Skip to content

MegaKernels in Iris 🔥#541

Draft
neoblizz wants to merge 42 commits into
mainfrom
neoblizz/megakernel-perf
Draft

MegaKernels in Iris 🔥#541
neoblizz wants to merge 42 commits into
mainfrom
neoblizz/megakernel-perf

Conversation

@neoblizz

Copy link
Copy Markdown
Member

This pull request introduces the GPT-OSS-120B Megakernel example, providing a comprehensive, high-performance single-GPU implementation of the GPT-OSS-120B model using a persistent Triton kernel. The changes include detailed documentation, benchmarking and accuracy scripts, and a reusable set of device-level Triton operations for attention, quantization, and expert routing. Together, these updates enable end-to-end quantized inference, benchmarking, and accuracy evaluation for this large mixture-of-experts model.

Major additions and improvements:

1. Documentation and Usage:

  • Added a detailed README.md describing the GPT-OSS-120B Megakernel architecture, model details, quantization options, benchmarking methodology, accuracy tradeoffs, and file purposes.

2. Benchmarking and Evaluation Scripts:

  • Introduced bench_tpot.py for measuring steady-state decode latency (TPOT) and comparing quantized vs. BF16 inference paths.
  • Added bench_islosl.py to benchmark prefill and decode latency across various input/output sequence lengths, reporting throughput and latency metrics.
  • Implemented acc_eval.py to compare accuracy between FP8 and BF16 attention (with shared FP4 experts), providing detailed metrics such as top-1 agreement, top-k overlap, KL divergence, and logit cosine similarity.

3. Reusable Device Operations:

  • Created common/__init__.py to expose a suite of reusable Triton device ops for attention, quantization, GEMV, routing, and SwiGLU, supporting both single-GPU and future multi-GPU kernels.
  • Added common/attention.py containing Triton JIT helpers for RoPE+KV cache appending and per-head flash decode with attention sinks, facilitating both current and future kernel designs.

neoblizz added 30 commits June 22, 2026 13:40
Single persistent Triton kernel running attention and MoE for all 36
layers of GPT-OSS-120B on one GPU (batch-1 decode), collapsing cosmic's
multi-GPU design. FP4 expert weights dequantized in-kernel, BF16 compute.
Validated against a PyTorch reference: greedy decode produces coherent
output on real weights. Includes HF->.iris converter and kernel tests.
Expert GEMVs can now run native FP4xFP8 scaled matrix multiply via
tl.dot_scaled (lowers to v_mfma_scale_f32_16x16x128_f8f6f4 on gfx950)
with dynamic FP8-E4M3 activation quant, selected by --quant. Keeps the
BF16 dequant path as default. Quantized path is ~2.8x faster (29ms vs
84ms TPOT on MI355X) and still decodes coherently. Adds dot_scaled and
quantized-expert correctness tests plus a TPOT benchmark.
Block-of-rows GEMV tiling with max_contiguous/multiple_of hints so the
compiler emits dwordx4 loads instead of per-element ushort. Fuse the
attention RMSNorm into QKV and the MoE RMSNorm into the router GEMV, fold
RoPE into attention, and stripe the residual/zeroing across all programs.
This removes the serial single-program phases and their grid barriers.
TPOT drops from 29.4 to 25.3 ms (quant) on MI355X; output unchanged.
In the quantized path, each program now owns whole 32-element SwiGLU
output blocks and quantizes the block it just produced, so the producer
and consumer are the same program and the separate act-quant barrier is
gone (4 fewer grid barriers per layer). TPOT drops from 25.3 to 22.9 ms
(quant) on MI355X; output unchanged.
Document the measured findings (zero spills, per-element loads, barrier
cost) and the tiling + barrier-reduction passes. Refresh the TPOT table:
quantized 29.4 -> 22.9 ms on MI355X.
Simplify the README to match the other examples, rewrite the module and
test docstrings for clarity, and remove implementation-history detail.
No functional changes; both decode paths still produce the same output.
The top-k experts are independent until the final accumulation, so run
each expert phase (gate-up, SwiGLU, down) across all experts before the
next barrier instead of one expert at a time. This drops the expert
barriers from three per expert to three per layer. TPOT 22.7 -> 19.3 ms.
The scaled FP4 x FP8 expert GEMV used the weight as the dot operand whose
contiguous dimension was not the inner one, and small block sizes left most
programs idle. Make the weight the lhs so its contiguous K bytes coalesce,
and use BLOCK_N=32, BLOCK_K=512, which raises the expert GEMV from ~0.6 to
~1.3 TB/s. TPOT 19.1 -> 11.7 ms.
BLOCK_M=16 lifts the BF16 tiled GEMV toward 5 TB/s on the LM-head shape.
Keep the attention output in its own buffer and defer the residual add to
the layer's final accumulation, so the router and expert input read
rmsnorm(x + o) directly. Removes the separate residual phase and its grid
barrier, and the final residual is striped across all programs. TPOT
11.8 -> 11.5 ms.
Track a running (max, index) while computing the LM-head logits instead of
writing all vocabulary logits to HBM and reading them back for the argmax.
Removes a full logits round-trip and one barrier. TPOT 11.5 -> 11.2 ms.
A 1024-wide K block on the scaled FP4 x FP8 expert GEMV raises its
bandwidth from ~1.3 to ~1.8 TB/s. TPOT 11.2 -> 10.6 ms.
The down projection (N=H) reaches best bandwidth at BLOCK_N=16 while the
gate-up (N=2*I) prefers 32, so give them independent tile sizes.
Replace the per-position scalar KV scan with a blocked online-softmax over
BLOCK_T positions at a time. This removes the linear per-token slowdown as
the context grows, so decode TPOT stays flat. At a 100-token context TPOT
drops from ~14 to ~9.5 ms; full-rollout TPOT is ~9.4 ms.
The kernel is bound by the grid-wide barriers between phases, not by GEMV
occupancy. Dropping from 256 to 192 programs makes each barrier cheaper
while the per-phase GEMVs stay well filled. TPOT 9.4 -> 7.4 ms.
Compute the current position's attention term in-register from the freshly
projected k/v and read only the history from the cache, so the cache append
and the attention can share a phase. Removes one grid barrier per layer.
TPOT 7.4 -> 7.25 ms.
Run the router GEMV, the FP8 expert-input quantization and the accumulator
reset in one phase (all depend only on x + o), then compute the top-k
redundantly per program so the experts need no separate top-k barrier.
TPOT 7.25 -> 7.15 ms.
The last expert's down projection owns the same MoE-output rows it just
wrote, so it finalizes x += o + moe in its epilogue instead of a separate
striped residual phase. Removes one more grid barrier per layer (quant
path). TPOT 7.15 -> 6.97 ms.
Smaller output-row tiles fill the grid better for the attention-path and
LM-head GEMVs. TPOT 6.97 -> 6.78 ms.
With the reduced barrier count and BLOCK_M=8, 184 programs measures fastest.
TPOT 6.78 -> 6.56 ms.
With BLOCK_M=8, a 1024-wide K block markedly speeds the attention-path and
LM-head BF16 GEMVs. TPOT 6.56 -> 5.87 ms.
The LM head spans the full vocabulary, so a separate BLOCK_M_LM lets it use
a larger row tile than the small attention GEMVs.
bench_islosl.py sweeps ISL:OSL pairs and reports TTFT, TPOT and end-to-end
latency. README now carries a measured table (100:100, 1024:100, 1024:1024,
2048:2048) on MI355X: TPOT is flat at ~5.9 ms/token; TTFT scales linearly
with input length since prefill reuses the single-token decode kernel.
Convert the inner K-reduction of every GEMV (expert FP4 dot_scaled, the
BF16 attention/router GEMVs and the LM head) from a plain while-loop to a
tl.range pipelined loop so the next block's loads overlap the current dot.
This lifts the FP4 expert GEMV from ~1.8 to ~2.7 TB/s. TPOT 5.78 -> 5.20 ms.
Copilot AI review requested due to automatic review settings June 24, 2026 15:56
@neoblizz neoblizz requested review from BKP and mawad-amd as code owners June 24, 2026 15:56
@github-actions github-actions Bot added in-progress We are working on it iris Iris project issue labels Jun 24, 2026
@neoblizz neoblizz marked this pull request as draft June 24, 2026 15:57

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot couldn't run its full agentic review because no GitHub Actions runner was available. Make sure your repository has a runner available to run Copilot's review, or add a copilot-setup-steps.yml file specifying one with the runs-on attribute. See the docs for more details.

Adds a GPT-OSS-120B “megakernel” example for Iris, including a single-GPU persistent Triton kernel implementation, reference math, multi-GPU prototype, and tooling for validation/benchmarking/accuracy evaluation.

Changes:

  • Introduces the single-GPU persistent Triton megakernel + reusable common/ device-op library.
  • Adds reference implementation, phased Triton runner, and multiple validation/benchmark scripts.
  • Adds a multi-GPU (1 attention/tail + 4 MoE ranks) prototype using Iris symmetric heap.

Reviewed changes

Copilot reviewed 32 out of 34 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
examples/33_gpt_oss_megakernel/tokenizer_util.py Minimal tokenizer wrapper for HF snapshot tokenizer.json.
examples/33_gpt_oss_megakernel/test_quant_expert.py Validates quantized expert GEMV path vs PyTorch reference.
examples/33_gpt_oss_megakernel/test_kernels.py Validates phased Triton kernels vs reference math.
examples/33_gpt_oss_megakernel/test_dot_scaled.py Validates tl.dot_scaled operand layout for FP4xFP8 GEMV.
examples/33_gpt_oss_megakernel/test_barrier.py Tests grid-wide barrier semantics needed for persistent kernels.
examples/33_gpt_oss_megakernel/run_triton_phased.py End-to-end decode using per-phase Triton kernels orchestrated by host.
examples/33_gpt_oss_megakernel/run_reference.py End-to-end decode using PyTorch reference implementation.
examples/33_gpt_oss_megakernel/reference.py Defines numerical ground-truth for decode forward pass.
examples/33_gpt_oss_megakernel/multi_gpu/run_multi_gpu.py Multi-process multi-GPU driver using Iris symmetric heap exchanges.
examples/33_gpt_oss_megakernel/multi_gpu/protocol.py Defines rank roles and exchange protocol for multi-GPU decode.
examples/33_gpt_oss_megakernel/multi_gpu/persistent_kernels.py Persistent per-rank Triton kernels using device-side flag rendezvous.
examples/33_gpt_oss_megakernel/multi_gpu/moe_kernels.py MoE-rank per-layer expert kernels + scatter-back helper.
examples/33_gpt_oss_megakernel/multi_gpu/attn_kernels.py Attention-rank prologue/scatter/accumulate/lm_head kernels.
examples/33_gpt_oss_megakernel/load_hf.py Loads HF checkpoint into the example’s weight layout + FP4 dequant helpers.
examples/33_gpt_oss_megakernel/kernels.py Standalone Triton building blocks for phased execution.
examples/33_gpt_oss_megakernel/gpt_oss_120b_quantized_megakernel.py Main single-GPU persistent megakernel + host driver.
examples/33_gpt_oss_megakernel/convert_to_iris.py Converts HF checkpoint to .iris tensor archive for mmap/device loading.
examples/33_gpt_oss_megakernel/common/swiglu.py Device ops for SwiGLU (BF16 + FP8-quant variants).
examples/33_gpt_oss_megakernel/common/router.py Device op for top-k + softmax routing.
examples/33_gpt_oss_megakernel/common/rmsnorm.py Device op to materialize residual+RMSNorm output for non-quant path.
examples/33_gpt_oss_megakernel/common/quant.py Device op for fused residual+RMSNorm+FP8 activation quantization.
examples/33_gpt_oss_megakernel/common/gemv_fp8.py FP8 weight-only GEMV device helpers (incl. RMSNorm/resid fused).
examples/33_gpt_oss_megakernel/common/gemv_fp4.py MXFP4 expert GEMV device helpers (dequant path + dot_scaled path).
examples/33_gpt_oss_megakernel/common/gemv_bf16.py BF16 GEMV device helpers (tiled + fused RMSNorm/resid RMSNorm).
examples/33_gpt_oss_megakernel/common/fp4.py FP4 magnitude LUT helper for dequant path.
examples/33_gpt_oss_megakernel/common/barrier.py Grid-wide barrier helper for persistent kernels.
examples/33_gpt_oss_megakernel/common/attention.py Device ops for RoPE+KV append and per-head flash decode with sinks.
examples/33_gpt_oss_megakernel/common/init.py Re-exports reusable device ops for megakernel and multi-GPU kernels.
examples/33_gpt_oss_megakernel/bench_tpot.py TPOT benchmarking script for steady-state decode.
examples/33_gpt_oss_megakernel/bench_islosl.py ISL/OSL sweep benchmark for TTFT/TPOT/E2E throughput.
examples/33_gpt_oss_megakernel/acc_eval.py Accuracy comparison between BF16-attn and FP8-attn variants.
examples/33_gpt_oss_megakernel/README.md Documentation: architecture, usage, benchmarking, and accuracy notes.
.gitignore Adds tmp/ ignore entry.
Comments suppressed due to low confidence (1)

examples/33_gpt_oss_megakernel/tokenizer_util.py:1

  • Snapshot selection always returns the lexicographically-last snapshot directory, without checking that it actually contains tokenizer.json. If the newest snapshot is partial/corrupt, tokenizer loading will fail later with a less actionable error. Consider filtering candidates to those containing tokenizer.json, and/or raising an error that includes attempted paths when none match.

Comment on lines +56 to +57
NUM_WG = 180
_NWG = tl.constexpr(NUM_WG)
Comment on lines +148 to +157
for t in range(lo, pos + 1):
kt = tl.load(kcache_ptr + t * kv_dim + kvh * DH + d).to(tl.float32)
score = tl.sum(q * kt, axis=0) * scale
m_new = tl.maximum(m, score)
alpha = tl.exp(m - m_new)
p = tl.exp(score - m_new)
l = l * alpha + p
vt = tl.load(vcache_ptr + t * kv_dim + kvh * DH + d).to(tl.float32)
acc = acc * alpha + p * vt
m = m_new
Comment on lines +15 to +22
tl.inline_asm_elementwise(
"buffer_inv sc0\n\ts_waitcnt vmcnt(0)",
"=r",
[],
dtype=tl.int32,
is_pure=False,
pack=1,
)
Comment on lines +121 to +122
idx_path = os.path.join(snap, "model.safetensors.index.json")
weight_map = json.load(open(idx_path))["weight_map"]
Comment on lines +1 to +3
"""Accuracy comparison: FP8-attention vs BF16-attention megakernel.

Both share the FP4 experts; this isolates the error introduced by quantizing the
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants