Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive support for paged KV cache, significantly enhancing memory efficiency and potentially improving throughput for large language models. By abstracting KV cache memory into fixed-size pages, the system can better manage fragmented memory and reduce overhead, especially for variable-length sequences. This foundational change impacts attention mechanisms, memory allocation, and request management across the LightLLM framework. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces paged attention support for FlashAttention v3 (FA3) and FlashInfer backends, enhancing memory efficiency by managing KV cache in pages. Key changes include adding get_page_size utility, updating memory managers to align allocations with page sizes, and modifying request managers to handle paged memory indices. New paged attention backend classes (PagedFa3AttBackend, PagedMlaFa3AttBackend, PagedFlashInferAttBackend, PagedMlaFlashInferAttBackend) are implemented, along with a PagedRadixCache for dynamic prompt caching. Compatibility checks were added to api_start.py to restrict paged attention usage with certain features like MTP, RPyC PD split mode, DP prefill balance, and CPU cache. Review comments highlight several potential issues: a memory leak in _expand_to_page_mem_indices due to incomplete freeing logic, incorrect kv_last_page_len calculations in paged FlashInfer backends when b_page_len is zero, an inconsistency in block_end_loc calculation in repack_kv_index leading to potential out-of-bounds access, and several instances of integer division or slicing that could cause incorrect tensor sizing, shape mismatches, or incorrect memory access patterns if input dimensions are not perfectly divisible or structured as assumed.
Note: Security Review is unavailable for this PR.
| base_indices = free_token_index[free_token_index % page_size == 0] | ||
| if len(base_indices) == 0: | ||
| return free_token_index | ||
| page_offsets = torch.arange(page_size, dtype=base_indices.dtype, device=base_indices.device) | ||
| return (base_indices[:, None] + page_offsets[None, :]).reshape(-1) |
There was a problem hiding this comment.
The logic in _expand_to_page_mem_indices for freeing tokens might be incomplete. If free_token_index contains indices that are part of a page but none of them are page-aligned (i.e., free_token_index % page_size == 0 is never true), then base_indices will be empty, and the function will return the original free_token_index. This means that partial pages might not be fully freed, potentially leading to memory leaks or an inconsistent state in the memory manager. The intention should be to free all tokens belonging to the pages that contain any of the free_token_index elements.
| q_starts = self.infer_state.b1_cu_q_seq_len.int() | ||
| kv_starts = self.infer_state.b1_cu_kv_seq_len.int() | ||
| b_page_len = triton.cdiv(self.infer_state.b_seq_len, self.backend.page_size) | ||
| kv_starts[1:] = b_page_len.cumsum(0) |
There was a problem hiding this comment.
The calculation kv_last_page_len = self.infer_state.b_seq_len - (b_page_len - 1) * self.backend.page_size could be problematic if b_page_len is 0 for any sequence (e.g., an empty sequence). In such a case, b_page_len - 1 would be -1, leading to an incorrect kv_last_page_len. It's safer to handle the case where b_page_len is 0 explicitly, perhaps by clamping b_page_len - 1 to a minimum of 0.
| page_data = token_data // page_size | ||
|
|
||
| offs_seq = start_seq_n * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK) | ||
| block_end_loc = tl.minimum((start_seq_n + 1) * SEQ_BLOCK, cur_batch_seq_len) |
There was a problem hiding this comment.
There's an inconsistency in how block_end_loc is calculated and used. On line 56, block_end_loc is calculated in terms of token length (* page_size), but on line 65, it's used for page-based indexing without * page_size, while cur_batch_seq_len remains token-based. This could lead to incorrect masking or out-of-bounds access when tl.store is called. cur_batch_seq_len should be converted to page length (e.g., triton.cdiv(cur_batch_seq_len, page_size)) when used for page-based indexing.
| device = self.infer_state.input_ids.device | ||
| model = self.backend.model | ||
| b_page_len = triton.cdiv(self.infer_state.b_seq_len, self.backend.page_size) | ||
| self.kv_last_page_len_buffer = self.infer_state.b_seq_len - (b_page_len - 1) * self.backend.page_size |
There was a problem hiding this comment.
| page_table_copy( | ||
| page_table=self.page_table[:, :table_len], | ||
| req_to_token_indexs=model.req_manager.req_to_token_indexs, | ||
| b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], |
There was a problem hiding this comment.
The slicing self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)] implies a very specific structure and length for b_req_idx. If b_req_idx does not conform to this pattern (e.g., its length is not a multiple of args_mtp_step + 1), this slice might produce an unexpected result or an empty tensor, leading to incorrect page table copies.
| args_mtp_step = get_env_start_args().mtp_step | ||
| if args_mtp_step > 0: | ||
| mtp_size = args_mtp_step + 1 | ||
| b_q_seq_len = torch.full( |
There was a problem hiding this comment.
| if args_mtp_step > 0: | ||
| mtp_size = args_mtp_step + 1 | ||
| b_q_seq_len = torch.full( | ||
| (self.infer_state.b_seq_len.shape[0] // mtp_size,), |
There was a problem hiding this comment.
The integer division self.infer_state.b_seq_len.shape[0] // mtp_size might lead to incorrect sizing if self.infer_state.b_seq_len.shape[0] is not perfectly divisible by mtp_size. This could result in an array of an unexpected size. Consider using triton.cdiv or adding an assertion to ensure exact divisibility if that's the expectation.
| self.kv_indices = torch.empty(buffer_len, dtype=torch.int32, device=device) | ||
|
|
||
| b_page_len = triton.cdiv(self.infer_state.b_seq_len, self.backend.page_size) | ||
| self.kv_starts[1:] = b_page_len.cumsum(0) |
There was a problem hiding this comment.
| self.page_table = page_buffer[self.infer_state.microbatch_index][ | ||
| : att_batch_size * shared_table_len | ||
| ].reshape(att_batch_size, shared_table_len) |
There was a problem hiding this comment.
The slicing page_buffer[self.infer_state.microbatch_index][: att_batch_size * shared_table_len] assumes that att_batch_size * shared_table_len is always less than or equal to the actual length of the buffer for the given microbatch_index. If att_batch_size is miscalculated due to integer division (as noted in line 118), this could lead to incorrect slicing or out-of-bounds access. It's safer to ensure the calculated length is within bounds or use min to prevent over-indexing.
| self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() | ||
| self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() | ||
|
|
||
| att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) |
There was a problem hiding this comment.
Similar to line 105, the integer division self.infer_state.batch_size // (args_mtp_step + 1) could cause issues if self.infer_state.batch_size is not perfectly divisible by args_mtp_step + 1. This might lead to an incorrect att_batch_size which could propagate errors in subsequent operations like reshaping or slicing.
060b2f4 to
924b2b9
Compare
No description provided.