Note: This project is a highly optimized, educational implementation of the FlashAttention algorithm (Dao et al.), designed to demonstrate advanced CUDA programming techniques including shared memory tiling, register caching, and vectorization.
Attention mechanisms are the bottleneck of modern Transformers, scaling quadratically
This repository provides a clean, modular, and extensible C++ CUDA implementation that features:
-
Tiled Matrix Multiplication: Block-based computation (
$16 \times 16$ tiles) to maximize data reuse. - Online Softmax: Numerical stability without intermediate global memory passes (Safe Softmax 3-pass fused into 1-pass).
-
Vectorized Memory Access: Uses
float4(128-bit) loads/stores to saturate memory bandwidth. -
Warp-Level Primitives: Efficient
__shfl_syncreductions to minimize shared memory bank conflicts. -
Python Bindings: Seamless integration with PyTorch via
pybind11. - Rigorous Testing: GoogleTest suite for numerical verification against CPU references.
The codebase is structured as a production-grade C++ library:
├── include/
│ └── kernels.cuh # Kernel Interface
├── src/
│ ├── kernels/
│ │ └── flash_fwd_kernel.cu # The Core Logic
│ ├── python/ # Pybind11 Bindings
│ └── benchmark.cu # Standalone Benchmark Tool
├── tests/ # GoogleTest Unit Tests
└── CMakeLists.txt # Modern CMake Build System
We partition the
-
Block Size:
$B_r=16, B_c=16$ . -
Shared Memory Layout:
[Br][D]float arrays to minimize bank conflicts during the inner product.
Accumulators for the output matrix float acc[8] per thread) during the inner loop over
All global memory reads for float4 instructions, loading 4 floats (16 bytes) per instruction. This significantly reduces the instruction overhead and improves bus utilization.
We track the running maximum (
- NVIDIA GPU (Ampere/Hopper recommended for max performance)
- CUDA Toolkit 11+
- CMake 3.18+
- C++17 Compiler
mkdir build && cd build
cmake ..
make -jThe benchmark executable runs a performance test:
./benchmark [BatchSize] [SeqLen]
# Example:
./benchmark 16 4096We use GoogleTest to verify correctness against a simple CPU implementation.
./unit_testsYou can build the Python extension to use the kernel directly in Python scripts:
# Ensure pybind11 is available
make flash_attn_cudaThen in Python:
import flash_attn_cuda
import torch
# ... setup pointers ...
flash_attn_cuda.flash_attn_fwd(q_ptr, k_ptr, v_ptr, o_ptr, B, N, d)- FP32 Forward Pass with Tiling & Vectorization
- Python Bindings
- Tensor Core Support (FP16/BF16): Implementing
wmmaintrinsics for 2x-4x throughput on Tensor Cores. - Backward Pass: Gradient computation for training.
- Variable Sequence Lengths: Support for ragged batches.
MIT License. See LICENSE for details.
Built with ❤️ and CUDA by Shawn Ray.