Fix CUDA Attention dispatch: skip MEA when head_size != v_head_size in GQA#28358
Open
justinchuby wants to merge 1 commit intomainfrom
Open
Fix CUDA Attention dispatch: skip MEA when head_size != v_head_size in GQA#28358justinchuby wants to merge 1 commit intomainfrom
justinchuby wants to merge 1 commit intomainfrom
Conversation
…n GQA The Memory-Efficient Attention (MEA) path in the CUDA Attention kernel crashes with misaligned address when q_num_heads != kv_num_heads (GQA mode) and head_size != v_head_size. This happens because MEA's LaunchUngroup requires equal head sizes, but the dispatch logic only checked this constraint for the past_key case, not the general GQA case. Add the missing check: skip MEA for GQA when head_size != v_head_size, allowing the Unfused Attention fallback to handle it correctly. This fixes Gemma4 models with KV sharing where Q has head_dim=256 but shared K,V have head_dim=512. CPU EP handled this correctly; CUDA EP crashed at seq_len >= 4 when Flash Attention was not eligible. Fixes #28357 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <justinchu@microsoft.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Problem
The Memory-Efficient Attention (MEA) path crashes with
cudaErrorMisalignedAddresswhen:q_num_heads != kv_num_heads)head_size != v_head_size(e.g., Q.head_dim=256, K.head_dim=512)seq_len >= 4(Flash Attention not eligible due to attention mask)This is because MEA's
LaunchUngrouprequires equal head sizes, but the dispatch logic only checked this constraint for the past_key case (line 1380), not the general GQA case.Fix
Skip MEA for GQA when head sizes differ. The Unfused Attention fallback handles this correctly.
Affected Models
Gemma4 (google/gemma-4-e2b-it) with KV sharing:
head_size = 256,v_head_size = 512Testing
Minimal repro (from #28357):
Full Gemma4 decoder (35 layers, 15 GQA + 20 standard Attention):
Fixes #28357