Skip to content

Optimize Whisper decoder cross-attention to skip redundant K/V projections during decode#70

Open
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
jimburtoft:fix/whisper-decode-cross-attn-optimization
Open

Optimize Whisper decoder cross-attention to skip redundant K/V projections during decode#70
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
jimburtoft:fix/whisper-decode-cross-attn-optimization

Conversation

@jimburtoft
Copy link

Description

The Whisper decoder's cross-attention layers recompute K and V projections from the full encoder output (audio_features) on every decode token, even though the cross-attention KV cache is already populated during prefill and the values never change. This wastes ~19.7B FLOPs per decode step (for large-v3-turbo: 4 layers x 2 projections x 1500 x 1280 x 1280) and transfers the full audio_features tensor (3.84 MB at fp32) from CPU to device on each step.
This PR moves the K/V projection inside the is_prefill branch so the decode path reads directly from the KV cache. During decode, a minimal 1-token dummy xa tensor is passed instead of the full encoder output, reducing CPU-to-device transfer by 1500x per step.

Changes (2 files, 4 edits):

  1. NeuronCrossAttention.forward() — Move K/V projection (self.key(xa), self.value(xa)) inside if is_prefill:. Decode path computes only Q and uses cached K/V.
  2. ModelWrapperWhisperDecoderDecode.input_generator() — Change trace-time dummy xa shape from (B, 1500, 1280) to (B, 1, 1280). The tensor is unused in the decode graph but must be present for forward signature compatibility.
  3. NeuronApplicationWhisper.logits() — During decode, create a torch.zeros(B, 1, dim) dummy instead of passing the full audio_features.
  4. NeuronInference.logits() (decoding.py) — Same dummy audio_features pattern for the decoding loop entry point.

Model Information

Model: Whisper large-v3-turbo (also applies to all Whisper variants)
Architecture: Encoder-decoder transformer (encoder: 32 layers, decoder: 4 layers, 20 heads, head_dim=64, cross-attention seq_len=1500)

Testing

Tested on trn2.3xlarge (LNC=2, TP=1) with openai/whisper-large-v3-turbo, batch_size=1, fp32.
Benchmark Results:
Metric
Per-token decode latency
Encoder latency
Prefill latency
Accuracy: 0% WER on test audio vs CPU reference (exact token match).
The Neuron compiler confirms xa is unused in the decode graph: "Received an input tensor that was unused... tensor will be ignored".

Compatibility

  • Instance Type(s): trn2.3xlarge
  • Neuron SDK: 2.27 (DLAMI 20260126)
  • Configuration: TP=1, batch_size=1, fp32

@jimburtoft
Copy link
Author

I have three additional Whisper optimizations built on top of this change, available on a separate branch if you're interested:

  • Fused QKV projections — replaces 3 separate Q/K/V linear layers with a single fused qkv_proj in self-attention, with state_dict conversion
  • NKI flash attention in encoder — replaces matmul-based attention with scaled_dot_product_attention_kernel (NxDI's existing attention_isa_kernel wrapper) across all 32 encoder layers. 45% faster compilation.
  • NKI fused Conv1D+GELU in encoder — replaces nn.Conv1d + F.gelu with nkilib.experimental.conv.conv1d fused kernel, with graceful fallback if nkilib isn't installed
    All use existing NKI kernels from the SDK and nki-library (no custom kernels). Combined branch: https://github.com/jimburtoft/neuronx-distributed-inference/tree/fix/whisper-all-optimizations
    I can either squash these into this PR or open a separate one — let me know your preference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant