[Relax][Frontend][KVCache] Add masked sequence prefill helper for encoder valid lengths#19392
Open
xthomaswang wants to merge 1 commit intoapache:mainfrom
Open
[Relax][Frontend][KVCache] Add masked sequence prefill helper for encoder valid lengths#19392xthomaswang wants to merge 1 commit intoapache:mainfrom
xthomaswang wants to merge 1 commit intoapache:mainfrom
Conversation
Contributor
There was a problem hiding this comment.
Code Review
This pull request introduces a new masked sequence prefill kernel, _attention_sequence_prefill_with_mask, designed for batched encoder-style inputs with right-padding. It includes a corresponding test suite covering various edge cases like zero-length sequences and grouped-query attention. Feedback includes removing an unused causal parameter, cleaning up unnecessary pylint suppressions for the lse buffer, and adding explicit bounds checks for kv_len during K and V loading to ensure robustness.
This adds a Relax frontend helper for sequence prefill with per-batch valid lengths on right-padded encoder inputs. The existing sequence prefill helper assumes that every position in [0, seq_len) is valid for every sample in the batch. That assumption does not hold for padded encoder batches, where each sample may carry a different valid length. The new helper adds the masked variant needed for this case while reusing the existing prefill kernel configuration and schedule. The focused correctness test compares the kernel against a NumPy fp32 reference on valid rows only and covers zero-length, full-length, mixed valid-length, and grouped-query cases. The test uses TVM GPU target parametrization for CUDA and Metal. This change is intentionally correctness-only. It does not introduce new tuning knobs, target-specific schedule changes, or performance claims. Tests: - python -m pytest -v tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py
4e09c7d to
459e1ff
Compare
Author
|
Addressed the suggestions from gemini |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
_attention_sequence_prefill_with_maskinpython/tvm/relax/frontend/nn/llm/kv_cache.py— a masked variant of the existing sequence prefill kernel that supports right-padded encoder batches with per-samplevalid_lens.The existing
_attention_sequence_prefillassumes all positions in[0, seq_len)are valid, which breaks for padded encoder inputs where each batch element has a different valid prefix length. This helper adds the masking semantics needed for correctness:valid_lensinput(row, col)pairs from the online softmax updateIt reuses the existing prefill kernel config and schedule — no new tuning knobs, no target-specific changes, no performance claims. Correctness only.
Motivation: encoder batch prefill for downstream consumers
This is the TVM-side primitive needed to support encoder batch prefill in downstream projects like
mlc-llm, where padded encoder batches withvalid_lensneed to be lowered without materializing an explicit broadcast attention mask on the host.The helper is generic and useful for any encoder-style sequence prefill consumer with per-sample valid lengths.
Tests
tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.pycovers:valid_len == 0valid_len == seq_lenh_q > h_kv)Run with: