Skip to content

add optimized prefill gdn kernels for qwen3_5#2686

Open
yiijin wants to merge 3 commits intoROCm:mainfrom
yiijin:prefill_gdn_opt
Open

add optimized prefill gdn kernels for qwen3_5#2686
yiijin wants to merge 3 commits intoROCm:mainfrom
yiijin:prefill_gdn_opt

Conversation

@yiijin
Copy link
Copy Markdown

@yiijin yiijin commented Apr 10, 2026

Motivation

Add an optimized forward-only prefill GDN pipeline to aiter, reducing kernel launch overhead and global memory traffic through kernel fusion and intermediate tensor layout optimization.

Technical Details

Adds chunk_gated_delta_rule_opt alongside the existing implementation (no existing code modified).

  • Fused cumsum+KKT kernel: Merges gate cumsum and scaled-dot KKT into one kernel, eliminating one intermediate tensor round-trip.
  • Fused triangular-solve+recompute kernel (new file): Keeps 16×16 inverse blocks in registers, removing the intermediate Ai tensor from global memory.
  • Optimized h/o kernels: GQA-aware strides, head-major [B,H,T,K/V] intermediate layout, output o written directly as [B,T,H,V] (no post-kernel transpose).

Test Plan

Covers equal-length and variable-length sequences, fp16/bf16, multiple B/T/H/D combos, L2-norm, and gate masking.

Test Result

image

Submission Checklist

@yiijin yiijin requested a review from a team April 10, 2026 09:45
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2686 --add-label <label>

@yiijin yiijin requested review from ganyi1996ppo and huizzhan April 10, 2026 09:45
@yiijin yiijin force-pushed the prefill_gdn_opt branch 3 times, most recently from e1ad4e6 to 558ac3c Compare April 13, 2026 03:00
huizzhan
huizzhan previously approved these changes Apr 13, 2026
valarLip
valarLip previously approved these changes Apr 13, 2026
Copilot AI review requested due to automatic review settings April 14, 2026 15:27
@yiijin yiijin dismissed stale reviews from valarLip and huizzhan via 431ee45 April 14, 2026 15:27
@yiijin yiijin force-pushed the prefill_gdn_opt branch 2 times, most recently from 431ee45 to 80eabcc Compare April 14, 2026 15:28
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a new optimized, forward-only “prefill” path for the Triton gated-delta-rule (GDN) pipeline intended to reduce kernel launches and global memory traffic via kernel fusion and layout changes, while keeping the existing implementation intact.

Changes:

  • Introduces new public APIs chunk_gated_delta_rule_opt and chunk_gated_delta_rule_opt_vk and wires them into aiter.ops.triton.gated_delta_net.
  • Adds fused prefill kernels (fused cumsum+KKT, fused solve-tril+recompute) and new “opt/opt_vk” variants of the hidden-state and output kernels using head-major intermediates.
  • Expands Triton op tests to cover the new optimized variants for both fixed-length and variable-length inputs.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
op_tests/triton_tests/test_gated_delta_rule.py Adds correctness tests for chunk_gated_delta_rule_opt and chunk_gated_delta_rule_opt_vk, including varlen coverage.
aiter/ops/triton/gated_delta_net/gated_delta_rule.py Adds high-level Python wrappers exposing the optimized forward-only prefill APIs.
aiter/ops/triton/gated_delta_net/init.py Exports the new optimized APIs from the public module.
aiter/ops/triton/_triton_kernels/gated_delta_rule/prefill/fused_solve_tril_recompute.py New fused kernel combining triangular solve and recomputation of w/u with register-resident inverse blocks.
aiter/ops/triton/_triton_kernels/gated_delta_rule/prefill/fused_cumsum_kkt.py Adds an autotuned fused kernel for chunk-local gate cumsum + scaled-dot KKT.
aiter/ops/triton/_triton_kernels/gated_delta_rule/prefill/chunk_o.py Adds optimized output kernels (opt / opt_vk) and Python entrypoints producing [B, T, H, V] directly.
aiter/ops/triton/_triton_kernels/gated_delta_rule/prefill/chunk_delta_h.py Adds optimized hidden-state forward kernels (opt / opt_vk) and entrypoints using head-major intermediates.
aiter/ops/triton/_triton_kernels/gated_delta_rule/prefill/chunk.py Adds optimized forward pipelines (chunk_gated_delta_rule_fwd_opt / _vk) composing the fused kernels.
aiter/ops/triton/_triton_kernels/gated_delta_rule/prefill/init.py Re-exports the new optimized prefill kernels/helpers.
aiter/ops/triton/_triton_kernels/gated_delta_rule/init.py Re-exports the new optimized kernels/helpers at the gated-delta-rule kernel package level.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +124 to +125
initial_state: torch.Tensor,
output_final_state: bool,
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature types initial_state as a required torch.Tensor, but higher-level code forwards initial_state=None (and the downstream Triton kernel supports the no-initial-state path). Update the annotation/default to initial_state: torch.Tensor | None = None (and consider defaulting output_final_state: bool = False) to match actual supported usage.

Suggested change
initial_state: torch.Tensor,
output_final_state: bool,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,

Copilot uses AI. Check for mistakes.
Comment on lines +203 to +207
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as chunk_gated_delta_rule_fwd_opt: initial_state is annotated as required torch.Tensor but the implementation supports None. Update the signature to initial_state: torch.Tensor | None = None (and consider defaulting output_final_state) to reflect runtime behavior and avoid confusing API/type checking.

Copilot uses AI. Check for mistakes.
"""
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
BT = A_raw.shape[-1]
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BT is taken from A_raw.shape[-1], but the underlying kernel is written specifically for 64-token chunks (fixed 0/16/32/48 offsets). Add a runtime validation like if BT != 64: raise ValueError(...) to prevent silent wrong results when this helper is used directly.

Suggested change
BT = A_raw.shape[-1]
BT = A_raw.shape[-1]
if BT != 64:
raise ValueError(
f"fused_solve_tril_recompute_w_u only supports BT=64 because the "
f"underlying Triton kernel is specialized for 64-token chunks; got BT={BT}."
)

Copilot uses AI. Check for mistakes.
Comment on lines +189 to +198
h11 = b11.to(tl.bfloat16)
h22 = b22.to(tl.bfloat16)
h33 = b33.to(tl.bfloat16)
h44 = b44.to(tl.bfloat16)
h21 = b21.to(tl.bfloat16)
h31 = b31.to(tl.bfloat16)
h32 = b32.to(tl.bfloat16)
h41 = b41.to(tl.bfloat16)
h42 = b42.to(tl.bfloat16)
h43 = b43.to(tl.bfloat16)
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The triangular-solve inverse blocks are unconditionally downcast to tl.bfloat16 (even if inputs are fp16), which changes numerical behavior and can require bf16 hardware support. Consider casting to the input dtype (fp16 vs bf16) or making the downcast conditional via a tl.constexpr flag/type check.

Suggested change
h11 = b11.to(tl.bfloat16)
h22 = b22.to(tl.bfloat16)
h33 = b33.to(tl.bfloat16)
h44 = b44.to(tl.bfloat16)
h21 = b21.to(tl.bfloat16)
h31 = b31.to(tl.bfloat16)
h32 = b32.to(tl.bfloat16)
h41 = b41.to(tl.bfloat16)
h42 = b42.to(tl.bfloat16)
h43 = b43.to(tl.bfloat16)
inv_block_dtype = rA43.dtype
h11 = b11.to(inv_block_dtype)
h22 = b22.to(inv_block_dtype)
h33 = b33.to(inv_block_dtype)
h44 = b44.to(inv_block_dtype)
h21 = b21.to(inv_block_dtype)
h31 = b31.to(inv_block_dtype)
h32 = b32.to(inv_block_dtype)
h41 = b41.to(inv_block_dtype)
h42 = b42.to(inv_block_dtype)
h43 = b43.to(inv_block_dtype)

Copilot uses AI. Check for mistakes.
"""
B, T, Hg, K = k.shape
H = beta.shape[-1]
BT = chunk_size
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chunk_size is treated as a free variable (BT = chunk_size), but the docstring says it must be 64 and the optimized pipeline assumes BT=64. Add a runtime check enforcing chunk_size == 64 (or relax the doc and ensure all downstream kernels truly support other sizes).

Suggested change
BT = chunk_size
if chunk_size != 64:
raise ValueError(f"Unsupported chunk_size={chunk_size}; fused_chunk_local_cumsum_scaled_dot_kkt_fwd requires chunk_size == 64.")
BT = 64

Copilot uses AI. Check for mistakes.
@yiijin yiijin force-pushed the prefill_gdn_opt branch 4 times, most recently from df0dedd to bd69224 Compare April 17, 2026 06:10
@yiijin yiijin requested a review from huizzhan April 17, 2026 07:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants