Skip to content

Support Attention head_dim=512#412

Merged
CC-Yeh merged 2 commits into
mainfrom
support_headdim_512_attention
May 14, 2026
Merged

Support Attention head_dim=512#412
CC-Yeh merged 2 commits into
mainfrom
support_headdim_512_attention

Conversation

@CC-Yeh
Copy link
Copy Markdown
Contributor

@CC-Yeh CC-Yeh commented May 14, 2026

  • MatmulArguments: b_offset, b_leading_dimension, b_transpose (the last as a VARIANTS axis on gemm.metal).
  • head_dim=512 dispatch (suffix > 8): per-group matmul(Q, K^T) → mask → softmax → matmul(P, V) → scatter. Smaller cases use existing single_pass/two_pass (also gained
    HEAD_DIM=512).
  • New generic Softmax kernel under kernel/softmax/.
  • Two attention helpers: ScatterScores, ScatterValues

@CC-Yeh
Copy link
Copy Markdown
Contributor Author

CC-Yeh commented May 14, 2026

Why BD=512 fused GEMM attention didn't work

  • Threadgroup memory ceiling is 32 KB on Apple Silicon. Fused GEMM needs Q-tile + K/V-tile resident:

    BQ BK q_smem + kv_smem
    32 32 66 KB — overflows
    32 16 49 KB — overflows
    16 16 32 KB — at the limit
    16 8 24 KB — fits
  • The only tile that fit (BQ=16, BK=8) was 2–3× slower than unfused matmul:

    seq fused BD=512 unfused matmul ratio
    512 4.23 ms 1.80 ms 2.4× slower
    2048 66.4 ms 22.5 ms 2.9× slower
  • Why: BK=8 means 128 FMAs per KV reload vs SteelMatmul's 2048 (16× worse bandwidth amortization). BQ=16 means 2 simdgroup-matrix rows — too few to overlap math with loads.

Why we use the unfused matmul pipeline instead

  • At HEAD_DIM=512 the work IS two big GEMMs (Q @ K^T and P @ V). GEMM kernels are already peak-tuned for those shapes; a hand-written attention kernel hardly beat them.
  • MLX does the same — head_dim ∉ {64, 80, 128}use_fallback = true → matmul + softmax + matmul at the graph level.

@CC-Yeh CC-Yeh marked this pull request as ready for review May 14, 2026 15:25
@CC-Yeh CC-Yeh enabled auto-merge (squash) May 14, 2026 15:26
@CC-Yeh CC-Yeh merged commit 5ff6f3a into main May 14, 2026
7 checks passed
@CC-Yeh CC-Yeh deleted the support_headdim_512_attention branch May 14, 2026 15:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants