From 5bd4899b006db79b3cfd3babc3456b3d7eecf818 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Wed, 11 Mar 2026 10:20:00 -0400 Subject: [PATCH] Optimize Whisper decoder cross-attention to skip redundant K/V projections during decode --- .../models/whisper/modeling_whisper.py | 33 +++++++++++++------ .../models/whisper/utils/decoding.py | 10 ++++-- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/neuronx_distributed_inference/models/whisper/modeling_whisper.py b/src/neuronx_distributed_inference/models/whisper/modeling_whisper.py index 4dc6024a..6167762d 100644 --- a/src/neuronx_distributed_inference/models/whisper/modeling_whisper.py +++ b/src/neuronx_distributed_inference/models/whisper/modeling_whisper.py @@ -219,15 +219,16 @@ def forward( is_prefill: bool = True, ): bsz, seq_len, hidden_dim = x.shape - kv_seq_len = xa.shape[1] - # bs, head, seqlen, head_dim + # Q projection (always needed for both prefill and decode) q = self.query(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) - k = self.key(xa).view(bsz, kv_seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) - v = self.value(xa).view(bsz, kv_seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) - # Save KV Cache if is_prefill: + # Prefill: compute K/V from encoder output and populate cache + kv_seq_len = xa.shape[1] + k = self.key(xa).view(bsz, kv_seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) + v = self.value(xa).view(bsz, kv_seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) + indices = torch.arange(start=0, end=kv_seq_len, dtype=torch.int64, device=q.device) indices = indices.view(1, 1, kv_seq_len, 1) indices = indices.expand(bsz, self.n_kv_heads, kv_seq_len, self.head_dim) @@ -235,6 +236,7 @@ def forward( updated_kcache = torch.scatter(self.cache_k, 2, indices, k) updated_vcache = torch.scatter(self.cache_v, 2, indices, v) else: + # Decode: use cached K/V directly (no K/V projection needed, xa is unused) updated_kcache = self.cache_k updated_vcache = self.cache_v @@ -546,10 +548,13 @@ def __init__(self, config, model_cls, tag="", compiler_args=None, priority_model self.bucket_config = None # Set to None if no bucketing needed def input_generator(self) -> List[Tuple[torch.Tensor]]: - # Generate example inputs for tracing + # Generate example inputs for tracing. + # Use minimal dummy xa (1 token instead of n_audio_ctx) since decode reads + # cross-attention K/V from cache, not from xa. The xa tensor must be present + # for forward signature compatibility but is unused in the decode graph. audio_embed = torch.randn( self.neuron_config.batch_size, - self.config.dims.n_audio_ctx, + 1, self.config.dims.n_audio_state, dtype=self.neuron_config.torch_dtype, ) @@ -700,9 +705,17 @@ def load(self, compiled_model_path, *args, **kwargs): def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): tokens = tokens.to(torch.int32) padded_tokens, last_pos, pad_mask = self._prepare_decoder_inputs(tokens) - return self.decoder( - padded_tokens, audio_features.to(self.config.neuron_config.torch_dtype), last_pos, pad_mask - )[:, : last_pos + 1] + is_prefill = padded_tokens.shape[1] > 1 + if is_prefill: + xa = audio_features.to(self.config.neuron_config.torch_dtype) + else: + # Decode: pass minimal dummy xa since cross-attention K/V caches + # were populated during prefill. xa is unused in the decode graph. + xa = torch.zeros( + audio_features.shape[0], 1, audio_features.shape[2], + dtype=self.config.neuron_config.torch_dtype, + ) + return self.decoder(padded_tokens, xa, last_pos, pad_mask)[:, : last_pos + 1] def _prepare_decoder_inputs(self, tokens: torch.Tensor): pad_token = -1 diff --git a/src/neuronx_distributed_inference/models/whisper/utils/decoding.py b/src/neuronx_distributed_inference/models/whisper/utils/decoding.py index fc740b1d..5b1594d1 100644 --- a/src/neuronx_distributed_inference/models/whisper/utils/decoding.py +++ b/src/neuronx_distributed_inference/models/whisper/utils/decoding.py @@ -22,10 +22,16 @@ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: padded_tokens, last_pos, pad_mask = self.model._prepare_decoder_inputs(tokens) if tokens.shape[-1] > self.initial_token_length: - # only need to use the last token except in the first forward pass + # Decode: only need the last token, pass dummy xa since + # cross-attention K/V caches were populated during prefill tokens = tokens[:, -1:] - return self.model.decoder(tokens, audio_features, last_pos, pad_mask) + dummy_audio = torch.zeros( + audio_features.shape[0], 1, audio_features.shape[2], + dtype=audio_features.dtype, + ) + return self.model.decoder(tokens, dummy_audio, last_pos, pad_mask) else: + # Prefill: pass full audio features tokens = padded_tokens return self.model.decoder(tokens, audio_features, last_pos, pad_mask)[:, : last_pos + 1]