Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -219,22 +219,24 @@ def forward(
is_prefill: bool = True,
):
bsz, seq_len, hidden_dim = x.shape
kv_seq_len = xa.shape[1]

# bs, head, seqlen, head_dim
# Q projection (always needed for both prefill and decode)
q = self.query(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = self.key(xa).view(bsz, kv_seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.value(xa).view(bsz, kv_seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)

# Save KV Cache
if is_prefill:
# Prefill: compute K/V from encoder output and populate cache
kv_seq_len = xa.shape[1]
k = self.key(xa).view(bsz, kv_seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.value(xa).view(bsz, kv_seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)

indices = torch.arange(start=0, end=kv_seq_len, dtype=torch.int64, device=q.device)
indices = indices.view(1, 1, kv_seq_len, 1)
indices = indices.expand(bsz, self.n_kv_heads, kv_seq_len, self.head_dim)

updated_kcache = torch.scatter(self.cache_k, 2, indices, k)
updated_vcache = torch.scatter(self.cache_v, 2, indices, v)
else:
# Decode: use cached K/V directly (no K/V projection needed, xa is unused)
updated_kcache = self.cache_k
updated_vcache = self.cache_v

Expand Down Expand Up @@ -546,10 +548,13 @@ def __init__(self, config, model_cls, tag="", compiler_args=None, priority_model
self.bucket_config = None # Set to None if no bucketing needed

def input_generator(self) -> List[Tuple[torch.Tensor]]:
# Generate example inputs for tracing
# Generate example inputs for tracing.
# Use minimal dummy xa (1 token instead of n_audio_ctx) since decode reads
# cross-attention K/V from cache, not from xa. The xa tensor must be present
# for forward signature compatibility but is unused in the decode graph.
audio_embed = torch.randn(
self.neuron_config.batch_size,
self.config.dims.n_audio_ctx,
1,
self.config.dims.n_audio_state,
dtype=self.neuron_config.torch_dtype,
)
Expand Down Expand Up @@ -700,9 +705,17 @@ def load(self, compiled_model_path, *args, **kwargs):
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
tokens = tokens.to(torch.int32)
padded_tokens, last_pos, pad_mask = self._prepare_decoder_inputs(tokens)
return self.decoder(
padded_tokens, audio_features.to(self.config.neuron_config.torch_dtype), last_pos, pad_mask
)[:, : last_pos + 1]
is_prefill = padded_tokens.shape[1] > 1
if is_prefill:
xa = audio_features.to(self.config.neuron_config.torch_dtype)
else:
# Decode: pass minimal dummy xa since cross-attention K/V caches
# were populated during prefill. xa is unused in the decode graph.
xa = torch.zeros(
audio_features.shape[0], 1, audio_features.shape[2],
dtype=self.config.neuron_config.torch_dtype,
)
return self.decoder(padded_tokens, xa, last_pos, pad_mask)[:, : last_pos + 1]

def _prepare_decoder_inputs(self, tokens: torch.Tensor):
pad_token = -1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
padded_tokens, last_pos, pad_mask = self.model._prepare_decoder_inputs(tokens)

if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass
# Decode: only need the last token, pass dummy xa since
# cross-attention K/V caches were populated during prefill
tokens = tokens[:, -1:]
return self.model.decoder(tokens, audio_features, last_pos, pad_mask)
dummy_audio = torch.zeros(
audio_features.shape[0], 1, audio_features.shape[2],
dtype=audio_features.dtype,
)
return self.model.decoder(tokens, dummy_audio, last_pos, pad_mask)
else:
# Prefill: pass full audio features
tokens = padded_tokens
return self.model.decoder(tokens, audio_features, last_pos, pad_mask)[:, : last_pos + 1]

Expand Down