Optimize Whisper decoder cross-attention to skip redundant K/V projections during decode#70
Open
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
Conversation
…tions during decode
Author
|
I have three additional Whisper optimizations built on top of this change, available on a separate branch if you're interested:
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
The Whisper decoder's cross-attention layers recompute K and V projections from the full encoder output (audio_features) on every decode token, even though the cross-attention KV cache is already populated during prefill and the values never change. This wastes ~19.7B FLOPs per decode step (for large-v3-turbo: 4 layers x 2 projections x 1500 x 1280 x 1280) and transfers the full audio_features tensor (3.84 MB at fp32) from CPU to device on each step.
This PR moves the K/V projection inside the is_prefill branch so the decode path reads directly from the KV cache. During decode, a minimal 1-token dummy xa tensor is passed instead of the full encoder output, reducing CPU-to-device transfer by 1500x per step.
Changes (2 files, 4 edits):
Model Information
Model: Whisper large-v3-turbo (also applies to all Whisper variants)
Architecture: Encoder-decoder transformer (encoder: 32 layers, decoder: 4 layers, 20 heads, head_dim=64, cross-attention seq_len=1500)
Testing
Tested on trn2.3xlarge (LNC=2, TP=1) with openai/whisper-large-v3-turbo, batch_size=1, fp32.
Benchmark Results:
Metric
Per-token decode latency
Encoder latency
Prefill latency
Accuracy: 0% WER on test audio vs CPU reference (exact token match).
The Neuron compiler confirms xa is unused in the decode graph: "Received an input tensor that was unused... tensor will be ignored".
Compatibility