Skip to content

Feature request: GroupQueryAttention support for KV-shared layers (Gemma4-style) #28188

@justinchuby

Description

@justinchuby

Summary

Gemma4 (and likely future architectures) uses KV-shared layers where a subset of decoder layers borrow Key/Value tensors from earlier "source" layers instead of computing their own K,V projections. For example, Gemma4 E2B has 35 layers total — 15 non-shared layers (with own K,V projections and KV cache) and 20 KV-shared layers that reuse K,V from the last matching non-shared layer.

Currently, com.microsoft.GroupQueryAttention cannot serve KV-shared layers because:

  1. GQA assumes own KV cache: It takes past_key/past_value and produces present_key/present_value. KV-shared layers have no cache of their own — they reference another layer's present KV output.
  2. No external KV input mode: There is no way to pass pre-computed, already-cached K,V tensors to GQA without also running its internal KV cache update logic.

This forces KV-shared layers to fall back to the standard Attention op, which:

  • Lacks the local_window_size attribute (needed for sliding-window KV-shared layers)
  • Requires an explicit attention bias mask built with CumSum/GreaterOrEqual/Where — all CPU-only ops on CUDA EP that cause Memcpy nodes

Current workaround (in mobius ONNX model builder)

We emit a mixed graph: GQA for non-shared layers, standard Attention for KV-shared layers.

Layer 0-14 (non-shared): GroupQueryAttention (with local_window_size for sliding)
Layer 15-34 (KV-shared): Attention + explicit bool mask (create_sliding_window_mask / create_padding_mask)

For full-attention KV-shared layers, we use create_padding_mask() with is_causal=1 — this avoids most CPU ops. But for sliding-window KV-shared layers, the mask still requires CumSum on INT64 attention_mask, which is CPU-only in CUDA EP and causes Memcpy.

Concrete Memcpy impact

For Gemma4 E2B (google/gemma-4-E2B-it, 35 layers, 20 KV-shared):

Source CPU ops Memcpy cause
Sliding-window mask for 16 KV-shared sliding layers CumSum, Less, And INT64 CumSum is CPU-only
KV tensor reshape (4D→3D for Attention op) Previously Shape ops (now fixed with [0,0,-1]) Eliminated

Suggestions

Option A: Add external_key/external_value inputs to GQA

Allow GQA to accept pre-computed K,V tensors (from another layer's present_key/present_value) instead of its own past_key/past_value. In this mode, GQA would:

  • Skip its internal KV cache concatenation
  • Apply RoPE only to Q (K already has RoPE from the source layer)
  • Use local_window_size for sliding-window masking
  • Not produce present_key/present_value outputs (or pass through the external ones)

This would let KV-shared layers use GQA with full sliding-window support and zero CPU Memcpy.

Option B: Add local_window_size to the standard Attention op

If GQA modification is too invasive, adding a local_window_size attribute to the standard Attention op would let KV-shared sliding layers avoid the explicit CumSum-based mask. Combined with is_causal=1, this would eliminate most CPU ops.

Option C: Add do_rotary=0 mode to GQA with external KV

A simpler variant of Option A: GQA with do_rotary=0 would skip RoPE entirely (caller pre-applies it to Q, K already has RoPE from source). The KV-shared layer would:

  1. Compute Q, apply RoPE manually
  2. Read source layer's present_key/present_value (already RoPE-applied, BNSH format)
  3. Pass Q + external K,V to GQA with do_rotary=0 and local_window_size

Architecture reference

Gemma4 E2B: 35 layers
├── Layers 0-14: Non-shared (own Q,K,V projections + KV cache)
│   ├── 12 sliding-attention layers (local_window_size=512)
│   └── 3 full-attention layers
└── Layers 15-34: KV-shared (own Q projection only, borrow K,V from source)
    ├── 16 sliding-attention layers → borrow from layer 13
    └── 4 full-attention layers → borrow from layer 14

HuggingFace config field: num_kv_shared_layers controls how many trailing layers share KV.

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions