Skip to content

【训练营】FlashAttention接入 @simon_chou#128

Open
Simon-CHOU wants to merge 6 commits intoInfiniTensor:masterfrom
Simon-CHOU:simon_chou-20260316
Open

【训练营】FlashAttention接入 @simon_chou#128
Simon-CHOU wants to merge 6 commits intoInfiniTensor:masterfrom
Simon-CHOU:simon_chou-20260316

Conversation

@Simon-CHOU
Copy link

[Feat] FlashAttention Integration (GPT-2/LLaMA-3)

Summary

Integrates FlashAttention kernels into InfiniTrain to optimize memory usage and support long-sequence training. Aligns with project requirements for the 2025 Winter Training Camp.

Key Changes

  • Kernels: Added FP32-based FlashAttentionForwardKernel and FlashAttentionBackwardKernel with causal masking and scaling support.
  • Ops: Implemented ScaledDotProductAttention autograd function, mirroring PyTorch's interface.
  • Models: Enabled --flash flag for GPT-2 and LLaMA-3 to toggle between baseline and FlashAttention.
  • Build: Updated CMake to support CUDA kernels and enforce Release optimization.

Verification

  • Precision (GPT-2): Loss alignment within 0.2% of baseline (FP32), verifying numerical correctness.
  • Stability: Fixed NaN issues by correcting shared memory initialization and gradient accumulation.
  • Functionality: LLaMA-3 (1B) training runs successfully without OOM.

Performance

  • Memory: ~15.7% reduction in peak memory usage on GPT-2 (SeqLen=1024).
  • Throughput: Comparable to baseline (FP32 SIMT kernel used for precision).

Notes

@kilinchange
Copy link
Collaborator

  1. 请移除 pr 中不必要的提交,pr 中只需包含代码部分修改,项目报告相关内容请作为邮件附件发送;
  2. 请解决目前 pr 与 master 分支的冲突。

@kilinchange kilinchange self-requested a review March 17, 2026 06:17
@kilinchange kilinchange self-assigned this Mar 17, 2026
@Simon-CHOU
Copy link
Author

Thank you for the detailed review. I'm working through these points and expect to push a new version by tomorrow(2026/03/23).

Regarding point #1: The project report has already been sent via email (mrsimonzhou@gmail.com) as requested.

@Simon-CHOU Simon-CHOU force-pushed the simon_chou-20260316 branch from c6498f0 to 16e102b Compare March 23, 2026 12:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants