Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@
python scripts/test_perf.py --verbose
```

- 单请求推理服务测试
```bash
python test/service/request.py --content="text:Image 1:" --content="image_url:xxx.jpg" --content="text:Image 2:" --content="image_url:xxxx.jpg" --content="text:Compare the 2 images."
```

- 运行推理基准测试(C-Eval/MMLU)

```bash
Expand Down
174 changes: 126 additions & 48 deletions python/infinilm/llm/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from collections import deque
import queue
from typing import List, Dict, Set
import xxhash
import numpy as np
Expand Down Expand Up @@ -45,9 +46,9 @@ class BlockManager:
"""

def __init__(self, num_blocks: int, block_size: int):
assert (
num_blocks > 0 and block_size > 0
), "num_blocks and block_size must be positive"
assert num_blocks > 0 and block_size > 0, (
"num_blocks and block_size must be positive"
)
self.num_blocks = num_blocks
self.block_size = block_size

Expand All @@ -67,12 +68,20 @@ def reset_req_blocks(self) -> None:
self.req_block_ids.clear()

@classmethod
def compute_hash(cls, token_ids: List[int], prefix_hash: int = -1) -> int:
def compute_hash(
cls,
token_ids: List[int],
prefix_hash: int = -1,
mm_data_identifiers: List[str] = None,
) -> int:
"""Compute hash for token sequence with optional prefix chaining."""
h = xxhash.xxh64()
if prefix_hash != -1:
h.update(prefix_hash.to_bytes(8, "little"))
h.update(np.array(token_ids, dtype=np.int32).tobytes())
if mm_data_identifiers is not None:
for identifier in mm_data_identifiers:
h.update(identifier.encode("utf-8"))
return h.intdigest()

def _allocate_partial_block(self, block_id: int) -> Block:
Expand Down Expand Up @@ -100,9 +109,9 @@ def _allocate_full_block(self, block_id: int) -> Block:
def _deallocate_block(self, block_id: int):
"""Deallocate a block and return it to free list."""
block = self.blocks[block_id]
assert (
block.ref_count == 0
), f"Block {block_id} ref_count not zero, cannot deallocate"
assert block.ref_count == 0, (
f"Block {block_id} ref_count not zero, cannot deallocate"
)

if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id:
del self.hash_to_block_id[block.hash]
Expand All @@ -115,73 +124,142 @@ def can_allocate(self, num_required_blocks: int) -> bool:
return len(self.free_block_ids) >= num_required_blocks

def allocate_blocks(
self, token_ids: List[int], block_table: List[int] = None
self,
token_ids: List[int],
block_table: List[int] = None,
mm_token_index_mappings: List[dict] = None,
) -> tuple[List[int], List[int], int]:
"""Allocate cache blocks for new request with prefix caching support.

Args:
token_ids: Input token sequence
block_table: Existing block_table (for decode phase)

mm_token_index_mappings: List of multimodal token index mappings
Returns:
Tuple of (block_table, slot_mapping, num_cached_tokens)
"""
if block_table is None:
block_table = []

# Static args
num_tokens = len(token_ids)
num_blocks = (num_tokens + self.block_size - 1) // self.block_size
num_full_blocks = num_tokens // self.block_size
remain_tokens = num_tokens % self.block_size
num_mm_inputs = (
0 if not mm_token_index_mappings else len(mm_token_index_mappings)
)

# Variables
slot_mapping = []
num_cached_tokens = 0
prefix_hash = -1
cache_miss = False
mm_start_counter = 0
mm_caching_queue = queue.Queue(maxsize=len(mm_token_index_mappings))
blocks_blueprint = [] # [{"prefix_hash": int or -1 if not a full block, "block_id": int or -1 if not cached}, ...]
max_blocks_to_reuse = num_full_blocks

for block_idx in range(num_blocks):
start_idx = block_idx * self.block_size
end_idx = min(start_idx + self.block_size, num_tokens)
block_tokens = token_ids[start_idx:end_idx]

# Only full blocks can be hashed for reuse
if len(block_tokens) == self.block_size:
prefix_hash = self.compute_hash(block_tokens, prefix_hash)

# Try to reuse existing block
if not cache_miss:
cached_block_id = self.hash_to_block_id.get(prefix_hash, -1)
if (
cached_block_id != -1
and self.blocks[cached_block_id].token_ids == block_tokens
):
# Check if all tokens are cached
if num_cached_tokens + self.block_size == len(token_ids):
cache_miss = True
else:
# Reuse successful
block = self.blocks[cached_block_id]
block.ref_count += 1
block_table.append(cached_block_id)
num_cached_tokens += self.block_size
continue
else:
cache_miss = True
else:
prefix_hash = -1

# Cannot reuse, allocate new block
if not self.free_block_ids:
raise RuntimeError("No available cache blocks")
# Process multimodal token index mappings for this block
mm_data_identifiers = []
while (
mm_start_counter < num_mm_inputs
and mm_token_index_mappings[mm_start_counter]["start_index"] < end_idx
and mm_token_index_mappings[mm_start_counter]["start_index"]
>= start_idx
):
# for all mm_data whose start_index is within this block's token range, add its identifier to the list
mm_data_identifiers.append(
mm_token_index_mappings[mm_start_counter]["identifier"]
)
mm_caching_queue.put((mm_start_counter))
mm_start_counter += 1

prefix_hash = (
self.compute_hash(block_tokens, prefix_hash, mm_data_identifiers)
if len(block_tokens) == self.block_size
else -1
)

# Try to reuse existing block if no previous cache miss yet
cached_block_id = (
self.hash_to_block_id.get(prefix_hash, -1) if not cache_miss else -1
)
if (
cached_block_id != -1
and self.blocks[cached_block_id].token_ids != block_tokens
):
cached_block_id = -1
if end_idx == num_tokens and remain_tokens == 0:
# Spicial case, when the last block is fully packed, we cannot reuse it because we need to leave at least one uncached token for forward
cached_block_id = -1

# Deal with the first cache miss
if not cache_miss and cached_block_id == -1:
max_blocks_to_reuse = min(max_blocks_to_reuse, block_idx)
cache_miss = True

if not cache_miss:
# pop fully cached mm_data
while (
not mm_caching_queue.empty()
and mm_token_index_mappings[mm_caching_queue.queue[0]]["end_index"]
< end_idx
):
mm_caching_queue.get()

blocks_blueprint.append(
{"prefix_hash": prefix_hash, "block_id": cached_block_id}
)

# If there is one incomplete mm_data, tailing blocks need to fall back until all included mm_data are complete
if not mm_caching_queue.empty():
incomplete_mm = mm_token_index_mappings[mm_caching_queue.get()]
incomplete_mm_start = incomplete_mm[
"start_index"
] # Fall back until this index is no longer included in the block
max_blocks_to_reuse = min(
max_blocks_to_reuse, incomplete_mm_start // self.block_size
)

num_cached_tokens = max_blocks_to_reuse * self.block_size

for block_id in range(num_blocks):
n_block_tokens = self.block_size

if block_id < max_blocks_to_reuse:
# Reuse block
block = self.blocks[blocks_blueprint[block_id]["block_id"]]
block.ref_count += 1

new_block_id = self.free_block_ids[0]
if prefix_hash != -1:
block = self._allocate_full_block(new_block_id)
block.update(prefix_hash, block_tokens)
else:
block = self._allocate_partial_block(new_block_id)
block_table.append(new_block_id)

# Generate slot_mapping
for i in range(len(block_tokens)):
slot_mapping.append(new_block_id * self.block_size + i)
new_block_id = self.free_block_ids[0]
if blocks_blueprint[block_id]["prefix_hash"] != -1:
start_idx = block_id * self.block_size
end_idx = start_idx + self.block_size
block_tokens = token_ids[start_idx:end_idx]
block = self._allocate_full_block(new_block_id)
block.update(
blocks_blueprint[block_id]["prefix_hash"], block_tokens
)
else:
block = self._allocate_partial_block(new_block_id)
n_block_tokens = remain_tokens
slot_mapping.extend(
list(
range(
block.block_id * self.block_size,
block.block_id * self.block_size + n_block_tokens,
)
)
)

block_table.append(block.block_id)

return block_table, slot_mapping, num_cached_tokens

Expand Down
18 changes: 15 additions & 3 deletions python/infinilm/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def add_request(
if request_id is None:
request_id = f"cmpl-{uuid.uuid4().hex}"

images, videos, audios = None, None, None
mm_index_mappings = None
processed_inputs = None

if prompt_token_ids is not None:
Expand All @@ -708,12 +708,23 @@ def add_request(
messages, add_generation_prompt=add_generation_prompt
)

images, videos, audios = resolve_multimodal_inputs(messages)
mm_inputs = resolve_multimodal_inputs(messages)

processed_inputs = self.engine.process(
prompt, images, videos, audios, return_tensors="pt"
prompt,
mm_inputs["images"],
mm_inputs["videos"],
mm_inputs["audios"],
return_tensors="pt",
)

prompt_token_ids = processed_inputs.get("input_ids").flatten().tolist()
mm_index_mappings = self.engine.processor.get_mm_token_index_list(
prompt_token_ids,
image_ids=mm_inputs["image_urls"],
video_ids=mm_inputs["video_urls"],
audio_ids=mm_inputs["audio_urls"],
)

if sampling_params is None:
sampling_params = SamplingParams(max_tokens=self.config.max_tokens)
Expand All @@ -726,6 +737,7 @@ def add_request(
prompt=prompt,
prompt_token_ids=prompt_token_ids,
processed_inputs=processed_inputs,
mm_token_index_mappings=mm_index_mappings,
sampling_params=sampling_params,
eos_token_ids=self.engine.eos_token_ids,
request_data=request_data,
Expand Down
5 changes: 5 additions & 0 deletions python/infinilm/llm/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
prompt: Optional[str] = None,
prompt_token_ids: Optional[List[int]] = None,
processed_inputs: Optional[dict] = None,
mm_token_index_mappings: Optional[List[dict]] = None,
sampling_params: Optional[SamplingParams] = None,
eos_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
Expand All @@ -122,6 +123,7 @@ def __init__(
self.prompt_token_ids: List[int] = prompt_token_ids or []
self.prompt_length: int = len(self.prompt_token_ids)
self.processed_inputs: Optional[dict] = processed_inputs
self.mm_token_index_mappings: Optional[List[dict]] = mm_token_index_mappings

# Sampling parameters
self.sampling_params: SamplingParams = sampling_params or SamplingParams()
Expand Down Expand Up @@ -186,6 +188,9 @@ def get_num_blocks_required(self, block_size: int) -> int:
def get_max_tokens(self) -> Optional[int]:
return self.sampling_params.max_tokens

def get_mm_token_index_mappings(self) -> Optional[List[dict]]:
return self.mm_token_index_mappings

def is_finished(self) -> bool:
return self.status in [
RequestStatus.FINISHED,
Expand Down
4 changes: 3 additions & 1 deletion python/infinilm/llm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def schedule(self) -> Optional[SchedulerOutput]:

# Allocate blocks with automatic prefix caching support
req.block_table, req.slot_mapping, req.num_cached_tokens = (
self.cache_manager.allocate_blocks(req_tokens, req.block_table)
self.cache_manager.allocate_blocks(
req_tokens, req.block_table, req.get_mm_token_index_mappings()
)
)

req.num_blocks = len(req.block_table)
Expand Down
37 changes: 33 additions & 4 deletions python/infinilm/multimodal/multimodal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
from typing import List, Union
from typing import List, Optional, Union
from PIL import Image
import xxhash


def has_multimodal_inputs(messages: Union[List[dict], dict]) -> bool:
"""Check if the input messages contain any multimodal inputs."""
if isinstance(messages, dict):
messages = [messages]

for msg in messages:
content = msg.get("content", [])
if not isinstance(content, list):
return False

for item in content:
if item.get("type") in ["image_url", "video_url", "audio_url"]:
return True

return False


def resolve_multimodal_inputs(messages: Union[List[dict], dict]):
Expand All @@ -8,8 +26,11 @@ def resolve_multimodal_inputs(messages: Union[List[dict], dict]):
messages = [messages]

images = []
image_urls = []
videos = []
video_urls = []
audios = []
audio_urls = []

for msg in messages:
content = msg.get("content", [])
Expand All @@ -19,11 +40,19 @@ def resolve_multimodal_inputs(messages: Union[List[dict], dict]):
for item in content:
if item.get("type") == "text":
pass
elif item.get("type") == "image":
elif item.get("type") == "image_url":
# TODO support other image url formats
images.append(Image.open(item["image_url"]))
images.append(Image.open(item["image_url"]["url"]))
image_urls.append(item["image_url"]["url"])

else: # TODO support video/audio
raise NotImplementedError("Only image input is supported for now")

return images, videos, audios
return {
"images": images,
"image_urls": image_urls,
"videos": videos,
"video_urls": video_urls,
"audios": audios,
"audio_urls": audio_urls,
}
Loading