[GQA] Make present_key/present_value outputs optional and add Gemma4 support#28242
[GQA] Make present_key/present_value outputs optional and add Gemma4 support#28242apsonawane wants to merge 19 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds support for KV-shared decoder layers by allowing com.microsoft.GroupQueryAttention to optionally consume pre-computed external K/V tensors (instead of maintaining/updating its own KV cache), enabling architectures like Gemma4-style KV sharing.
Changes:
- Extended the GroupQueryAttention schema with optional inputs
external_key/external_value(indices 14/15). - Added new parameters + validation helpers to detect/configure “external KV” mode and enforce
do_rotary=0. - Updated CPU and CUDA kernels to source KV from external tensors and bypass KV-cache update / RoPE-on-KV paths.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/graph/contrib_ops/bert_defs.cc | Adds schema inputs for external KV (but type/shape inference also needs external-KV awareness). |
| onnxruntime/contrib_ops/cpu/bert/attention_parameters.h | Adds use_external_kv and external_kv_sequence_length to GQA parameters. |
| onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h | Adds Q-only checks and external-KV shape validation/configuration helpers; updates CheckInputs to distinguish packed-QKV vs Q-only. |
| onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc | Plumbs external KV inputs into CPU kernel and skips K/V transpose + rotary when external KV is used. |
| onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h | Updates CPU attention core to skip KV concatenation and copy external KV to present outputs once per KV head. |
| onnxruntime/contrib_ops/cuda/bert/attention_data.h | Adds external KV pointers to CUDA attention data struct. |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc | Plumbs external KV inputs into CUDA kernel and enforces do_rotary=0 for external KV mode. |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu | Adds external-KV path in PrepareQKV (copy external KV to present and skip append/RoPE). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tianleiwu
left a comment
There was a problem hiding this comment.
There is no need to add extra inputs, you can use key/value for that, and make past_key/past_value/present_key/present_value as optional.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tianleiwu
left a comment
There was a problem hiding this comment.
I found one additional correctness issue on the current head. There are also already-open current-head threads covering the optional-present tests and the CUDA/documentation mismatch, so I am not duplicating those here.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tianleiwu
left a comment
There was a problem hiding this comment.
Most of my concerns on this head are already covered by existing review threads (CPU GEMM/concat invariant, CUDA hard-error vs. optional schema, mixed-output configurations, scratch buffer sizing on CUDA, seqlen_present_kv_cache initialization, test tolerance). Two items I did not see covered:
1. PR description does not describe this PR. The description discusses adding external_key/external_value inputs (inputs 14/15), Check_Q_Only, CheckExternalKV, use_external_kv, CUDA attention_data.h fields, etc. None of that appears in this diff — the actual change makes existing present_key/present_value outputs optional. Please rewrite the description to match the implementation; downstream tooling and release notes rely on it.
2. CUDA test coverage is missing. All new tests use DefaultCpuExecutionProvider() only. The CUDA path has its own contracts added in this PR — the early guard in group_query_attention.cc (claims first-prompt is supported) and the unconditional rejection in PrepareQKV (group_query_attention_impl.cu). These two messages contradict each other, but no CUDA test exercises omit_present=true, so the contradiction is invisible to CI. Please add at least one CUDA-gated negative test that asserts the kernel rejects omitted present outputs with a stable error message (or, if the intent is to support it on CUDA, a positive equivalence test). Otherwise a future refactor of PrepareQKV can silently change the user-facing error path.
No new inline comments — the existing threads already pinpoint the code locations.
tianleiwu
left a comment
There was a problem hiding this comment.
Review of head 6cbe62c
The two concerns I raised on 9a14803a have been addressed:
- PR description body now accurately describes the optional-present-output change.
- CUDA test (
OptionalPresent_CudaOmitMatchesConnected) was added.
The is_first_prompt validation in both CPU and CUDA kernels properly constrains the omitted-present case, and the scratch-buffer approach on CUDA correctly keeps data.present_key/data.present_value non-null for downstream kernels. I resolved my earlier thread on gqa_attention_base.h line 177.
One remaining concern:
WebGPU EP still has no guard for omitted present outputs. The schema relaxation (OpSchema::Optional on outputs 1/2) is global across all EPs. The WebGPU kernel (onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc ~line 240) calls context.Output(1, present_kv_shape) and then passes the result to ApplyFlashAttention/ApplyAttention without a nullptr check. A model that omits these outputs will crash at runtime on WebGPU. Please add a validation that rejects present_key == nullptr || present_value == nullptr with a clear message (e.g., "WebGPU GroupQueryAttention requires present_key and present_value outputs").
Nitpick: The PR title still says "Add external_key/external_value inputs" — please update to match the description body.
tianleiwu
left a comment
There was a problem hiding this comment.
Review Summary
The approach of making present_key/present_value optional in GQA is the right choice for KV-shared layers (MHA is not viable since it lacks kv_num_heads, seqlens_k, sliding window, rotary, and KV quantization). The CPU null-guards and WebGPU rejection are well done. However, two design issues need addressing:
1. CUDA scratch allocation negates memory savings — Allocating full-size scratch buffers (present_shape.Size() * sizeof(U) × 2) when present outputs are omitted defeats the purpose. MHA's CUDA PrepareQkv_MHA_NoPast path handles the equivalent scenario by using K/V input buffers directly for flash/MEA kernels — zero scratch allocation. GQA should follow the same pattern: skip scratch allocation and modify PrepareQKV in group_query_attention_impl.cu to use K/V input data directly (or the rotary output buffer if rotary is applied).
2. is_first_prompt constraint too restrictive for Gemma 4 — The guard rejects omitting present when sequence_length != total_sequence_length. But in KV-shared layers:
- During decode: Q has length 1, but borrowed KV has length = total_seq_len →
is_first_prompt = false→ rejection fires - During prefill with different Q/KV lengths: same issue
This means the feature only works for prefill when Q and KV happen to have equal length — blocking the decode use case entirely. The real safety condition is past_key != nullptr (past KV exists and must be concatenated into present), not is_first_prompt. MHA supports query_length != key_length without KV cache for the same pattern.
Other notes:
past_present_share_bufferevaluates totruewhen both past and present arenullptr(no bug, but semantically misleading — consider addingpresent_key_data != nullptr &&prefix)- Please confirm deleted external KV tests (
ExternalKV_*) are no longer needed (feature fully removed)
| IAllocatorUniquePtr<void> present_key_scratch; | ||
| IAllocatorUniquePtr<void> present_value_scratch; | ||
| if (present_key_output == nullptr || present_value_output == nullptr) { | ||
| size_t present_kv_bytes = present_shape.Size() * sizeof(U); |
There was a problem hiding this comment.
This scratch allocation (present_shape.Size() * sizeof(U) × 2) negates the memory savings of omitting present outputs.
MHA's CUDA PrepareQkv_MHA_NoPast path uses K/V input buffers directly (data.k = const_cast<T*>(data.key)) for flash/MEA kernels when there's no past/present — zero scratch allocation.
Since omitting present is only allowed with no past, GQA should follow the same pattern: skip scratch allocation and modify PrepareQKV in group_query_attention_impl.cu to use the K/V input data directly (or the rotary output buffer if rotary is applied) rather than allocating a separate buffer that duplicates the input.
There was a problem hiding this comment.
GQA's PrepareQKV always runs LaunchUnpackRopeAppend which transposes BSNH->BNSH into the destination buffer. Flash attention for GQA reads from BNSH buffer. So, we need a buffer of size BNSH.
To eliminate this, we need to make GQA's flash attention accept BSNH input directly (which is not relevant to this PR).
I can implement it in another PR
| // When past exists, the attention GEMMs use total_seqlen which requires a | ||
| // concatenated past+current KV buffer built by ConcatStateChunkGQA into present. | ||
| if ((present_k == nullptr || present_v == nullptr) && !parameters.is_first_prompt) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, |
There was a problem hiding this comment.
The is_first_prompt guard (sequence_length == total_sequence_length) is too restrictive for Gemma 4's KV-shared layers.
During decode steps: the KV-shared layer receives the source layer's full KV (length = total_seq_len) while Q has length 1. With this guard, is_first_prompt = false → rejection fires → the feature only works for prefill when Q and KV happen to have equal length.
The real safety condition is whether past KV exists and needs concatenation:
if ((present_k == nullptr || present_v == nullptr) && past_key != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"present_key and present_value outputs are required when past_key is provided.");
}Note: Using MHA instead is not viable (no kv_num_heads, no seqlens_k, no sliding window/rotary/KV quantization). Modifying GQA is correct, but the constraint needs to be relaxed to support decode.
There was a problem hiding this comment.
The guard is relaxed now, but the intended decode shape still does not get through input validation. Check_Q_K_V still requires query_dims[1] == key_dims[1] and query_dims[1] == value_dims[1], so a KV-shared decode call with Q length 1 and borrowed K/V length total_sequence_length is rejected before it reaches this code. Please either relax the shared validation and thread a separate KV sequence length through transpose/rotary/attention sizing, or explicitly scope this PR to prefill-only behavior.
| } | ||
| EXPECT_FALSE(all_zero) << "CUDA output should not be all zeros"; | ||
| } | ||
|
|
There was a problem hiding this comment.
Add a test case of query_sequence_length != kv_sequence_length (i.e. sequence_length != total_sequence_length). That is for decoding case.
There was a problem hiding this comment.
This decode-coverage test is still needed on the latest head. The current tests all keep total_seq_len == sequence_length; they do not exercise the KV-shared decode shape where Q has length 1 and borrowed K/V has the full cache length. That case should currently fail in Check_Q_K_V before compute, so adding this test would catch the remaining gap.
| "kv_sequence_length.", | ||
| "T_CACHE") | ||
| "T_CACHE", | ||
| OpSchema::Optional) |
There was a problem hiding this comment.
What is the desired op schema for the GQA ops in the shared KV cache layers after marking the present KV cache outputs as optional?
- Option 1: The existing KV caches are passed in directly as the past KV cache with empty KV inputs.
- Option 2: The existing KV caches are passed in directly as the KV inputs with empty past/present KV caches. This would be similar to how cross-attention is handled in
MultiHeadAttentionfor Whisper where the past KV caches are passed in directly.
There was a problem hiding this comment.
It shall be option 2. If kv cache need concatenation, it shall be passed as past_*. Since there is no concatenation in this case, it can be directly passed as K/V.
tianleiwu
left a comment
There was a problem hiding this comment.
Thanks for the follow-up changes. I found two remaining current-head correctness issues in the KV-shared decode path. Both come from treating borrowed full-context K/V as though they were newly appended after a local past cache, which can lead to invalid memory access or incomplete KV contents. I am also not duplicating the existing decode-coverage thread, but CUDA still needs a q_seq_len != kv_seq_len regression because the current CUDA test only covers equal Q/K/V lengths.
| const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H | ||
|
|
||
| if (!past_present_share_buffer) { | ||
| if (present_key && !past_present_share_buffer) { |
There was a problem hiding this comment.
This still treats KV-shared decode with no local past as if there were a past prefix whenever present_key is connected. For the new decode shape (sequence_length == 1, kv_sequence_length == total_sequence_length, past_key == nullptr), is_first_prompt is false, so past_seqlen becomes total_seqlen - sequence_length. ConcatStateChunkGQA then copies past_chunk_length from the null past_key pointer and appends the full K input after that synthetic offset, which can read null/invalid memory and overflow the present buffer. Please distinguish actual-past decode from borrowed-KV/no-past decode here; the no-past case should copy K/V from offset 0 (or reject connected present outputs for that mode if omitted-only is the intended contract).
| // uses a single sequence_length for its thread grid. | ||
| if (kv_sequence_length != sequence_length) { | ||
| // Process Q only (sequence_length tokens, no K/V) | ||
| ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend<T, U>( |
There was a problem hiding this comment.
The split launch passes null K/V pointers for the Q-only pass, but LaunchUnpackRoPEAppend still builds a grid over num_heads + 2 * kv_num_heads and its kernel unconditionally loads query, key, or value based on head_idx. That means this launch can dereference null key/value, and the second launch can dereference null query. Even with guarded loads, the K/V pass would still append at past_seq_lens[b] == total_len - query_seq_len for borrowed full-context K/V, so it would only populate the tail of the cache. Please add a launcher/kernel mode that dispatches only the requested head types and uses zero append offset for no-past borrowed K/V, or route this case through a path that consumes K/V directly instead of treating them as newly appended cache entries.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // When kv_sequence_length differs from sequence_length (KV-shared decode), | ||
| // we must call LaunchUnpackRoPEAppend separately for Q and K/V since the kernel | ||
| // uses a single sequence_length for its thread grid. | ||
| if (kv_sequence_length != sequence_length) { | ||
| // Process Q only (sequence_length tokens, no K/V) | ||
| ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend<T, U>( | ||
| nullptr, // no packed_qkv | ||
| reinterpret_cast<const T*>(data.query), // Q input | ||
| nullptr, // no K | ||
| nullptr, // no V | ||
| q_out, k, v, data.k_scale, data.v_scale, | ||
| num_heads, kv_num_heads, head_size, sequence_length, batch_size, | ||
| max_cache_length, data.past_seq_lens, | ||
| reinterpret_cast<const T*>(data.cos_cache), reinterpret_cast<const T*>(data.sin_cache), | ||
| parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, | ||
| is_cache_bnsh, parameters.k_quant_type, | ||
| stream, max_threads_per_block))); | ||
|
|
||
| // Process K/V only (kv_sequence_length tokens, no Q) | ||
| ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend<T, U>( | ||
| nullptr, // no packed_qkv | ||
| nullptr, // no Q | ||
| reinterpret_cast<const T*>(data.key), // K input | ||
| reinterpret_cast<const T*>(data.value), // V input | ||
| nullptr, k, v, data.k_scale, data.v_scale, // no Q output | ||
| num_heads, kv_num_heads, head_size, kv_sequence_length, batch_size, | ||
| max_cache_length, data.past_seq_lens, | ||
| reinterpret_cast<const T*>(data.cos_cache), reinterpret_cast<const T*>(data.sin_cache), | ||
| parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, | ||
| is_cache_bnsh, parameters.k_quant_type, | ||
| stream, max_threads_per_block))); | ||
| } else { |
| GroupQueryAttentionParameters* output_parameters = reinterpret_cast<GroupQueryAttentionParameters*>(parameters); | ||
| output_parameters->batch_size = batch_size; | ||
| output_parameters->sequence_length = sequence_length; // sequence length of Q | ||
| output_parameters->kv_sequence_length = kv_sequence_length; // sequence length of K/V inputs |
| // Regression (CUDA): omitting present outputs on CUDA EP must produce the same | ||
| // attention output as when present outputs are connected. The CUDA path allocates | ||
| // internal scratch buffers to serve as KV workspace for flash/MEA/unfused kernels. | ||
| TEST(GroupQueryAttentionTest, OptionalPresent_CudaOmitMatchesConnected) { |
| } | ||
| } | ||
|
|
||
| // KV-shared first-prompt: longer sequence with no past, present omitted. |
| // KV-shared decode: Q has length 1, K/V have full context length (kv_seq_len=8), | ||
| // no past, present omitted. Verifies that the output matches when present outputs | ||
| // are connected vs omitted for the decode shape. | ||
| TEST(GroupQueryAttentionTest, OptionalPresent_KVSharedDecode) { |
[GQA] Make present_key/present_value outputs optional and support KV-shared decode
Summary
Marks
present_keyandpresent_valueoutputs ofGroupQueryAttentionasOpSchema::Optionaland threads a separatekv_sequence_lengththrough validation and compute, enabling KV-shared layers (e.g., Gemma 4) to omit present outputs during both prefill and decode.Motivation
Gemma 4 has 20 KV-shared decoder layers that borrow K,V from a source layer's present outputs. These layers don't maintain their own KV cache — they only compute attention using borrowed KV. Making present outputs optional avoids allocating redundant KV tensors for these layers.
During decode, Q has length 1 while borrowed K/V has the full context length. This required relaxing the existing
query.dim[1] == key.dim[1]constraint and threadingkv_sequence_lengththrough the entire compute pipeline.Changes
Schema (
bert_defs.cc):Output(1, "present_key")andOutput(2, "present_value")markedOpSchema::OptionalShared validation (
group_query_attention_helper.h):Check_Q_K_Vrelaxed:key.dim[1]no longer required to equalquery.dim[1]; value validated against key (key.dim[1] == value.dim[1])kv_sequence_lengthoutput parameter, set inCheckInputsasparameters.kv_sequence_lengthpresent_key/present_valuemust be both provided or both omittedpast_keyis provided (KV concatenation required)CPU kernel (
group_query_attention.cc,gqa_attention_base.h):MaybeTransposeToBNSHuseskv_sequence_lengthkv_input_chunk_lengthinComputeAttentionProbsandComputeVxAttentionScoreuseskv_sequence_length(not Q'ssequence_length)seqlen_present_kv_cachefalls back toparameters.total_sequence_lengthwhenpresent_keyis nullptrmemsetcalls with nullptr checksCUDA kernel (
group_query_attention.cc,group_query_attention_impl.cu):GetScratchBuffer<void>()when present outputs are omittedPrepareQKV: whenkv_sequence_length != sequence_length, callsLaunchUnpackRoPEAppendtwice — once for Q only (sequence_length) and once for K/V only (kv_sequence_length). Standard path unchanged.WebGPU kernel (
group_query_attention.cc):static_cast<int>→onnxruntime::narrow<int>()for checked narrowing of device limitsTests (
group_query_attention_op_test.cc):OptionalPresent_OmittingDoesNotChangeOutputEXPECT_NEAR(1e-5))OptionalPresent_BatchedOmitMatchesConnectedbatch_size > 1OptionalPresent_KVSharedFirstPromptseq_len=8, no past, present omittedOptionalPresent_KVSharedDecodeOptionalPresent_RejectWithPastKeypast_keyprovided with omitted presentOptionalPresent_CudaOmitMatchesConnectedConstraints
past_key/past_value(no KV concatenation needed)present_keyandpresent_valuemust be both provided or both omittedkey.dim[1]must equalvalue.dim[1](K/V sequence lengths must match)