add optimized prefill gdn kernels for qwen3_5#2686
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
e1ad4e6 to
558ac3c
Compare
431ee45 to
80eabcc
Compare
There was a problem hiding this comment.
Pull request overview
This PR adds a new optimized, forward-only “prefill” path for the Triton gated-delta-rule (GDN) pipeline intended to reduce kernel launches and global memory traffic via kernel fusion and layout changes, while keeping the existing implementation intact.
Changes:
- Introduces new public APIs
chunk_gated_delta_rule_optandchunk_gated_delta_rule_opt_vkand wires them intoaiter.ops.triton.gated_delta_net. - Adds fused prefill kernels (fused cumsum+KKT, fused solve-tril+recompute) and new “opt/opt_vk” variants of the hidden-state and output kernels using head-major intermediates.
- Expands Triton op tests to cover the new optimized variants for both fixed-length and variable-length inputs.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/triton_tests/test_gated_delta_rule.py | Adds correctness tests for chunk_gated_delta_rule_opt and chunk_gated_delta_rule_opt_vk, including varlen coverage. |
| aiter/ops/triton/gated_delta_net/gated_delta_rule.py | Adds high-level Python wrappers exposing the optimized forward-only prefill APIs. |
| aiter/ops/triton/gated_delta_net/init.py | Exports the new optimized APIs from the public module. |
| aiter/ops/triton/_triton_kernels/gated_delta_rule/prefill/fused_solve_tril_recompute.py | New fused kernel combining triangular solve and recomputation of w/u with register-resident inverse blocks. |
| aiter/ops/triton/_triton_kernels/gated_delta_rule/prefill/fused_cumsum_kkt.py | Adds an autotuned fused kernel for chunk-local gate cumsum + scaled-dot KKT. |
| aiter/ops/triton/_triton_kernels/gated_delta_rule/prefill/chunk_o.py | Adds optimized output kernels (opt / opt_vk) and Python entrypoints producing [B, T, H, V] directly. |
| aiter/ops/triton/_triton_kernels/gated_delta_rule/prefill/chunk_delta_h.py | Adds optimized hidden-state forward kernels (opt / opt_vk) and entrypoints using head-major intermediates. |
| aiter/ops/triton/_triton_kernels/gated_delta_rule/prefill/chunk.py | Adds optimized forward pipelines (chunk_gated_delta_rule_fwd_opt / _vk) composing the fused kernels. |
| aiter/ops/triton/_triton_kernels/gated_delta_rule/prefill/init.py | Re-exports the new optimized prefill kernels/helpers. |
| aiter/ops/triton/_triton_kernels/gated_delta_rule/init.py | Re-exports the new optimized kernels/helpers at the gated-delta-rule kernel package level. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| initial_state: torch.Tensor, | ||
| output_final_state: bool, |
There was a problem hiding this comment.
The signature types initial_state as a required torch.Tensor, but higher-level code forwards initial_state=None (and the downstream Triton kernel supports the no-initial-state path). Update the annotation/default to initial_state: torch.Tensor | None = None (and consider defaulting output_final_state: bool = False) to match actual supported usage.
| initial_state: torch.Tensor, | |
| output_final_state: bool, | |
| initial_state: torch.Tensor | None = None, | |
| output_final_state: bool = False, |
| beta: torch.Tensor, | ||
| scale: float, | ||
| initial_state: torch.Tensor, | ||
| output_final_state: bool, | ||
| cu_seqlens: torch.LongTensor | None = None, |
There was a problem hiding this comment.
Same as chunk_gated_delta_rule_fwd_opt: initial_state is annotated as required torch.Tensor but the implementation supports None. Update the signature to initial_state: torch.Tensor | None = None (and consider defaulting output_final_state) to reflect runtime behavior and avoid confusing API/type checking.
| """ | ||
| B, T, Hg, K, V = *k.shape, v.shape[-1] | ||
| H = v.shape[-2] | ||
| BT = A_raw.shape[-1] |
There was a problem hiding this comment.
BT is taken from A_raw.shape[-1], but the underlying kernel is written specifically for 64-token chunks (fixed 0/16/32/48 offsets). Add a runtime validation like if BT != 64: raise ValueError(...) to prevent silent wrong results when this helper is used directly.
| BT = A_raw.shape[-1] | |
| BT = A_raw.shape[-1] | |
| if BT != 64: | |
| raise ValueError( | |
| f"fused_solve_tril_recompute_w_u only supports BT=64 because the " | |
| f"underlying Triton kernel is specialized for 64-token chunks; got BT={BT}." | |
| ) |
| h11 = b11.to(tl.bfloat16) | ||
| h22 = b22.to(tl.bfloat16) | ||
| h33 = b33.to(tl.bfloat16) | ||
| h44 = b44.to(tl.bfloat16) | ||
| h21 = b21.to(tl.bfloat16) | ||
| h31 = b31.to(tl.bfloat16) | ||
| h32 = b32.to(tl.bfloat16) | ||
| h41 = b41.to(tl.bfloat16) | ||
| h42 = b42.to(tl.bfloat16) | ||
| h43 = b43.to(tl.bfloat16) |
There was a problem hiding this comment.
The triangular-solve inverse blocks are unconditionally downcast to tl.bfloat16 (even if inputs are fp16), which changes numerical behavior and can require bf16 hardware support. Consider casting to the input dtype (fp16 vs bf16) or making the downcast conditional via a tl.constexpr flag/type check.
| h11 = b11.to(tl.bfloat16) | |
| h22 = b22.to(tl.bfloat16) | |
| h33 = b33.to(tl.bfloat16) | |
| h44 = b44.to(tl.bfloat16) | |
| h21 = b21.to(tl.bfloat16) | |
| h31 = b31.to(tl.bfloat16) | |
| h32 = b32.to(tl.bfloat16) | |
| h41 = b41.to(tl.bfloat16) | |
| h42 = b42.to(tl.bfloat16) | |
| h43 = b43.to(tl.bfloat16) | |
| inv_block_dtype = rA43.dtype | |
| h11 = b11.to(inv_block_dtype) | |
| h22 = b22.to(inv_block_dtype) | |
| h33 = b33.to(inv_block_dtype) | |
| h44 = b44.to(inv_block_dtype) | |
| h21 = b21.to(inv_block_dtype) | |
| h31 = b31.to(inv_block_dtype) | |
| h32 = b32.to(inv_block_dtype) | |
| h41 = b41.to(inv_block_dtype) | |
| h42 = b42.to(inv_block_dtype) | |
| h43 = b43.to(inv_block_dtype) |
| """ | ||
| B, T, Hg, K = k.shape | ||
| H = beta.shape[-1] | ||
| BT = chunk_size |
There was a problem hiding this comment.
chunk_size is treated as a free variable (BT = chunk_size), but the docstring says it must be 64 and the optimized pipeline assumes BT=64. Add a runtime check enforcing chunk_size == 64 (or relax the doc and ensure all downstream kernels truly support other sizes).
| BT = chunk_size | |
| if chunk_size != 64: | |
| raise ValueError(f"Unsupported chunk_size={chunk_size}; fused_chunk_local_cumsum_scaled_dot_kkt_fwd requires chunk_size == 64.") | |
| BT = 64 |
df0dedd to
bd69224
Compare
Motivation
Add an optimized forward-only prefill GDN pipeline to aiter, reducing kernel launch overhead and global memory traffic through kernel fusion and intermediate tensor layout optimization.
Technical Details
Adds chunk_gated_delta_rule_opt alongside the existing implementation (no existing code modified).
Test Plan
Covers equal-length and variable-length sequences, fp16/bf16, multiple B/T/H/D combos, L2-norm, and gate masking.
Test Result
Submission Checklist