MegaKernels in Iris 🔥#541
Draft
neoblizz wants to merge 42 commits into
Draft
Conversation
Single persistent Triton kernel running attention and MoE for all 36 layers of GPT-OSS-120B on one GPU (batch-1 decode), collapsing cosmic's multi-GPU design. FP4 expert weights dequantized in-kernel, BF16 compute. Validated against a PyTorch reference: greedy decode produces coherent output on real weights. Includes HF->.iris converter and kernel tests.
Expert GEMVs can now run native FP4xFP8 scaled matrix multiply via tl.dot_scaled (lowers to v_mfma_scale_f32_16x16x128_f8f6f4 on gfx950) with dynamic FP8-E4M3 activation quant, selected by --quant. Keeps the BF16 dequant path as default. Quantized path is ~2.8x faster (29ms vs 84ms TPOT on MI355X) and still decodes coherently. Adds dot_scaled and quantized-expert correctness tests plus a TPOT benchmark.
Block-of-rows GEMV tiling with max_contiguous/multiple_of hints so the compiler emits dwordx4 loads instead of per-element ushort. Fuse the attention RMSNorm into QKV and the MoE RMSNorm into the router GEMV, fold RoPE into attention, and stripe the residual/zeroing across all programs. This removes the serial single-program phases and their grid barriers. TPOT drops from 29.4 to 25.3 ms (quant) on MI355X; output unchanged.
In the quantized path, each program now owns whole 32-element SwiGLU output blocks and quantizes the block it just produced, so the producer and consumer are the same program and the separate act-quant barrier is gone (4 fewer grid barriers per layer). TPOT drops from 25.3 to 22.9 ms (quant) on MI355X; output unchanged.
Document the measured findings (zero spills, per-element loads, barrier cost) and the tiling + barrier-reduction passes. Refresh the TPOT table: quantized 29.4 -> 22.9 ms on MI355X.
Simplify the README to match the other examples, rewrite the module and test docstrings for clarity, and remove implementation-history detail. No functional changes; both decode paths still produce the same output.
The top-k experts are independent until the final accumulation, so run each expert phase (gate-up, SwiGLU, down) across all experts before the next barrier instead of one expert at a time. This drops the expert barriers from three per expert to three per layer. TPOT 22.7 -> 19.3 ms.
The scaled FP4 x FP8 expert GEMV used the weight as the dot operand whose contiguous dimension was not the inner one, and small block sizes left most programs idle. Make the weight the lhs so its contiguous K bytes coalesce, and use BLOCK_N=32, BLOCK_K=512, which raises the expert GEMV from ~0.6 to ~1.3 TB/s. TPOT 19.1 -> 11.7 ms.
BLOCK_M=16 lifts the BF16 tiled GEMV toward 5 TB/s on the LM-head shape.
Keep the attention output in its own buffer and defer the residual add to the layer's final accumulation, so the router and expert input read rmsnorm(x + o) directly. Removes the separate residual phase and its grid barrier, and the final residual is striped across all programs. TPOT 11.8 -> 11.5 ms.
Track a running (max, index) while computing the LM-head logits instead of writing all vocabulary logits to HBM and reading them back for the argmax. Removes a full logits round-trip and one barrier. TPOT 11.5 -> 11.2 ms.
A 1024-wide K block on the scaled FP4 x FP8 expert GEMV raises its bandwidth from ~1.3 to ~1.8 TB/s. TPOT 11.2 -> 10.6 ms.
The down projection (N=H) reaches best bandwidth at BLOCK_N=16 while the gate-up (N=2*I) prefers 32, so give them independent tile sizes.
Replace the per-position scalar KV scan with a blocked online-softmax over BLOCK_T positions at a time. This removes the linear per-token slowdown as the context grows, so decode TPOT stays flat. At a 100-token context TPOT drops from ~14 to ~9.5 ms; full-rollout TPOT is ~9.4 ms.
The kernel is bound by the grid-wide barriers between phases, not by GEMV occupancy. Dropping from 256 to 192 programs makes each barrier cheaper while the per-phase GEMVs stay well filled. TPOT 9.4 -> 7.4 ms.
Compute the current position's attention term in-register from the freshly projected k/v and read only the history from the cache, so the cache append and the attention can share a phase. Removes one grid barrier per layer. TPOT 7.4 -> 7.25 ms.
Run the router GEMV, the FP8 expert-input quantization and the accumulator reset in one phase (all depend only on x + o), then compute the top-k redundantly per program so the experts need no separate top-k barrier. TPOT 7.25 -> 7.15 ms.
The last expert's down projection owns the same MoE-output rows it just wrote, so it finalizes x += o + moe in its epilogue instead of a separate striped residual phase. Removes one more grid barrier per layer (quant path). TPOT 7.15 -> 6.97 ms.
Smaller output-row tiles fill the grid better for the attention-path and LM-head GEMVs. TPOT 6.97 -> 6.78 ms.
With the reduced barrier count and BLOCK_M=8, 184 programs measures fastest. TPOT 6.78 -> 6.56 ms.
With BLOCK_M=8, a 1024-wide K block markedly speeds the attention-path and LM-head BF16 GEMVs. TPOT 6.56 -> 5.87 ms.
The LM head spans the full vocabulary, so a separate BLOCK_M_LM lets it use a larger row tile than the small attention GEMVs.
bench_islosl.py sweeps ISL:OSL pairs and reports TTFT, TPOT and end-to-end latency. README now carries a measured table (100:100, 1024:100, 1024:1024, 2048:2048) on MI355X: TPOT is flat at ~5.9 ms/token; TTFT scales linearly with input length since prefill reuses the single-token decode kernel.
Convert the inner K-reduction of every GEMV (expert FP4 dot_scaled, the BF16 attention/router GEMVs and the LM head) from a plain while-loop to a tl.range pipelined loop so the next block's loads overlap the current dot. This lifts the FP4 expert GEMV from ~1.8 to ~2.7 TB/s. TPOT 5.78 -> 5.20 ms.
Contributor
There was a problem hiding this comment.
Pull request overview
Note
Copilot couldn't run its full agentic review because no GitHub Actions runner was available. Make sure your repository has a runner available to run Copilot's review, or add a copilot-setup-steps.yml file specifying one with the runs-on attribute. See the docs for more details.
Adds a GPT-OSS-120B “megakernel” example for Iris, including a single-GPU persistent Triton kernel implementation, reference math, multi-GPU prototype, and tooling for validation/benchmarking/accuracy evaluation.
Changes:
- Introduces the single-GPU persistent Triton megakernel + reusable
common/device-op library. - Adds reference implementation, phased Triton runner, and multiple validation/benchmark scripts.
- Adds a multi-GPU (1 attention/tail + 4 MoE ranks) prototype using Iris symmetric heap.
Reviewed changes
Copilot reviewed 32 out of 34 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| examples/33_gpt_oss_megakernel/tokenizer_util.py | Minimal tokenizer wrapper for HF snapshot tokenizer.json. |
| examples/33_gpt_oss_megakernel/test_quant_expert.py | Validates quantized expert GEMV path vs PyTorch reference. |
| examples/33_gpt_oss_megakernel/test_kernels.py | Validates phased Triton kernels vs reference math. |
| examples/33_gpt_oss_megakernel/test_dot_scaled.py | Validates tl.dot_scaled operand layout for FP4xFP8 GEMV. |
| examples/33_gpt_oss_megakernel/test_barrier.py | Tests grid-wide barrier semantics needed for persistent kernels. |
| examples/33_gpt_oss_megakernel/run_triton_phased.py | End-to-end decode using per-phase Triton kernels orchestrated by host. |
| examples/33_gpt_oss_megakernel/run_reference.py | End-to-end decode using PyTorch reference implementation. |
| examples/33_gpt_oss_megakernel/reference.py | Defines numerical ground-truth for decode forward pass. |
| examples/33_gpt_oss_megakernel/multi_gpu/run_multi_gpu.py | Multi-process multi-GPU driver using Iris symmetric heap exchanges. |
| examples/33_gpt_oss_megakernel/multi_gpu/protocol.py | Defines rank roles and exchange protocol for multi-GPU decode. |
| examples/33_gpt_oss_megakernel/multi_gpu/persistent_kernels.py | Persistent per-rank Triton kernels using device-side flag rendezvous. |
| examples/33_gpt_oss_megakernel/multi_gpu/moe_kernels.py | MoE-rank per-layer expert kernels + scatter-back helper. |
| examples/33_gpt_oss_megakernel/multi_gpu/attn_kernels.py | Attention-rank prologue/scatter/accumulate/lm_head kernels. |
| examples/33_gpt_oss_megakernel/load_hf.py | Loads HF checkpoint into the example’s weight layout + FP4 dequant helpers. |
| examples/33_gpt_oss_megakernel/kernels.py | Standalone Triton building blocks for phased execution. |
| examples/33_gpt_oss_megakernel/gpt_oss_120b_quantized_megakernel.py | Main single-GPU persistent megakernel + host driver. |
| examples/33_gpt_oss_megakernel/convert_to_iris.py | Converts HF checkpoint to .iris tensor archive for mmap/device loading. |
| examples/33_gpt_oss_megakernel/common/swiglu.py | Device ops for SwiGLU (BF16 + FP8-quant variants). |
| examples/33_gpt_oss_megakernel/common/router.py | Device op for top-k + softmax routing. |
| examples/33_gpt_oss_megakernel/common/rmsnorm.py | Device op to materialize residual+RMSNorm output for non-quant path. |
| examples/33_gpt_oss_megakernel/common/quant.py | Device op for fused residual+RMSNorm+FP8 activation quantization. |
| examples/33_gpt_oss_megakernel/common/gemv_fp8.py | FP8 weight-only GEMV device helpers (incl. RMSNorm/resid fused). |
| examples/33_gpt_oss_megakernel/common/gemv_fp4.py | MXFP4 expert GEMV device helpers (dequant path + dot_scaled path). |
| examples/33_gpt_oss_megakernel/common/gemv_bf16.py | BF16 GEMV device helpers (tiled + fused RMSNorm/resid RMSNorm). |
| examples/33_gpt_oss_megakernel/common/fp4.py | FP4 magnitude LUT helper for dequant path. |
| examples/33_gpt_oss_megakernel/common/barrier.py | Grid-wide barrier helper for persistent kernels. |
| examples/33_gpt_oss_megakernel/common/attention.py | Device ops for RoPE+KV append and per-head flash decode with sinks. |
| examples/33_gpt_oss_megakernel/common/init.py | Re-exports reusable device ops for megakernel and multi-GPU kernels. |
| examples/33_gpt_oss_megakernel/bench_tpot.py | TPOT benchmarking script for steady-state decode. |
| examples/33_gpt_oss_megakernel/bench_islosl.py | ISL/OSL sweep benchmark for TTFT/TPOT/E2E throughput. |
| examples/33_gpt_oss_megakernel/acc_eval.py | Accuracy comparison between BF16-attn and FP8-attn variants. |
| examples/33_gpt_oss_megakernel/README.md | Documentation: architecture, usage, benchmarking, and accuracy notes. |
| .gitignore | Adds tmp/ ignore entry. |
Comments suppressed due to low confidence (1)
examples/33_gpt_oss_megakernel/tokenizer_util.py:1
- Snapshot selection always returns the lexicographically-last snapshot directory, without checking that it actually contains
tokenizer.json. If the newest snapshot is partial/corrupt, tokenizer loading will fail later with a less actionable error. Consider filtering candidates to those containingtokenizer.json, and/or raising an error that includes attempted paths when none match.
Comment on lines
+56
to
+57
| NUM_WG = 180 | ||
| _NWG = tl.constexpr(NUM_WG) |
Comment on lines
+148
to
+157
| for t in range(lo, pos + 1): | ||
| kt = tl.load(kcache_ptr + t * kv_dim + kvh * DH + d).to(tl.float32) | ||
| score = tl.sum(q * kt, axis=0) * scale | ||
| m_new = tl.maximum(m, score) | ||
| alpha = tl.exp(m - m_new) | ||
| p = tl.exp(score - m_new) | ||
| l = l * alpha + p | ||
| vt = tl.load(vcache_ptr + t * kv_dim + kvh * DH + d).to(tl.float32) | ||
| acc = acc * alpha + p * vt | ||
| m = m_new |
Comment on lines
+15
to
+22
| tl.inline_asm_elementwise( | ||
| "buffer_inv sc0\n\ts_waitcnt vmcnt(0)", | ||
| "=r", | ||
| [], | ||
| dtype=tl.int32, | ||
| is_pure=False, | ||
| pack=1, | ||
| ) |
Comment on lines
+121
to
+122
| idx_path = os.path.join(snap, "model.safetensors.index.json") | ||
| weight_map = json.load(open(idx_path))["weight_map"] |
Comment on lines
+1
to
+3
| """Accuracy comparison: FP8-attention vs BF16-attention megakernel. | ||
|
|
||
| Both share the FP4 experts; this isolates the error introduced by quantizing the |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This pull request introduces the GPT-OSS-120B Megakernel example, providing a comprehensive, high-performance single-GPU implementation of the GPT-OSS-120B model using a persistent Triton kernel. The changes include detailed documentation, benchmarking and accuracy scripts, and a reusable set of device-level Triton operations for attention, quantization, and expert routing. Together, these updates enable end-to-end quantized inference, benchmarking, and accuracy evaluation for this large mixture-of-experts model.
Major additions and improvements:
1. Documentation and Usage:
README.mddescribing the GPT-OSS-120B Megakernel architecture, model details, quantization options, benchmarking methodology, accuracy tradeoffs, and file purposes.2. Benchmarking and Evaluation Scripts:
bench_tpot.pyfor measuring steady-state decode latency (TPOT) and comparing quantized vs. BF16 inference paths.bench_islosl.pyto benchmark prefill and decode latency across various input/output sequence lengths, reporting throughput and latency metrics.acc_eval.pyto compare accuracy between FP8 and BF16 attention (with shared FP4 experts), providing detailed metrics such as top-1 agreement, top-k overlap, KL divergence, and logit cosine similarity.3. Reusable Device Operations:
common/__init__.pyto expose a suite of reusable Triton device ops for attention, quantization, GEMV, routing, and SwiGLU, supporting both single-GPU and future multi-GPU kernels.common/attention.pycontaining Triton JIT helpers for RoPE+KV cache appending and per-head flash decode with attention sinks, facilitating both current and future kernel designs.