Skip to content

codingwithshawnyt/FlashAttention-CUDA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FlashAttention CUDA: A High-Performance Attention Kernel

Build Status CUDA License

Note: This project is a highly optimized, educational implementation of the FlashAttention algorithm (Dao et al.), designed to demonstrate advanced CUDA programming techniques including shared memory tiling, register caching, and vectorization.

🚀 Overview

Attention mechanisms are the bottleneck of modern Transformers, scaling quadratically $O(N^2)$ with sequence length. FlashAttention solves this by reordering operations to minimize High Bandwidth Memory (HBM) accesses, keeping the attention matrix in fast SRAM (Shared Memory).

This repository provides a clean, modular, and extensible C++ CUDA implementation that features:

  • Tiled Matrix Multiplication: Block-based computation ($16 \times 16$ tiles) to maximize data reuse.
  • Online Softmax: Numerical stability without intermediate global memory passes (Safe Softmax 3-pass fused into 1-pass).
  • Vectorized Memory Access: Uses float4 (128-bit) loads/stores to saturate memory bandwidth.
  • Warp-Level Primitives: Efficient __shfl_sync reductions to minimize shared memory bank conflicts.
  • Python Bindings: Seamless integration with PyTorch via pybind11.
  • Rigorous Testing: GoogleTest suite for numerical verification against CPU references.

🛠 Architecture

The codebase is structured as a production-grade C++ library:

├── include/
│   └── kernels.cuh       # Kernel Interface
├── src/
│   ├── kernels/
│   │   └── flash_fwd_kernel.cu  # The Core Logic
│   ├── python/           # Pybind11 Bindings
│   └── benchmark.cu      # Standalone Benchmark Tool
├── tests/                # GoogleTest Unit Tests
└── CMakeLists.txt        # Modern CMake Build System

⚡ Performance Optimization Details

1. Memory Hierarchy Aware Tiling

We partition the $Q, K, V$ matrices into blocks fitting into the L1/Shared Memory.

  • Block Size: $B_r=16, B_c=16$.
  • Shared Memory Layout: [Br][D] float arrays to minimize bank conflicts during the inner product.

2. Register Caching

Accumulators for the output matrix $O$ are held entirely in registers (float acc[8] per thread) during the inner loop over $K, V$ blocks, preventing repeated global memory writes.

3. Vectorization

All global memory reads for $Q, K, V$ are performed using float4 instructions, loading 4 floats (16 bytes) per instruction. This significantly reduces the instruction overhead and improves bus utilization.

4. Online Softmax (Safe Softmax)

We track the running maximum ($m$) and running sum ($l$) of exponentials. When updating $m$, we rescale the current accumulator $O$ by $e^{m_{old} - m_{new}}$, ensuring we never overflow fp32 range while computing the exact Softmax result in a single pass.

📦 Build & Usage

Prerequisites

  • NVIDIA GPU (Ampere/Hopper recommended for max performance)
  • CUDA Toolkit 11+
  • CMake 3.18+
  • C++17 Compiler

Building from Source

mkdir build && cd build
cmake ..
make -j

Running Benchmarks

The benchmark executable runs a performance test:

./benchmark [BatchSize] [SeqLen]
# Example:
./benchmark 16 4096

Running Tests

We use GoogleTest to verify correctness against a simple CPU implementation.

./unit_tests

Python Integration

You can build the Python extension to use the kernel directly in Python scripts:

# Ensure pybind11 is available
make flash_attn_cuda

Then in Python:

import flash_attn_cuda
import torch

# ... setup pointers ...
flash_attn_cuda.flash_attn_fwd(q_ptr, k_ptr, v_ptr, o_ptr, B, N, d)

📈 Roadmap

  • FP32 Forward Pass with Tiling & Vectorization
  • Python Bindings
  • Tensor Core Support (FP16/BF16): Implementing wmma intrinsics for 2x-4x throughput on Tensor Cores.
  • Backward Pass: Gradient computation for training.
  • Variable Sequence Lengths: Support for ragged batches.

📄 License

MIT License. See LICENSE for details.


Built with ❤️ and CUDA by Shawn Ray.

About

This repository contains a highly optimized, from-scratch implementation of the FlashAttention algorithm in CUDA. Designed for maximum performance on NVIDIA GPUs, this kernel demonstrates advanced memory hierarchy management, tiling strategies, and numerical stability techniques (Online Softmax).

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors