Skip to content

feat: page size > 1 support#1224

Open
blueswhen wants to merge 1 commit intomainfrom
page_size_support
Open

feat: page size > 1 support#1224
blueswhen wants to merge 1 commit intomainfrom
page_size_support

Conversation

@blueswhen
Copy link
Collaborator

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces comprehensive support for paged KV cache, significantly enhancing memory efficiency and potentially improving throughput for large language models. By abstracting KV cache memory into fixed-size pages, the system can better manage fragmented memory and reduce overhead, especially for variable-length sequences. This foundational change impacts attention mechanisms, memory allocation, and request management across the LightLLM framework.

Highlights

  • Paged KV Cache Support: Introduced comprehensive support for paged KV cache, enabling more efficient memory utilization and management for attention mechanisms.
  • Integrated Paged Attention Backends: Integrated new paged attention backends for both FlashAttention v3 (Fa3) and FlashInfer, including their Multi-Layer Attention (MLA) variants, which are conditionally selected based on the configured page size.
  • Memory Management Enhancements: Updated core memory management components, including KV cache buffers, request managers, and radix cache, to align memory allocations and operations with page-based structures.
  • Compatibility and Checks: Added compatibility checks and assertions to ensure that the paged memory feature (when PAGE_SIZE > 1) does not conflict with existing functionalities like MTP, RPyC PD split mode, DP prefill balance, and CPU cache.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • lightllm/common/basemodel/attention/create_utils.py
    • Imported get_page_size from envs_utils.
    • Conditionally selected paged attention backends (PagedFa3AttBackend, PagedFlashInferAttBackend, PagedMlaFa3AttBackend, PagedMlaFlashInferAttBackend) based on the PAGE_SIZE environment variable.
  • lightllm/common/basemodel/attention/fa3/fp.py
    • Corrected a typo in method name from _nomarl_prefill_att to _normal_prefill_att.
  • lightllm/common/basemodel/attention/flashinfer/fp.py
    • Corrected a typo in method name from _nomarl_prefill_att to _normal_prefill_att.
  • lightllm/common/basemodel/attention/paged_fa3/fp.py
    • Added PagedFa3AttBackend and its associated prefill/decode states to support paged attention using FlashAttention v3.
  • lightllm/common/basemodel/attention/paged_fa3/mla.py
    • Added PagedMlaFa3AttBackend and its associated prefill/decode states to support paged attention with MLA FlashAttention v3.
  • lightllm/common/basemodel/attention/paged_flashinfer/fp.py
    • Added PagedFlashInferAttBackend and its associated prefill/decode states to support paged attention using FlashInfer.
  • lightllm/common/basemodel/attention/paged_flashinfer/mla.py
    • Added PagedMlaFlashInferAttBackend and its associated prefill/decode states to support paged attention with MLA FlashInfer.
  • lightllm/common/basemodel/attention/triton/fp.py
    • Corrected a typo in method name from _nomarl_prefill_att to _normal_prefill_att.
  • lightllm/common/basemodel/triton_kernel/fa3_utils.py
    • Imported get_page_size.
    • Modified page_table_copy to handle page-aligned indexing when page_size > 1.
  • lightllm/common/basemodel/triton_kernel/repack_kv_index.py
    • Imported get_page_size.
    • Added _fwd_kernel_repack_page_kv_index_from_tokens and paged_repack_kv_index for managing paged KV cache indices.
  • lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py
    • Imported triton and get_page_size.
    • Adjusted _init_buffers to allocate KV buffer size based on page_size for page alignment.
  • lightllm/common/kv_cache_mem_manager/mem_manager.py
    • Imported triton and get_page_size.
    • Modified __init__ and _init_buffers to align memory allocation to page_size boundaries.
  • lightllm/common/req_manager.py
    • Imported triton and get_page_size.
    • Introduced new methods (calc_real_need_token_num, calc_last_mem_index_in_prefill, alloc_mem_indices, _expand_to_page_mem_indices, _expand_by_page_size, _alloc_paged_mem_indices, _get_need_paged_token_num) to manage paged memory allocation and deallocation for requests.
  • lightllm/server/api_start.py
    • Imported get_page_size.
    • Added compatibility assertions to prevent page_size > 1 from being used with MTP, RPyC PD split mode, DP prefill balance, and CPU cache.
  • lightllm/server/router/dynamic_prompt/paged_radix_cache.py
    • Added a new file implementing PagedRadixCache and TreeNode classes, designed to manage paged KV cache entries within a radix tree structure.
  • lightllm/server/router/model_infer/infer_batch.py
    • Added last_kv_mem_index attribute to InferReq.
    • Updated _match_radix_cache to store the last_kv_mem_index for paged cache handling.
  • lightllm/server/router/model_infer/mode_backend/base_backend.py
    • Imported PagedRadixCache.
    • Modified init_model to instantiate PagedRadixCache if page_size > 1, otherwise use RadixCache.
  • lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py
    • Updated padded_prepare_prefill_inputs and padded_prepare_decode_inputs to utilize the new paged memory allocation and last_kv_mem_index tracking from req_manager.
  • lightllm/server/router/model_infer/mode_backend/generic_pre_process.py
    • Updated prepare_prefill_inputs and prepare_decode_inputs to utilize the new paged memory allocation and last_kv_mem_index tracking from req_manager.
  • lightllm/server/router/req_queue/chunked_prefill/impl.py
    • Imported get_page_size.
    • Adjusted the _can_add_new_req method's token capacity calculation to account for page_size overhead.
  • lightllm/utils/envs_utils.py
    • Added get_page_size function to retrieve the PAGE_SIZE environment variable, defaulting to 1.
  • unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py
    • Added test_paged_repack_kv_index to verify the correctness of paged KV index repack logic.
