A custom, educational implementation of FlashAttention (Forward & Backward) using CUDA and C++ PyTorch extensions. This project explores kernel optimizations including tiling, shared memory buffering, tensor core (WMMA) usage, and voltage-level online softmax.
- OS: Windows 10/11 (or Linux)
- GPU: NVIDIA GPU (Compute Capability 8.0+ recommended, e.g., RTX 30xx/40xx)
- CUDA: Toolkit v11.6+ (Tested with v13.x)
- Python: 3.8+
- PyTorch: installed with CUDA support
Run the following command from the root directory to compile the CUDA kernels and install the package in editable mode:
pip install -e .
or
python setup.py install Note: If encountring compiler errors on Windows, ensure cl.exe (MSVC) is in your path via the "x64 Native Tools Command Prompt" and that nvcc is discoverable.
To gauge performance against PyTorch's native scaled_dot_product_attention (SDPA), run the benchmark script:
python bench/benchmark_final.pyThe script validates correctness (FP16 precision) and compares execution time.
Benchmarking: B=4, H=8, N=4096, D=64
--------------------------------------------------
PyTorch Output Shape: torch.Size([4, 8, 4096, 64])
Custom Output Shape: torch.Size([4, 8, 4096, 64])
Max Difference: 0.000122 (FP16 tolerance met)
--------------------------------------------------
PyTorch SDPA Time: 6.29 ms | 21.83 TFLOPS
Our Custom Kernel: 105.060 ms | 1.31 TFLOPS
Relative Performance: ~5.88% (Work in progress)
Run unit tests:
# Profile cuda kernel
python tests/profile_cuda.py- Tiling: Blocks used to fit Q, K, V into Shared Memory.
-
Tensor Cores: Using
nvcuda::wmmaAPI for$Q \times K^T$ . - Online Softmax: Computes softmax statistics (max, sum) on the fly to avoid global memory round-trips.
-
Vectorization: Uses
half2types for efficient Global$\to$ Shared memory loads.
- Naive Implementation
- Tiled Implementation (SRAM)
- Online Softmax (Safe FP16)
- Tensor Cores (WMMA) for QK^T
- Tensor Cores for PV accumulation
- Pipeline optimizations (Warp Specialization)
- Double buffering
import torch
import flash_attn_lab
# Define inputs
Q = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
K = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
V = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
# Run kernel
output = flash_attn_lab.flash_attention_cuda(Q, K, V)