Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2671,14 +2671,14 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Scale tensor for past_value.</dd>
</dl>

#### Outputs (3 - 4)
#### Outputs (1 - 4)

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>present_key</tt> : T_CACHE</dt>
<dt><tt>present_key</tt> (optional) : T_CACHE</dt>
<dd>present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
<dt><tt>present_value</tt> : T_CACHE</dt>
<dt><tt>present_value</tt> (optional) : T_CACHE</dt>
<dd>present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
Comment thread
tianleiwu marked this conversation as resolved.
<dt><tt>output_qk</tt> (optional) : T</dt>
<dd>Values of QK matrix multiplication, either before or after softmax normalization</dd>
Expand Down
45 changes: 25 additions & 20 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class GQAAttentionBase {
const bool is_prompt = parameters.is_first_prompt;
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
const int total_sequence_length = parameters.total_sequence_length;
const int head_size = parameters.head_size;
const int hidden_size = parameters.hidden_size;
Expand All @@ -85,7 +86,9 @@ class GQAAttentionBase {
if (past_key != nullptr && past_value != nullptr) {
seqlen_past_kv_cache = static_cast<int>(past_key->Shape().GetDims()[2]);
}
int seqlen_present_kv_cache = static_cast<int>(present_key->Shape().GetDims()[2]);
int seqlen_present_kv_cache = present_key != nullptr
? static_cast<int>(present_key->Shape().GetDims()[2])
: parameters.total_sequence_length;

// Compute the attention score.
bool gqa_mlas_supported = MlasGQASupported<T>(CblasNoTrans, CblasTrans) &&
Expand All @@ -110,28 +113,28 @@ class GQAAttentionBase {

if (gqa_mlas_supported) {
ComputeAttentionProbs(static_cast<T*>(attention_probs), Q, k, head_sink, seqlens_k->Data<int32_t>(), attention_bias_data,
batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache,
batch_size, sequence_length, kv_sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache,
seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer,
past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);

// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(attention_probs), v,
seqlens_k->Data<int32_t>(),
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
batch_size, sequence_length, kv_sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
is_prompt, tp, allocator);
} else {
ComputeAttentionProbs(static_cast<float*>(attention_probs), Q, k, head_sink, seqlens_k->Data<int32_t>(), attention_bias_data,
batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache,
batch_size, sequence_length, kv_sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache,
seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer,
past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);

// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<float*>(attention_probs), v,
seqlens_k->Data<int32_t>(),
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
batch_size, sequence_length, kv_sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
is_prompt, tp, allocator);
}
Expand All @@ -145,15 +148,16 @@ class GQAAttentionBase {
// attention_probs(B, N, S, T) = Softmax(attention_probs)
// If T is float32, U is float32. If T is float16, U could be float16 or float32.
template <typename T, typename U>
void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT
const T* Q, // Q data. Its size is BxNxSxH
const T* K, // k data. Its size is BxNxLxH
const T* head_sink, // for smooth softmax. Its size is N.
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
const T* attention_bias, // optional attention bias
const size_t batch_size, // batch size of self-attention
const size_t sequence_length, // sequence length of self-attention (S)
const size_t total_sequence_length, // total sequence length (T)
void ComputeAttentionProbs(U* attention_probs,
const T* Q,
const T* K,
const T* head_sink,
const int32_t* seqlens_k,
const T* attention_bias,
const size_t batch_size,
const size_t sequence_length,
const size_t kv_sequence_length,
const size_t total_sequence_length,
const gsl::span<const int64_t> attention_bias_shape, // shape of the attention bias
const size_t past_buffer_sequence_length, // sequence length of past state
const size_t present_buffer_sequence_length, // sequence length of present state
Expand All @@ -171,11 +175,11 @@ class GQAAttentionBase {
: SafeInt<ptrdiff_t>(0);
const size_t kv_num_heads_factor = num_heads_ / kv_num_heads_;
const size_t q_input_chunk_length = sequence_length * head_size; // S x H
const size_t kv_input_chunk_length = sequence_length * head_size; // L x H
const size_t kv_input_chunk_length = kv_sequence_length * head_size; // L x H
const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H
Comment thread
tianleiwu marked this conversation as resolved.
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H

if (!past_present_share_buffer) {
if (present_key && !past_present_share_buffer) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still treats KV-shared decode with no local past as if there were a past prefix whenever present_key is connected. For the new decode shape (sequence_length == 1, kv_sequence_length == total_sequence_length, past_key == nullptr), is_first_prompt is false, so past_seqlen becomes total_seqlen - sequence_length. ConcatStateChunkGQA then copies past_chunk_length from the null past_key pointer and appends the full K input after that synthetic offset, which can read null/invalid memory and overflow the present buffer. Please distinguish actual-past decode from borrowed-KV/no-past decode here; the no-past case should copy K/V from offset 0 (or reject connected present outputs for that mode if omitted-only is the intended contract).

memset((void*)present_key,
0,
batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
Expand Down Expand Up @@ -382,9 +386,10 @@ class GQAAttentionBase {
const T* V, // V value with size BxN_kvxSxH
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
const size_t batch_size, // batch size
const size_t sequence_length, // sequence length
const size_t sequence_length, // sequence length of Q
const size_t kv_sequence_length, // sequence length of K/V input
const size_t past_buffer_sequence_length, // sequence length in past state
const size_t present_buffer_sequence_length, // sequence length in past state
const size_t present_buffer_sequence_length, // sequence length in present state
const size_t head_size, // head size of Q, K, V
const size_t hidden_size, // hidden size of Output
const T* past_value, // past value only
Expand All @@ -398,11 +403,11 @@ class GQAAttentionBase {
packed_qkv ? SafeInt<ptrdiff_t>(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
: SafeInt<ptrdiff_t>(0);
const size_t kv_num_heads_factor = num_heads_ / kv_num_heads_;
const size_t kv_input_chunk_length = sequence_length * head_size; // L x H
const size_t kv_input_chunk_length = kv_sequence_length * head_size; // L x H
const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H

if (!past_present_share_buffer) {
if (present_value && !past_present_share_buffer) {
memset((void*)present_value,
0,
batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
Expand Down
24 changes: 21 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,21 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
Tensor* present_k = context->Output(1, present_k_shape);
Tensor* present_v = context->Output(2, present_v_shape);

// present_key and present_value must be both present or both absent.
if ((present_k == nullptr) != (present_v == nullptr)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"present_key and present_value must be both provided or both omitted.");
}

// Omitting present outputs is only safe when past_key is not provided.
// When past_key exists, ConcatStateChunkGQA must build a concatenated
// past+current KV buffer in present_key/present_value for the attention GEMMs.
if ((present_k == nullptr || present_v == nullptr) && past_key != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The is_first_prompt guard (sequence_length == total_sequence_length) is too restrictive for Gemma 4's KV-shared layers.

During decode steps: the KV-shared layer receives the source layer's full KV (length = total_seq_len) while Q has length 1. With this guard, is_first_prompt = false → rejection fires → the feature only works for prefill when Q and KV happen to have equal length.

The real safety condition is whether past KV exists and needs concatenation:

if ((present_k == nullptr || present_v == nullptr) && past_key != nullptr) {
  return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
                         "present_key and present_value outputs are required when past_key is provided.");
}

Note: Using MHA instead is not viable (no kv_num_heads, no seqlens_k, no sliding window/rotary/KV quantization). Modifying GQA is correct, but the constraint needs to be relaxed to support decode.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The guard is relaxed now, but the intended decode shape still does not get through input validation. Check_Q_K_V still requires query_dims[1] == key_dims[1] and query_dims[1] == value_dims[1], so a KV-shared decode call with Q length 1 and borrowed K/V length total_sequence_length is rejected before it reaches this code. Please either relax the shared validation and thread a separate KV sequence length through transpose/rotary/attention sizing, or explicitly scope this PR to prefill-only behavior.

"present_key and present_value outputs are required when past_key is provided. "
"Omitting present outputs is only supported when there is no past KV cache.");
}

std::vector<int64_t> output_qk_shape{static_cast<int64_t>(batch_size), static_cast<int64_t>(num_heads_), static_cast<int64_t>(parameters.sequence_length), static_cast<int64_t>(parameters.total_sequence_length)};
Tensor* output_qk = context->Output(3, output_qk_shape);

Expand All @@ -125,16 +140,17 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
OrtValue Q;
OrtValue K;
OrtValue V;
const int kv_sequence_length = parameters.kv_sequence_length;
if (packed_qkv) {
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q));
} else {
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, num_heads_, sequence_length, head_size, query, Q));
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, kv_num_heads_, sequence_length, head_size, key, K));
allocator, batch_size, kv_num_heads_, kv_sequence_length, head_size, key, K));
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V));
allocator, batch_size, kv_num_heads_, kv_sequence_length, head_size, value, V));
}

OrtValue RotaryQKV;
Expand Down Expand Up @@ -233,7 +249,9 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data<T>() : nullptr;

// Compute the attention score and apply the score to V
return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(),
const T* k_data = packed_qkv ? nullptr : k_rotary;
const T* v_data = packed_qkv ? nullptr : V.Get<Tensor>().Data<T>();
return ApplyAttention(q_rotary, k_data, v_data,
head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v,
output_qk, seqlens_k, parameters, allocator, context);
}
Expand Down
18 changes: 10 additions & 8 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ namespace group_query_attention_helper {

template <typename T>
Status Check_Q_K_V(const T* query, const T* key, const T* value, const int num_heads, const int kv_num_heads,
int& batch_size, int& sequence_length, int& q_hidden_size, int& kv_hidden_size, int& head_size) {
int& batch_size, int& sequence_length, int& kv_sequence_length,
int& q_hidden_size, int& kv_hidden_size, int& head_size) {
const auto& query_dims = query->Shape().GetDims();
if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
Expand All @@ -40,10 +41,8 @@ Status Check_Q_K_V(const T* query, const T* key, const T* value, const int num_h
} else if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 0 (batch size)");
} else if (query_dims[1] != key_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 1 (sequence length)");
}
kv_sequence_length = static_cast<int>(key_dims[1]);
kv_hidden_size = static_cast<int>(key_dims[2]);
if (kv_hidden_size % kv_num_heads != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand All @@ -61,9 +60,9 @@ Status Check_Q_K_V(const T* query, const T* key, const T* value, const int num_h
} else if (query_dims[0] != value_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 0 (batch size)");
} else if (query_dims[1] != value_dims[1]) {
} else if (key_dims[1] != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 1 (sequence length)");
"Input 'key' and 'value' shall have same dim 1 (sequence length)");
} else if (value_dims[2] != kv_hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
}
Expand Down Expand Up @@ -239,17 +238,19 @@ Status CheckInputs(const T* query,

int batch_size = 0;
int sequence_length = 0;
int kv_sequence_length = 0;
int q_hidden_size = 0;
int kv_hidden_size = 0;
int head_size = 0;
const bool is_packed_qkv = key == nullptr;
const bool is_packed_qkv = (key == nullptr);
if (!is_packed_qkv) {
ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, kv_num_heads, batch_size, sequence_length,
q_hidden_size, kv_hidden_size, head_size));
kv_sequence_length, q_hidden_size, kv_hidden_size, head_size));
} else {
qkv_format = QKV_BS3NH;
ORT_RETURN_IF_ERROR(Check_QKV(query, value, num_heads, kv_num_heads, batch_size, sequence_length, q_hidden_size,
kv_hidden_size, head_size));
kv_sequence_length = sequence_length;
}

// Check past-present KV
Expand Down Expand Up @@ -312,6 +313,7 @@ Status CheckInputs(const T* query,
GroupQueryAttentionParameters* output_parameters = reinterpret_cast<GroupQueryAttentionParameters*>(parameters);
output_parameters->batch_size = batch_size;
output_parameters->sequence_length = sequence_length; // sequence length of Q
output_parameters->kv_sequence_length = kv_sequence_length; // sequence length of K/V inputs
Comment on lines 313 to +316
output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors
output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors
output_parameters->total_sequence_length = total_sequence_length; // total sequence length
Expand Down
Loading
Loading