Skip to content

7bbg/flash-attention-lab

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Flash Attention Lab

A custom, educational implementation of FlashAttention (Forward & Backward) using CUDA and C++ PyTorch extensions. This project explores kernel optimizations including tiling, shared memory buffering, tensor core (WMMA) usage, and voltage-level online softmax.

Installation

Prerequisites

  • OS: Windows 10/11 (or Linux)
  • GPU: NVIDIA GPU (Compute Capability 8.0+ recommended, e.g., RTX 30xx/40xx)
  • CUDA: Toolkit v11.6+ (Tested with v13.x)
  • Python: 3.8+
  • PyTorch: installed with CUDA support

Build & Install

Run the following command from the root directory to compile the CUDA kernels and install the package in editable mode:

pip install -e .

or

python setup.py install 

Note: If encountring compiler errors on Windows, ensure cl.exe (MSVC) is in your path via the "x64 Native Tools Command Prompt" and that nvcc is discoverable.

Benchmarking

To gauge performance against PyTorch's native scaled_dot_product_attention (SDPA), run the benchmark script:

python bench/benchmark_final.py

Expected Output

The script validates correctness (FP16 precision) and compares execution time.

Benchmarking: B=4, H=8, N=4096, D=64
--------------------------------------------------
PyTorch Output Shape: torch.Size([4, 8, 4096, 64])
Custom Output Shape: torch.Size([4, 8, 4096, 64])
Max Difference: 0.000122 (FP16 tolerance met)
--------------------------------------------------
PyTorch SDPA Time: 6.29 ms | 21.83 TFLOPS
Our Custom Kernel: 105.060 ms | 1.31 TFLOPS
Relative Performance: ~5.88% (Work in progress)

Testing

Run unit tests:

# Profile cuda kernel
python tests/profile_cuda.py

Implementation Details

Forward Pass (kernels/src/cuda/forward.cu)

  1. Tiling: Blocks used to fit Q, K, V into Shared Memory.
  2. Tensor Cores: Using nvcuda::wmma API for $Q \times K^T$.
  3. Online Softmax: Computes softmax statistics (max, sum) on the fly to avoid global memory round-trips.
  4. Vectorization: Uses half2 types for efficient Global $\to$ Shared memory loads.

Optimization Roadmap

  • Naive Implementation
  • Tiled Implementation (SRAM)
  • Online Softmax (Safe FP16)
  • Tensor Cores (WMMA) for QK^T
  • Tensor Cores for PV accumulation
  • Pipeline optimizations (Warp Specialization)
  • Double buffering

Usage in Python

import torch
import flash_attn_lab

# Define inputs
Q = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
K = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
V = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)

# Run kernel
output = flash_attn_lab.flash_attention_cuda(Q, K, V)

About

No description or website provided.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors