Skip to content

Fix CUDA Attention dispatch: skip MEA when head_size != v_head_size in GQA#28358

Open
justinchuby wants to merge 1 commit intomainfrom
fix-attention-head-size-mismatch
Open

Fix CUDA Attention dispatch: skip MEA when head_size != v_head_size in GQA#28358
justinchuby wants to merge 1 commit intomainfrom
fix-attention-head-size-mismatch

Conversation

@justinchuby
Copy link
Copy Markdown
Contributor

Summary

Problem

The Memory-Efficient Attention (MEA) path crashes with cudaErrorMisalignedAddress when:

  • GQA mode (q_num_heads != kv_num_heads)
  • head_size != v_head_size (e.g., Q.head_dim=256, K.head_dim=512)
  • seq_len >= 4 (Flash Attention not eligible due to attention mask)

This is because MEA's LaunchUngroup requires equal head sizes, but the dispatch logic only checked this constraint for the past_key case (line 1380), not the general GQA case.

Fix

Skip MEA for GQA when head sizes differ. The Unfused Attention fallback handles this correctly.

Affected Models

Gemma4 (google/gemma-4-e2b-it) with KV sharing:

  • Layers 15-34 borrow K,V from source layers
  • Q projection: 1536 → 2048 (8 heads × 256)
  • K/V from source: [batch, 1, seq, 512]
  • head_size = 256, v_head_size = 512

Testing

Minimal repro (from #28357):

# Attention(Q=[1,S,2048], K=[1,S,512], V=[1,S,512], q_num_heads=8, kv_num_heads=1)
# Before fix: seq=4+ crashes with misaligned address
# After fix: all seq lengths work

Full Gemma4 decoder (35 layers, 15 GQA + 20 standard Attention):

  • Prefill seq=32: ✅
  • Decode seq=1: ✅

Fixes #28357

…n GQA

The Memory-Efficient Attention (MEA) path in the CUDA Attention kernel
crashes with misaligned address when q_num_heads != kv_num_heads (GQA
mode) and head_size != v_head_size. This happens because MEA's
LaunchUngroup requires equal head sizes, but the dispatch logic only
checked this constraint for the past_key case, not the general GQA case.

Add the missing check: skip MEA for GQA when head_size != v_head_size,
allowing the Unfused Attention fallback to handle it correctly.

This fixes Gemma4 models with KV sharing where Q has head_dim=256 but
shared K,V have head_dim=512. CPU EP handled this correctly; CUDA EP
crashed at seq_len >= 4 when Flash Attention was not eligible.

Fixes #28357

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchu@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CUDA Attention kernel crashes with mismatched Q/K head dimensions (head_size != v_head_size)

1 participant