Summary
Gemma4 (and likely future architectures) uses KV-shared layers where a subset of decoder layers borrow Key/Value tensors from earlier "source" layers instead of computing their own K,V projections. For example, Gemma4 E2B has 35 layers total — 15 non-shared layers (with own K,V projections and KV cache) and 20 KV-shared layers that reuse K,V from the last matching non-shared layer.
Currently, com.microsoft.GroupQueryAttention cannot serve KV-shared layers because:
- GQA assumes own KV cache: It takes
past_key/past_value and produces present_key/present_value. KV-shared layers have no cache of their own — they reference another layer's present KV output.
- No external KV input mode: There is no way to pass pre-computed, already-cached K,V tensors to GQA without also running its internal KV cache update logic.
This forces KV-shared layers to fall back to the standard Attention op, which:
- Lacks the
local_window_size attribute (needed for sliding-window KV-shared layers)
- Requires an explicit attention bias mask built with
CumSum/GreaterOrEqual/Where — all CPU-only ops on CUDA EP that cause Memcpy nodes
Current workaround (in mobius ONNX model builder)
We emit a mixed graph: GQA for non-shared layers, standard Attention for KV-shared layers.
Layer 0-14 (non-shared): GroupQueryAttention (with local_window_size for sliding)
Layer 15-34 (KV-shared): Attention + explicit bool mask (create_sliding_window_mask / create_padding_mask)
For full-attention KV-shared layers, we use create_padding_mask() with is_causal=1 — this avoids most CPU ops. But for sliding-window KV-shared layers, the mask still requires CumSum on INT64 attention_mask, which is CPU-only in CUDA EP and causes Memcpy.
Concrete Memcpy impact
For Gemma4 E2B (google/gemma-4-E2B-it, 35 layers, 20 KV-shared):
| Source |
CPU ops |
Memcpy cause |
| Sliding-window mask for 16 KV-shared sliding layers |
CumSum, Less, And |
INT64 CumSum is CPU-only |
| KV tensor reshape (4D→3D for Attention op) |
Previously Shape ops (now fixed with [0,0,-1]) |
Eliminated |
Suggestions
Option A: Add external_key/external_value inputs to GQA
Allow GQA to accept pre-computed K,V tensors (from another layer's present_key/present_value) instead of its own past_key/past_value. In this mode, GQA would:
- Skip its internal KV cache concatenation
- Apply RoPE only to Q (K already has RoPE from the source layer)
- Use
local_window_size for sliding-window masking
- Not produce
present_key/present_value outputs (or pass through the external ones)
This would let KV-shared layers use GQA with full sliding-window support and zero CPU Memcpy.
Option B: Add local_window_size to the standard Attention op
If GQA modification is too invasive, adding a local_window_size attribute to the standard Attention op would let KV-shared sliding layers avoid the explicit CumSum-based mask. Combined with is_causal=1, this would eliminate most CPU ops.
Option C: Add do_rotary=0 mode to GQA with external KV
A simpler variant of Option A: GQA with do_rotary=0 would skip RoPE entirely (caller pre-applies it to Q, K already has RoPE from source). The KV-shared layer would:
- Compute Q, apply RoPE manually
- Read source layer's
present_key/present_value (already RoPE-applied, BNSH format)
- Pass Q + external K,V to GQA with
do_rotary=0 and local_window_size
Architecture reference
Gemma4 E2B: 35 layers
├── Layers 0-14: Non-shared (own Q,K,V projections + KV cache)
│ ├── 12 sliding-attention layers (local_window_size=512)
│ └── 3 full-attention layers
└── Layers 15-34: KV-shared (own Q projection only, borrow K,V from source)
├── 16 sliding-attention layers → borrow from layer 13
└── 4 full-attention layers → borrow from layer 14
HuggingFace config field: num_kv_shared_layers controls how many trailing layers share KV.
Related
Summary
Gemma4 (and likely future architectures) uses KV-shared layers where a subset of decoder layers borrow Key/Value tensors from earlier "source" layers instead of computing their own K,V projections. For example, Gemma4 E2B has 35 layers total — 15 non-shared layers (with own K,V projections and KV cache) and 20 KV-shared layers that reuse K,V from the last matching non-shared layer.
Currently,
com.microsoft.GroupQueryAttentioncannot serve KV-shared layers because:past_key/past_valueand producespresent_key/present_value. KV-shared layers have no cache of their own — they reference another layer's present KV output.This forces KV-shared layers to fall back to the standard
Attentionop, which:local_window_sizeattribute (needed for sliding-window KV-shared layers)CumSum/GreaterOrEqual/Where— all CPU-only ops on CUDA EP that causeMemcpynodesCurrent workaround (in mobius ONNX model builder)
We emit a mixed graph: GQA for non-shared layers, standard Attention for KV-shared layers.
For full-attention KV-shared layers, we use
create_padding_mask()withis_causal=1— this avoids most CPU ops. But for sliding-window KV-shared layers, the mask still requiresCumSumon INT64 attention_mask, which is CPU-only in CUDA EP and causes Memcpy.Concrete Memcpy impact
For Gemma4 E2B (
google/gemma-4-E2B-it, 35 layers, 20 KV-shared):CumSum,Less,AndShapeops (now fixed with[0,0,-1])Suggestions
Option A: Add
external_key/external_valueinputs to GQAAllow GQA to accept pre-computed K,V tensors (from another layer's
present_key/present_value) instead of its ownpast_key/past_value. In this mode, GQA would:local_window_sizefor sliding-window maskingpresent_key/present_valueoutputs (or pass through the external ones)This would let KV-shared layers use GQA with full sliding-window support and zero CPU Memcpy.
Option B: Add
local_window_sizeto the standard Attention opIf GQA modification is too invasive, adding a
local_window_sizeattribute to the standardAttentionop would let KV-shared sliding layers avoid the explicit CumSum-based mask. Combined withis_causal=1, this would eliminate most CPU ops.Option C: Add
do_rotary=0mode to GQA with external KVA simpler variant of Option A: GQA with
do_rotary=0would skip RoPE entirely (caller pre-applies it to Q, K already has RoPE from source). The KV-shared layer would:present_key/present_value(already RoPE-applied, BNSH format)do_rotary=0andlocal_window_sizeArchitecture reference
HuggingFace config field:
num_kv_shared_layerscontrols how many trailing layers share KV.Related