Activity
  • No specific activity has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces paged attention support for FlashAttention v3 (FA3) and FlashInfer backends, enhancing memory efficiency by managing KV cache in pages. Key changes include adding get_page_size utility, updating memory managers to align allocations with page sizes, and modifying request managers to handle paged memory indices. New paged attention backend classes (PagedFa3AttBackend, PagedMlaFa3AttBackend, PagedFlashInferAttBackend, PagedMlaFlashInferAttBackend) are implemented, along with a PagedRadixCache for dynamic prompt caching. Compatibility checks were added to api_start.py to restrict paged attention usage with certain features like MTP, RPyC PD split mode, DP prefill balance, and CPU cache. Review comments highlight several potential issues: a memory leak in _expand_to_page_mem_indices due to incomplete freeing logic, incorrect kv_last_page_len calculations in paged FlashInfer backends when b_page_len is zero, an inconsistency in block_end_loc calculation in repack_kv_index leading to potential out-of-bounds access, and several instances of integer division or slicing that could cause incorrect tensor sizing, shape mismatches, or incorrect memory access patterns if input dimensions are not perfectly divisible or structured as assumed.

Note: Security Review is unavailable for this PR.

Comment on lines +121 to +125
base_indices = free_token_index[free_token_index % page_size == 0]
if len(base_indices) == 0:
return free_token_index
page_offsets = torch.arange(page_size, dtype=base_indices.dtype, device=base_indices.device)
return (base_indices[:, None] + page_offsets[None, :]).reshape(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic in _expand_to_page_mem_indices for freeing tokens might be incomplete. If free_token_index contains indices that are part of a page but none of them are page-aligned (i.e., free_token_index % page_size == 0 is never true), then base_indices will be empty, and the function will return the original free_token_index. This means that partial pages might not be fully freed, potentially leading to memory leaks or an inconsistent state in the memory manager. The intention should be to free all tokens belonging to the pages that contain any of the free_token_index elements.

q_starts = self.infer_state.b1_cu_q_seq_len.int()
kv_starts = self.infer_state.b1_cu_kv_seq_len.int()
b_page_len = triton.cdiv(self.infer_state.b_seq_len, self.backend.page_size)
kv_starts[1:] = b_page_len.cumsum(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation kv_last_page_len = self.infer_state.b_seq_len - (b_page_len - 1) * self.backend.page_size could be problematic if b_page_len is 0 for any sequence (e.g., an empty sequence). In such a case, b_page_len - 1 would be -1, leading to an incorrect kv_last_page_len. It's safer to handle the case where b_page_len is 0 explicitly, perhaps by clamping b_page_len - 1 to a minimum of 0.

page_data = token_data // page_size

offs_seq = start_seq_n * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
block_end_loc = tl.minimum((start_seq_n + 1) * SEQ_BLOCK, cur_batch_seq_len)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's an inconsistency in how block_end_loc is calculated and used. On line 56, block_end_loc is calculated in terms of token length (* page_size), but on line 65, it's used for page-based indexing without * page_size, while cur_batch_seq_len remains token-based. This could lead to incorrect masking or out-of-bounds access when tl.store is called. cur_batch_seq_len should be converted to page length (e.g., triton.cdiv(cur_batch_seq_len, page_size)) when used for page-based indexing.

device = self.infer_state.input_ids.device
model = self.backend.model
b_page_len = triton.cdiv(self.infer_state.b_seq_len, self.backend.page_size)
self.kv_last_page_len_buffer = self.infer_state.b_seq_len - (b_page_len - 1) * self.backend.page_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line has the same potential issue as line 51 regarding the calculation of kv_last_page_len_buffer when b_page_len might be 0. Ensure that b_page_len - 1 does not result in a negative value that could lead to an incorrect calculation of the last page length.

page_table_copy(
page_table=self.page_table[:, :table_len],
req_to_token_indexs=model.req_manager.req_to_token_indexs,
b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The slicing self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)] implies a very specific structure and length for b_req_idx. If b_req_idx does not conform to this pattern (e.g., its length is not a multiple of args_mtp_step + 1), this slice might produce an unexpected result or an empty tensor, leading to incorrect page table copies.

args_mtp_step = get_env_start_args().mtp_step
if args_mtp_step > 0:
mtp_size = args_mtp_step + 1
b_q_seq_len = torch.full(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to paged_fa3/fp.py, the integer division self.infer_state.b_seq_len.shape[0] // mtp_size could lead to an incorrectly sized tensor if self.infer_state.b_seq_len.shape[0] is not perfectly divisible by mtp_size. This can cause downstream issues with tensor operations.

if args_mtp_step > 0:
mtp_size = args_mtp_step + 1
b_q_seq_len = torch.full(
(self.infer_state.b_seq_len.shape[0] // mtp_size,),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The integer division self.infer_state.b_seq_len.shape[0] // mtp_size might lead to incorrect sizing if self.infer_state.b_seq_len.shape[0] is not perfectly divisible by mtp_size. This could result in an array of an unexpected size. Consider using triton.cdiv or adding an assertion to ensure exact divisibility if that's the expectation.

self.kv_indices = torch.empty(buffer_len, dtype=torch.int32, device=device)

b_page_len = triton.cdiv(self.infer_state.b_seq_len, self.backend.page_size)
self.kv_starts[1:] = b_page_len.cumsum(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to paged_flashinfer/fp.py, modifying self.kv_starts in place with self.kv_starts[1:] = b_page_len.cumsum(0) could have unintended side effects if the original self.infer_state.b1_cu_kv_seq_len is expected to be immutable or is used elsewhere in its initial state.

Comment on lines +128 to +130
self.page_table = page_buffer[self.infer_state.microbatch_index][
: att_batch_size * shared_table_len
].reshape(att_batch_size, shared_table_len)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The slicing page_buffer[self.infer_state.microbatch_index][: att_batch_size * shared_table_len] assumes that att_batch_size * shared_table_len is always less than or equal to the actual length of the buffer for the given microbatch_index. If att_batch_size is miscalculated due to integer division (as noted in line 118), this could lead to incorrect slicing or out-of-bounds access. It's safer to ensure the calculated length is within bounds or use min to prevent over-indexing.

self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int()
self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int()

att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to line 105, the integer division self.infer_state.batch_size // (args_mtp_step + 1) could cause issues if self.infer_state.batch_size is not perfectly divisible by args_mtp_step + 1. This might lead to an incorrect att_batch_size which could propagate errors in subsequent operations like reshaping or slicing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant