Skip to content

[GQA] Make present_key/present_value outputs optional and add Gemma4 support#28242

Open
apsonawane wants to merge 19 commits intomainfrom
asonawane/gemma4
Open

[GQA] Make present_key/present_value outputs optional and add Gemma4 support#28242
apsonawane wants to merge 19 commits intomainfrom
asonawane/gemma4

Conversation

@apsonawane
Copy link
Copy Markdown
Contributor

@apsonawane apsonawane commented Apr 27, 2026

[GQA] Make present_key/present_value outputs optional and support KV-shared decode

Summary

Marks present_key and present_value outputs of GroupQueryAttention as OpSchema::Optional and threads a separate kv_sequence_length through validation and compute, enabling KV-shared layers (e.g., Gemma 4) to omit present outputs during both prefill and decode.

Motivation

Gemma 4 has 20 KV-shared decoder layers that borrow K,V from a source layer's present outputs. These layers don't maintain their own KV cache — they only compute attention using borrowed KV. Making present outputs optional avoids allocating redundant KV tensors for these layers.

During decode, Q has length 1 while borrowed K/V has the full context length. This required relaxing the existing query.dim[1] == key.dim[1] constraint and threading kv_sequence_length through the entire compute pipeline.

Changes

Schema (bert_defs.cc):

  • Output(1, "present_key") and Output(2, "present_value") marked OpSchema::Optional

Shared validation (group_query_attention_helper.h):

  • Check_Q_K_V relaxed: key.dim[1] no longer required to equal query.dim[1]; value validated against key (key.dim[1] == value.dim[1])
  • Added kv_sequence_length output parameter, set in CheckInputs as parameters.kv_sequence_length
  • present_key/present_value must be both provided or both omitted
  • Omitting present rejected when past_key is provided (KV concatenation required)

CPU kernel (group_query_attention.cc, gqa_attention_base.h):

  • K/V transpose via MaybeTransposeToBNSH uses kv_sequence_length
  • kv_input_chunk_length in ComputeAttentionProbs and ComputeVxAttentionScore uses kv_sequence_length (not Q's sequence_length)
  • seqlen_present_kv_cache falls back to parameters.total_sequence_length when present_key is nullptr
  • Guarded memset calls with nullptr checks

CUDA kernel (group_query_attention.cc, group_query_attention_impl.cu):

  • Allocates internal scratch buffers via GetScratchBuffer<void>() when present outputs are omitted
  • PrepareQKV: when kv_sequence_length != sequence_length, calls LaunchUnpackRoPEAppend twice — once for Q only (sequence_length) and once for K/V only (kv_sequence_length). Standard path unchanged.
  • Scratch freed after kernel execution (no persistent allocation)

WebGPU kernel (group_query_attention.cc):

  • Rejects omitted present outputs with actionable error message (WebGPU flash attention requires present buffers)
  • static_cast<int>onnxruntime::narrow<int>() for checked narrowing of device limits

Tests (group_query_attention_op_test.cc):

Test What it verifies
OptionalPresent_OmittingDoesNotChangeOutput CPU output identical with/without present (deterministic inputs, EXPECT_NEAR(1e-5))
OptionalPresent_BatchedOmitMatchesConnected Same with batch_size > 1
OptionalPresent_KVSharedFirstPrompt Prefill: seq_len=8, no past, present omitted
OptionalPresent_KVSharedDecode Decode: Q_seq=1, KV_seq=8, no past, present omitted — equivalence test
OptionalPresent_RejectWithPastKey Rejection when past_key provided with omitted present
OptionalPresent_CudaOmitMatchesConnected CUDA-gated equivalence test (skips when CUDA unavailable)

Constraints

  • Omitting present outputs requires no past_key/past_value (no KV concatenation needed)
  • present_key and present_value must be both provided or both omitted
  • key.dim[1] must equal value.dim[1] (K/V sequence lengths must match)
  • WebGPU EP requires present outputs (flash attention dependency)
  • CPU and CUDA EPs support both prefill and decode with omitted present

Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/cpu/bert/attention_parameters.h Outdated
@apsonawane apsonawane requested a review from Copilot April 27, 2026 21:05
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds support for KV-shared decoder layers by allowing com.microsoft.GroupQueryAttention to optionally consume pre-computed external K/V tensors (instead of maintaining/updating its own KV cache), enabling architectures like Gemma4-style KV sharing.

Changes:

  • Extended the GroupQueryAttention schema with optional inputs external_key / external_value (indices 14/15).
  • Added new parameters + validation helpers to detect/configure “external KV” mode and enforce do_rotary=0.
  • Updated CPU and CUDA kernels to source KV from external tensors and bypass KV-cache update / RoPE-on-KV paths.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
onnxruntime/core/graph/contrib_ops/bert_defs.cc Adds schema inputs for external KV (but type/shape inference also needs external-KV awareness).
onnxruntime/contrib_ops/cpu/bert/attention_parameters.h Adds use_external_kv and external_kv_sequence_length to GQA parameters.
onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h Adds Q-only checks and external-KV shape validation/configuration helpers; updates CheckInputs to distinguish packed-QKV vs Q-only.
onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Plumbs external KV inputs into CPU kernel and skips K/V transpose + rotary when external KV is used.
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h Updates CPU attention core to skip KV concatenation and copy external KV to present outputs once per KV head.
onnxruntime/contrib_ops/cuda/bert/attention_data.h Adds external KV pointers to CUDA attention data struct.
onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc Plumbs external KV inputs into CUDA kernel and enforces do_rotary=0 for external KV mode.
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu Adds external-KV path in PrepareQKV (copy external KV to present and skip append/RoPE).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h Outdated
Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc Outdated
Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu Outdated
Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Outdated
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to add extra inputs, you can use key/value for that, and make past_key/past_value/present_key/present_value as optional.

@apsonawane apsonawane requested review from Copilot and tianleiwu April 29, 2026 17:09
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
Comment thread onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu Outdated
Comment thread docs/ContribOperators.md
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found one additional correctness issue on the current head. There are also already-open current-head threads covering the optional-present tests and the CUDA/documentation mismatch, so I am not duplicating those here.

Comment thread onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 7 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Comment thread onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Outdated
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of my concerns on this head are already covered by existing review threads (CPU GEMM/concat invariant, CUDA hard-error vs. optional schema, mixed-output configurations, scratch buffer sizing on CUDA, seqlen_present_kv_cache initialization, test tolerance). Two items I did not see covered:

1. PR description does not describe this PR. The description discusses adding external_key/external_value inputs (inputs 14/15), Check_Q_Only, CheckExternalKV, use_external_kv, CUDA attention_data.h fields, etc. None of that appears in this diff — the actual change makes existing present_key/present_value outputs optional. Please rewrite the description to match the implementation; downstream tooling and release notes rely on it.

2. CUDA test coverage is missing. All new tests use DefaultCpuExecutionProvider() only. The CUDA path has its own contracts added in this PR — the early guard in group_query_attention.cc (claims first-prompt is supported) and the unconditional rejection in PrepareQKV (group_query_attention_impl.cu). These two messages contradict each other, but no CUDA test exercises omit_present=true, so the contradiction is invisible to CI. Please add at least one CUDA-gated negative test that asserts the kernel rejects omitted present outputs with a stable error message (or, if the intent is to support it on CUDA, a positive equivalence test). Otherwise a future refactor of PrepareQKV can silently change the user-facing error path.

No new inline comments — the existing threads already pinpoint the code locations.

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review of head 6cbe62c

The two concerns I raised on 9a14803a have been addressed:

  • PR description body now accurately describes the optional-present-output change.
  • CUDA test (OptionalPresent_CudaOmitMatchesConnected) was added.

The is_first_prompt validation in both CPU and CUDA kernels properly constrains the omitted-present case, and the scratch-buffer approach on CUDA correctly keeps data.present_key/data.present_value non-null for downstream kernels. I resolved my earlier thread on gqa_attention_base.h line 177.

One remaining concern:

WebGPU EP still has no guard for omitted present outputs. The schema relaxation (OpSchema::Optional on outputs 1/2) is global across all EPs. The WebGPU kernel (onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc ~line 240) calls context.Output(1, present_kv_shape) and then passes the result to ApplyFlashAttention/ApplyAttention without a nullptr check. A model that omits these outputs will crash at runtime on WebGPU. Please add a validation that rejects present_key == nullptr || present_value == nullptr with a clear message (e.g., "WebGPU GroupQueryAttention requires present_key and present_value outputs").

Nitpick: The PR title still says "Add external_key/external_value inputs" — please update to match the description body.

@apsonawane apsonawane changed the title Add external_key/external_value inputs to GroupQueryAttention for KV-shared layers [GQA] Make present_key/present_value outputs optional and add Gemma4 support May 1, 2026
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

The approach of making present_key/present_value optional in GQA is the right choice for KV-shared layers (MHA is not viable since it lacks kv_num_heads, seqlens_k, sliding window, rotary, and KV quantization). The CPU null-guards and WebGPU rejection are well done. However, two design issues need addressing:

1. CUDA scratch allocation negates memory savings — Allocating full-size scratch buffers (present_shape.Size() * sizeof(U) × 2) when present outputs are omitted defeats the purpose. MHA's CUDA PrepareQkv_MHA_NoPast path handles the equivalent scenario by using K/V input buffers directly for flash/MEA kernels — zero scratch allocation. GQA should follow the same pattern: skip scratch allocation and modify PrepareQKV in group_query_attention_impl.cu to use K/V input data directly (or the rotary output buffer if rotary is applied).

2. is_first_prompt constraint too restrictive for Gemma 4 — The guard rejects omitting present when sequence_length != total_sequence_length. But in KV-shared layers:

  • During decode: Q has length 1, but borrowed KV has length = total_seq_len → is_first_prompt = false → rejection fires
  • During prefill with different Q/KV lengths: same issue

This means the feature only works for prefill when Q and KV happen to have equal length — blocking the decode use case entirely. The real safety condition is past_key != nullptr (past KV exists and must be concatenated into present), not is_first_prompt. MHA supports query_length != key_length without KV cache for the same pattern.

Other notes:

  • past_present_share_buffer evaluates to true when both past and present are nullptr (no bug, but semantically misleading — consider adding present_key_data != nullptr && prefix)
  • Please confirm deleted external KV tests (ExternalKV_*) are no longer needed (feature fully removed)

IAllocatorUniquePtr<void> present_key_scratch;
IAllocatorUniquePtr<void> present_value_scratch;
if (present_key_output == nullptr || present_value_output == nullptr) {
size_t present_kv_bytes = present_shape.Size() * sizeof(U);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This scratch allocation (present_shape.Size() * sizeof(U) × 2) negates the memory savings of omitting present outputs.

MHA's CUDA PrepareQkv_MHA_NoPast path uses K/V input buffers directly (data.k = const_cast<T*>(data.key)) for flash/MEA kernels when there's no past/present — zero scratch allocation.

Since omitting present is only allowed with no past, GQA should follow the same pattern: skip scratch allocation and modify PrepareQKV in group_query_attention_impl.cu to use the K/V input data directly (or the rotary output buffer if rotary is applied) rather than allocating a separate buffer that duplicates the input.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GQA's PrepareQKV always runs LaunchUnpackRopeAppend which transposes BSNH->BNSH into the destination buffer. Flash attention for GQA reads from BNSH buffer. So, we need a buffer of size BNSH.
To eliminate this, we need to make GQA's flash attention accept BSNH input directly (which is not relevant to this PR).
I can implement it in another PR

// When past exists, the attention GEMMs use total_seqlen which requires a
// concatenated past+current KV buffer built by ConcatStateChunkGQA into present.
if ((present_k == nullptr || present_v == nullptr) && !parameters.is_first_prompt) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The is_first_prompt guard (sequence_length == total_sequence_length) is too restrictive for Gemma 4's KV-shared layers.

During decode steps: the KV-shared layer receives the source layer's full KV (length = total_seq_len) while Q has length 1. With this guard, is_first_prompt = false → rejection fires → the feature only works for prefill when Q and KV happen to have equal length.

The real safety condition is whether past KV exists and needs concatenation:

if ((present_k == nullptr || present_v == nullptr) && past_key != nullptr) {
  return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
                         "present_key and present_value outputs are required when past_key is provided.");
}

Note: Using MHA instead is not viable (no kv_num_heads, no seqlens_k, no sliding window/rotary/KV quantization). Modifying GQA is correct, but the constraint needs to be relaxed to support decode.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The guard is relaxed now, but the intended decode shape still does not get through input validation. Check_Q_K_V still requires query_dims[1] == key_dims[1] and query_dims[1] == value_dims[1], so a KV-shared decode call with Q length 1 and borrowed K/V length total_sequence_length is rejected before it reaches this code. Please either relax the shared validation and thread a separate KV sequence length through transpose/rotary/attention sizing, or explicitly scope this PR to prefill-only behavior.

}
EXPECT_FALSE(all_zero) << "CUDA output should not be all zeros";
}

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu May 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test case of query_sequence_length != kv_sequence_length (i.e. sequence_length != total_sequence_length). That is for decoding case.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This decode-coverage test is still needed on the latest head. The current tests all keep total_seq_len == sequence_length; they do not exercise the KV-shared decode shape where Q has length 1 and borrowed K/V has the full cache length. That case should currently fail in Check_Q_K_V before compute, so adding this test would catch the remaining gap.

"kv_sequence_length.",
"T_CACHE")
"T_CACHE",
OpSchema::Optional)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the desired op schema for the GQA ops in the shared KV cache layers after marking the present KV cache outputs as optional?

  • Option 1: The existing KV caches are passed in directly as the past KV cache with empty KV inputs.
  • Option 2: The existing KV caches are passed in directly as the KV inputs with empty past/present KV caches. This would be similar to how cross-attention is handled in MultiHeadAttention for Whisper where the past KV caches are passed in directly.

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu May 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shall be option 2. If kv cache need concatenation, it shall be passed as past_*. Since there is no concatenation in this case, it can be directly passed as K/V.

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the follow-up changes. I found two remaining current-head correctness issues in the KV-shared decode path. Both come from treating borrowed full-context K/V as though they were newly appended after a local past cache, which can lead to invalid memory access or incomplete KV contents. I am also not duplicating the existing decode-coverage thread, but CUDA still needs a q_seq_len != kv_seq_len regression because the current CUDA test only covers equal Q/K/V lengths.

const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H

if (!past_present_share_buffer) {
if (present_key && !past_present_share_buffer) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still treats KV-shared decode with no local past as if there were a past prefix whenever present_key is connected. For the new decode shape (sequence_length == 1, kv_sequence_length == total_sequence_length, past_key == nullptr), is_first_prompt is false, so past_seqlen becomes total_seqlen - sequence_length. ConcatStateChunkGQA then copies past_chunk_length from the null past_key pointer and appends the full K input after that synthetic offset, which can read null/invalid memory and overflow the present buffer. Please distinguish actual-past decode from borrowed-KV/no-past decode here; the no-past case should copy K/V from offset 0 (or reject connected present outputs for that mode if omitted-only is the intended contract).

// uses a single sequence_length for its thread grid.
if (kv_sequence_length != sequence_length) {
// Process Q only (sequence_length tokens, no K/V)
ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend<T, U>(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The split launch passes null K/V pointers for the Q-only pass, but LaunchUnpackRoPEAppend still builds a grid over num_heads + 2 * kv_num_heads and its kernel unconditionally loads query, key, or value based on head_idx. That means this launch can dereference null key/value, and the second launch can dereference null query. Even with guarded loads, the K/V pass would still append at past_seq_lens[b] == total_len - query_seq_len for borrowed full-context K/V, so it would only populate the tail of the cache. Please add a launcher/kernel mode that dispatches only the requested head types and uses zero append offset for no-past borrowed K/V, or route this case through a path that consumes K/V directly instead of treating them as newly appended cache entries.

@tianleiwu tianleiwu requested a review from Copilot May 4, 2026 22:43
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +127 to +158
// When kv_sequence_length differs from sequence_length (KV-shared decode),
// we must call LaunchUnpackRoPEAppend separately for Q and K/V since the kernel
// uses a single sequence_length for its thread grid.
if (kv_sequence_length != sequence_length) {
// Process Q only (sequence_length tokens, no K/V)
ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend<T, U>(
nullptr, // no packed_qkv
reinterpret_cast<const T*>(data.query), // Q input
nullptr, // no K
nullptr, // no V
q_out, k, v, data.k_scale, data.v_scale,
num_heads, kv_num_heads, head_size, sequence_length, batch_size,
max_cache_length, data.past_seq_lens,
reinterpret_cast<const T*>(data.cos_cache), reinterpret_cast<const T*>(data.sin_cache),
parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved,
is_cache_bnsh, parameters.k_quant_type,
stream, max_threads_per_block)));

// Process K/V only (kv_sequence_length tokens, no Q)
ORT_RETURN_IF_ERROR((LaunchUnpackRoPEAppend<T, U>(
nullptr, // no packed_qkv
nullptr, // no Q
reinterpret_cast<const T*>(data.key), // K input
reinterpret_cast<const T*>(data.value), // V input
nullptr, k, v, data.k_scale, data.v_scale, // no Q output
num_heads, kv_num_heads, head_size, kv_sequence_length, batch_size,
max_cache_length, data.past_seq_lens,
reinterpret_cast<const T*>(data.cos_cache), reinterpret_cast<const T*>(data.sin_cache),
parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved,
is_cache_bnsh, parameters.k_quant_type,
stream, max_threads_per_block)));
} else {
Comment on lines 313 to +316
GroupQueryAttentionParameters* output_parameters = reinterpret_cast<GroupQueryAttentionParameters*>(parameters);
output_parameters->batch_size = batch_size;
output_parameters->sequence_length = sequence_length; // sequence length of Q
output_parameters->kv_sequence_length = kv_sequence_length; // sequence length of K/V inputs
Comment on lines +594 to +597
// Regression (CUDA): omitting present outputs on CUDA EP must produce the same
// attention output as when present outputs are connected. The CUDA path allocates
// internal scratch buffers to serve as KV workspace for flash/MEA/unfused kernels.
TEST(GroupQueryAttentionTest, OptionalPresent_CudaOmitMatchesConnected) {
}
}

// KV-shared first-prompt: longer sequence with no past, present omitted.
// KV-shared decode: Q has length 1, K/V have full context length (kv_seq_len=8),
// no past, present omitted. Verifies that the output matches when present outputs
// are connected vs omitted for the decode shape.
TEST(GroupQueryAttentionTest, OptionalPresent_KVSharedDecode) {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

4 participants