-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[GQA] Make present_key/present_value outputs optional and add Gemma4 support #28242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0db9578
ec041db
bcd8243
7139053
0bf6394
3c07e1d
f905196
d2ead38
aad74ef
64005dd
2b3a2ce
9a14803
2ef269c
6cbe62c
3db82c6
bd023f1
f0035aa
0afa1c9
ab0ddfb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -74,6 +74,7 @@ class GQAAttentionBase { | |
| const bool is_prompt = parameters.is_first_prompt; | ||
| const int batch_size = parameters.batch_size; | ||
| const int sequence_length = parameters.sequence_length; | ||
| const int kv_sequence_length = parameters.kv_sequence_length; | ||
| const int total_sequence_length = parameters.total_sequence_length; | ||
| const int head_size = parameters.head_size; | ||
| const int hidden_size = parameters.hidden_size; | ||
|
|
@@ -85,7 +86,9 @@ class GQAAttentionBase { | |
| if (past_key != nullptr && past_value != nullptr) { | ||
| seqlen_past_kv_cache = static_cast<int>(past_key->Shape().GetDims()[2]); | ||
| } | ||
| int seqlen_present_kv_cache = static_cast<int>(present_key->Shape().GetDims()[2]); | ||
| int seqlen_present_kv_cache = present_key != nullptr | ||
| ? static_cast<int>(present_key->Shape().GetDims()[2]) | ||
| : parameters.total_sequence_length; | ||
|
|
||
| // Compute the attention score. | ||
| bool gqa_mlas_supported = MlasGQASupported<T>(CblasNoTrans, CblasTrans) && | ||
|
|
@@ -110,28 +113,28 @@ class GQAAttentionBase { | |
|
|
||
| if (gqa_mlas_supported) { | ||
| ComputeAttentionProbs(static_cast<T*>(attention_probs), Q, k, head_sink, seqlens_k->Data<int32_t>(), attention_bias_data, | ||
| batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, | ||
| batch_size, sequence_length, kv_sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, | ||
| seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer, | ||
| past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); | ||
|
|
||
| // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) | ||
| const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; | ||
| ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(attention_probs), v, | ||
| seqlens_k->Data<int32_t>(), | ||
| batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, | ||
| batch_size, sequence_length, kv_sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, | ||
| hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, | ||
| is_prompt, tp, allocator); | ||
| } else { | ||
| ComputeAttentionProbs(static_cast<float*>(attention_probs), Q, k, head_sink, seqlens_k->Data<int32_t>(), attention_bias_data, | ||
| batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, | ||
| batch_size, sequence_length, kv_sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, | ||
| seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer, | ||
| past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); | ||
|
|
||
| // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) | ||
| const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; | ||
| ComputeVxAttentionScore(output->MutableData<T>(), static_cast<float*>(attention_probs), v, | ||
| seqlens_k->Data<int32_t>(), | ||
| batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, | ||
| batch_size, sequence_length, kv_sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, | ||
| hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, | ||
| is_prompt, tp, allocator); | ||
| } | ||
|
|
@@ -145,15 +148,16 @@ class GQAAttentionBase { | |
| // attention_probs(B, N, S, T) = Softmax(attention_probs) | ||
| // If T is float32, U is float32. If T is float16, U could be float16 or float32. | ||
| template <typename T, typename U> | ||
| void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT | ||
| const T* Q, // Q data. Its size is BxNxSxH | ||
| const T* K, // k data. Its size is BxNxLxH | ||
| const T* head_sink, // for smooth softmax. Its size is N. | ||
| const int32_t* seqlens_k, // total - 1 sequence lengths tensor | ||
| const T* attention_bias, // optional attention bias | ||
| const size_t batch_size, // batch size of self-attention | ||
| const size_t sequence_length, // sequence length of self-attention (S) | ||
| const size_t total_sequence_length, // total sequence length (T) | ||
| void ComputeAttentionProbs(U* attention_probs, | ||
| const T* Q, | ||
| const T* K, | ||
| const T* head_sink, | ||
| const int32_t* seqlens_k, | ||
| const T* attention_bias, | ||
| const size_t batch_size, | ||
| const size_t sequence_length, | ||
| const size_t kv_sequence_length, | ||
| const size_t total_sequence_length, | ||
| const gsl::span<const int64_t> attention_bias_shape, // shape of the attention bias | ||
| const size_t past_buffer_sequence_length, // sequence length of past state | ||
| const size_t present_buffer_sequence_length, // sequence length of present state | ||
|
|
@@ -171,11 +175,11 @@ class GQAAttentionBase { | |
| : SafeInt<ptrdiff_t>(0); | ||
| const size_t kv_num_heads_factor = num_heads_ / kv_num_heads_; | ||
| const size_t q_input_chunk_length = sequence_length * head_size; // S x H | ||
| const size_t kv_input_chunk_length = sequence_length * head_size; // L x H | ||
| const size_t kv_input_chunk_length = kv_sequence_length * head_size; // L x H | ||
| const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H | ||
|
tianleiwu marked this conversation as resolved.
|
||
| const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H | ||
|
|
||
| if (!past_present_share_buffer) { | ||
| if (present_key && !past_present_share_buffer) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This still treats KV-shared decode with no local past as if there were a past prefix whenever |
||
| memset((void*)present_key, | ||
| 0, | ||
| batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); | ||
|
|
@@ -382,9 +386,10 @@ class GQAAttentionBase { | |
| const T* V, // V value with size BxN_kvxSxH | ||
| const int32_t* seqlens_k, // total - 1 sequence lengths tensor | ||
| const size_t batch_size, // batch size | ||
| const size_t sequence_length, // sequence length | ||
| const size_t sequence_length, // sequence length of Q | ||
| const size_t kv_sequence_length, // sequence length of K/V input | ||
| const size_t past_buffer_sequence_length, // sequence length in past state | ||
| const size_t present_buffer_sequence_length, // sequence length in past state | ||
| const size_t present_buffer_sequence_length, // sequence length in present state | ||
| const size_t head_size, // head size of Q, K, V | ||
| const size_t hidden_size, // hidden size of Output | ||
| const T* past_value, // past value only | ||
|
|
@@ -398,11 +403,11 @@ class GQAAttentionBase { | |
| packed_qkv ? SafeInt<ptrdiff_t>(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size | ||
| : SafeInt<ptrdiff_t>(0); | ||
| const size_t kv_num_heads_factor = num_heads_ / kv_num_heads_; | ||
| const size_t kv_input_chunk_length = sequence_length * head_size; // L x H | ||
| const size_t kv_input_chunk_length = kv_sequence_length * head_size; // L x H | ||
| const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H | ||
| const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H | ||
|
|
||
| if (!past_present_share_buffer) { | ||
| if (present_value && !past_present_share_buffer) { | ||
| memset((void*)present_value, | ||
| 0, | ||
| batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -113,6 +113,21 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const { | |
| Tensor* present_k = context->Output(1, present_k_shape); | ||
| Tensor* present_v = context->Output(2, present_v_shape); | ||
|
|
||
| // present_key and present_value must be both present or both absent. | ||
| if ((present_k == nullptr) != (present_v == nullptr)) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "present_key and present_value must be both provided or both omitted."); | ||
| } | ||
|
|
||
| // Omitting present outputs is only safe when past_key is not provided. | ||
| // When past_key exists, ConcatStateChunkGQA must build a concatenated | ||
| // past+current KV buffer in present_key/present_value for the attention GEMMs. | ||
| if ((present_k == nullptr || present_v == nullptr) && past_key != nullptr) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The During decode steps: the KV-shared layer receives the source layer's full KV (length = The real safety condition is whether past KV exists and needs concatenation: if ((present_k == nullptr || present_v == nullptr) && past_key != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"present_key and present_value outputs are required when past_key is provided.");
}Note: Using MHA instead is not viable (no
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The guard is relaxed now, but the intended decode shape still does not get through input validation. |
||
| "present_key and present_value outputs are required when past_key is provided. " | ||
| "Omitting present outputs is only supported when there is no past KV cache."); | ||
| } | ||
|
|
||
| std::vector<int64_t> output_qk_shape{static_cast<int64_t>(batch_size), static_cast<int64_t>(num_heads_), static_cast<int64_t>(parameters.sequence_length), static_cast<int64_t>(parameters.total_sequence_length)}; | ||
| Tensor* output_qk = context->Output(3, output_qk_shape); | ||
|
|
||
|
|
@@ -125,16 +140,17 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const { | |
| OrtValue Q; | ||
| OrtValue K; | ||
| OrtValue V; | ||
| const int kv_sequence_length = parameters.kv_sequence_length; | ||
| if (packed_qkv) { | ||
| ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>( | ||
| allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q)); | ||
| } else { | ||
| ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>( | ||
| allocator, batch_size, num_heads_, sequence_length, head_size, query, Q)); | ||
| ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>( | ||
| allocator, batch_size, kv_num_heads_, sequence_length, head_size, key, K)); | ||
| allocator, batch_size, kv_num_heads_, kv_sequence_length, head_size, key, K)); | ||
| ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>( | ||
| allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V)); | ||
| allocator, batch_size, kv_num_heads_, kv_sequence_length, head_size, value, V)); | ||
| } | ||
|
|
||
| OrtValue RotaryQKV; | ||
|
|
@@ -233,7 +249,9 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const { | |
| const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data<T>() : nullptr; | ||
|
|
||
| // Compute the attention score and apply the score to V | ||
| return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(), | ||
| const T* k_data = packed_qkv ? nullptr : k_rotary; | ||
| const T* v_data = packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(); | ||
| return ApplyAttention(q_rotary, k_data, v_data, | ||
| head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, | ||
| output_qk, seqlens_k, parameters, allocator, context); | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.