Skip to content

[Relax][Frontend][KVCache] Add masked sequence prefill helper for encoder valid lengths#19392

Open
xthomaswang wants to merge 1 commit intoapache:mainfrom
xthomaswang:relax/masked-sequence-prefill
Open

[Relax][Frontend][KVCache] Add masked sequence prefill helper for encoder valid lengths#19392
xthomaswang wants to merge 1 commit intoapache:mainfrom
xthomaswang:relax/masked-sequence-prefill

Conversation

@xthomaswang
Copy link
Copy Markdown

Summary

Adds _attention_sequence_prefill_with_mask in python/tvm/relax/frontend/nn/llm/kv_cache.py — a masked variant of the existing sequence prefill kernel that supports right-padded encoder batches with per-sample valid_lens.

The existing _attention_sequence_prefill assumes all positions in [0, seq_len) are valid, which breaks for padded encoder inputs where each batch element has a different valid prefix length. This helper adds the masking semantics needed for correctness:

  • accepts a per-batch valid_lens input
  • ignores padded query rows and padded key/value positions
  • excludes padded (row, col) pairs from the online softmax update

It reuses the existing prefill kernel config and schedule — no new tuning knobs, no target-specific changes, no performance claims. Correctness only.

Motivation: encoder batch prefill for downstream consumers

This is the TVM-side primitive needed to support encoder batch prefill in downstream projects like mlc-llm, where padded encoder batches with valid_lens need to be lowered without materializing an explicit broadcast attention mask on the host.

The helper is generic and useful for any encoder-style sequence prefill consumer with per-sample valid lengths.

Tests

tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py covers:

  • valid_len == 0
  • valid_len == seq_len
  • mixed valid lengths within a batch
  • grouped-query attention (h_q > h_kv)

Run with:

python -m pytest -v tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new masked sequence prefill kernel, _attention_sequence_prefill_with_mask, designed for batched encoder-style inputs with right-padding. It includes a corresponding test suite covering various edge cases like zero-length sequences and grouped-query attention. Feedback includes removing an unused causal parameter, cleaning up unnecessary pylint suppressions for the lse buffer, and adding explicit bounds checks for kv_len during K and V loading to ensure robustness.

This adds a Relax frontend helper for sequence prefill with per-batch valid
lengths on right-padded encoder inputs.

The existing sequence prefill helper assumes that every position in
[0, seq_len) is valid for every sample in the batch. That assumption does not
hold for padded encoder batches, where each sample may carry a different
valid length. The new helper adds the masked variant needed for this case
while reusing the existing prefill kernel configuration and schedule.

The focused correctness test compares the kernel against a NumPy fp32
reference on valid rows only and covers zero-length, full-length, mixed
valid-length, and grouped-query cases. The test uses TVM GPU target
parametrization for CUDA and Metal.

This change is intentionally correctness-only. It does not introduce new
tuning knobs, target-specific schedule changes, or performance claims.

Tests:
- python -m pytest -v tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py
@xthomaswang xthomaswang force-pushed the relax/masked-sequence-prefill branch from 4e09c7d to 459e1ff Compare April 11, 2026 19:14
@xthomaswang
Copy link
Copy Markdown
Author

Addressed the suggestions from gemini

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant