Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 56 additions & 37 deletions megatron/core/inference/contexts/dynamic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,14 @@ def initialize_all_tensors(self) -> None:
self.active_request_metadata = {
label: torch.empty_like(tensor) for label, tensor in self.request_metadata.items()
}
# Static tensor addresses to make `last_token_logits` graphable with speculative decoding.
max_logit_idxs = self.max_requests * (self.num_speculative_tokens + 1)
self.active_logit_idxs = torch.zeros(
max_logit_idxs, dtype=torch.int32, device=torch.cuda.current_device()
)
self._decode_logit_idxs = torch.arange(
max_logit_idxs, dtype=torch.int32, device=torch.cuda.current_device()
)

# NOTE: Need to build this outside the UVM / TMS context to avoid IMA.
if self.is_hybrid_model:
Expand Down Expand Up @@ -1107,7 +1115,30 @@ def build_active_slices(self, batch_size: int):

def pad_active_slices(self):
"""Pad the active slices of specific tensors."""
pass
active_request_count = self.total_request_count - self.paused_request_count
active_decode_count = self.num_decode_requests
active_prefill_count = active_request_count - active_decode_count
active_decode_token_count = active_decode_count * (self.num_speculative_tokens + 1)

# Decode prefix: positions [0, 1, ..., active_decode_token_count - 1].
self.active_logit_idxs[:active_decode_token_count].copy_(
self._decode_logit_idxs[:active_decode_token_count]
)

# Prefill last-token positions: cumsum the prefill query lengths in place,
# then shift by (active_decode_token_count - 1) to get absolute positions.
prefill_dst = self.active_logit_idxs[
active_decode_token_count : active_decode_token_count + active_prefill_count
]
prefill_idxs = self.paused_request_count + active_decode_count
torch.cumsum(
self.request_query_lengths[prefill_idxs : self.total_request_count],
dim=0,
out=prefill_dst,
)
prefill_dst.add_(active_decode_token_count - 1)

self.active_logit_idxs[active_decode_token_count + active_prefill_count :].zero_()

def append_key_value_cache(self, layer_number: int, key: Tensor, value: Tensor) -> None:
"""Append to KV cache.
Expand Down Expand Up @@ -1731,7 +1762,9 @@ def initialize_attention_state(
self.padded_active_request_count = self.padded_batch_dimensions.req_count
self.padding_slice = slice(self.active_token_count, self.padded_active_token_count)

self.build_active_slices(self.padded_active_request_count)
self.build_active_slices(
min(self.padded_active_request_count, self.max_requests - self.paused_request_count)
)
self.pad_active_slices()

# Update token position indexes.
Expand Down Expand Up @@ -1923,31 +1956,18 @@ def current_input_and_position_ids(
self.token_to_pos_ids[:num_tokens].unsqueeze(0),
)

def speculative_required_logit_indices(self, device: torch.device) -> Tensor:
def speculative_required_logit_indices(self) -> Tensor:
"""Token-level indices needed for speculative decode verification.

Returns all decode token positions (base + speculative) concatenated
with the last token position of each prefill request.

Args:
device (torch.device): Device on which to create the index tensor.

Return:
(Tensor) 1-D indices into the packed token sequence, length
``num_decode_requests * (num_speculative_tokens + 1) + num_prefill_requests``.
``num_decode_requests * (num_speculative_tokens + 1) + num_prefill_requests``
in eager, or the equivalent padded count under non-eager.
"""
paused = self.paused_request_count
total = self.total_request_count
query_lengths = self.request_query_lengths[paused:total]
num_decode = self.num_decode_requests

decode_token_count = num_decode * (self.num_speculative_tokens + 1)
decode_indices = torch.arange(decode_token_count, device=device)

cumsum = torch.cumsum(query_lengths, dim=0)
prefill_last_indices = cumsum[num_decode:] - 1

return torch.cat([decode_indices, prefill_last_indices])
return self.active_logit_idxs[: self.num_last_token_logits]

@property
def num_last_token_logits(self) -> int:
Expand All @@ -1957,11 +1977,22 @@ def num_last_token_logits(self) -> int:
`(num_speculative_tokens + 1)` rows per decode request when MTP is active.
"""
if self.num_speculative_tokens > 0:
return (
self.num_decode_requests * (self.num_speculative_tokens + 1)
+ self.num_prefill_requests
)
return self.total_request_count - self.paused_request_count
if self._using_cuda_graph_this_step:
return (
self.padded_batch_dimensions.decode_req_count
* (self.num_speculative_tokens + 1)
+ self.padded_batch_dimensions.prefill_req_count
)
else:
return (
self.num_decode_requests * (self.num_speculative_tokens + 1)
+ self.num_prefill_requests
)
else:
if self._using_cuda_graph_this_step:
return self.padded_active_request_count
else:
return self.total_request_count - self.paused_request_count

def last_token_logits(self, logits: Tensor) -> Tensor:
"""Select the logit positions needed for token generation.
Expand All @@ -1984,19 +2015,7 @@ def last_token_logits(self, logits: Tensor) -> Tensor:
f"logits.size(1) ({tuple(logits.shape)}) != "
f"padded_active_token_count ({self.padded_active_token_count})."
)
logits_2d = logits.squeeze(0)

if self.num_speculative_tokens > 0:
selected = self.speculative_required_logit_indices(logits.device)
assert selected.numel() == self.num_last_token_logits
return logits_2d[selected, :]

paused = self.paused_request_count
total = self.total_request_count
query_lengths = self.request_query_lengths[paused:total]
last_token_idxs = torch.cumsum(query_lengths, dim=0) - 1
assert last_token_idxs.numel() == self.num_last_token_logits
return logits_2d[last_token_idxs, :]
return logits.squeeze(0)[self.active_logit_idxs[: self.num_last_token_logits], :]

def _compute_prefix_match(
self, req: DynamicInferenceRequest, prefill_chunk_length: int
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ def _dynamic_step_sample_logits_and_verify_tokens(self, input_ids: Tensor):
nvtx_range_push("mtp-spec-decoding/verify/logit-indices")
# Use pre-allocated buffer for CUDA graph compatibility.
logits = self._all_logits_cuda
required_logit_indices = context.speculative_required_logit_indices(logits.device)
required_logit_indices = context.speculative_required_logit_indices()

if context.config.materialize_only_last_token_logits:
# last_token_logits already selected exactly the required positions.
Expand Down
240 changes: 240 additions & 0 deletions tests/unit_tests/inference/contexts/test_dynamic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2970,3 +2970,243 @@ def test_chunked_prefill_meets_prefix_caching(self):
# Verify block references updated appropriately
assert ctx.kv_block_allocator.block_ref_counts[req1_blocks[2]].item() == 2
assert ctx.kv_block_allocator.block_ref_counts[req1_blocks[3]].item() == 2

# ------------------------------------------------------------------ #
# Tests for active_logit_idxs / last_token_logits / pad_active_slices
# ------------------------------------------------------------------ #

def _build_speculative_ctx(self, num_speculative_tokens=2, block_size=256):
"""Build a context configured for speculative decoding."""
model_config = TransformerConfig(
params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2
)
inference_config = InferenceConfig(
max_sequence_length=512,
buffer_size_gb=0.05,
block_size_tokens=block_size,
num_speculative_tokens=num_speculative_tokens,
unified_memory_level=0,
)
return DynamicInferenceContext(model_config=model_config, inference_config=inference_config)

def _add_and_step_decode_requests(self, ctx, num_requests, prompt_length=10):
"""Add prefill requests, then step them into decode state with speculative tokens.

Returns the context in a state with ``num_requests`` decode requests whose
query_lengths equal ``num_speculative_tokens + 1``.
"""
for i in range(num_requests):
req = DynamicInferenceRequest(
request_id=i,
prompt_tokens=torch.arange(0, prompt_length, device='cuda'),
sampling_params=SamplingParams(num_tokens_to_generate=100),
)
ctx.add_request(req)

ctx.initialize_attention_state()

active_mask = torch.ones(num_requests, device='cuda', dtype=torch.int32)
new_tokens = torch.arange(num_requests, device='cuda')
num_spec = ctx.num_speculative_tokens
new_spec = torch.arange(num_spec * num_requests, device='cuda').reshape(
num_spec, num_requests
)
ctx.update_requests(
active_requests_mask=active_mask, new_tokens=new_tokens, new_speculative_tokens=new_spec
)
return ctx

@pytest.mark.internal
@rounder_override(64)
def test_pad_active_slices_speculative_decode_only(self):
"""Verify active_logit_idxs for a decode-only batch with speculative tokens."""
num_decode = 3
num_spec = 2
ctx = self._build_speculative_ctx(num_speculative_tokens=num_spec)
self._add_and_step_decode_requests(ctx, num_decode)

assert ctx.num_prefill_requests == 0
assert ctx.num_decode_requests == num_decode
tokens_per_decode = num_spec + 1

ctx.initialize_attention_state()

decode_token_count = num_decode * tokens_per_decode
expected_decode = torch.arange(decode_token_count, dtype=torch.int32, device='cuda')
actual = ctx.active_logit_idxs[:decode_token_count]
assert torch.equal(
actual, expected_decode
), f"decode indices mismatch: {actual.tolist()} vs {expected_decode.tolist()}"

assert ctx.num_last_token_logits == decode_token_count
assert ctx.active_logit_idxs[decode_token_count:].sum().item() == 0

@pytest.mark.internal
@rounder_override(64)
def test_pad_active_slices_speculative_mixed_batch(self):
"""Verify active_logit_idxs for a mixed decode+prefill batch with speculative tokens."""
num_decode = 2
num_spec = 2
ctx = self._build_speculative_ctx(num_speculative_tokens=num_spec)
self._add_and_step_decode_requests(ctx, num_decode)

prefill_lengths = [15, 20]
for i, pl in enumerate(prefill_lengths):
req = DynamicInferenceRequest(
request_id=100 + i,
prompt_tokens=torch.arange(0, pl, device='cuda'),
sampling_params=SamplingParams(num_tokens_to_generate=50),
)
ctx.add_request(req)

assert ctx.num_decode_requests == num_decode
assert ctx.num_prefill_requests == len(prefill_lengths)
tokens_per_decode = num_spec + 1

ctx.initialize_attention_state()

decode_token_count = num_decode * tokens_per_decode
expected_decode = torch.arange(decode_token_count, dtype=torch.int32, device='cuda')
actual_decode = ctx.active_logit_idxs[:decode_token_count]
assert torch.equal(actual_decode, expected_decode)

cumulative = 0
for i, pl in enumerate(prefill_lengths):
cumulative += pl
expected_prefill_idx = decode_token_count + cumulative - 1
actual_prefill_idx = ctx.active_logit_idxs[decode_token_count + i].item()
assert (
actual_prefill_idx == expected_prefill_idx
), f"prefill request {i}: expected idx {expected_prefill_idx}, got {actual_prefill_idx}"

expected_num_logits = decode_token_count + len(prefill_lengths)
assert ctx.num_last_token_logits == expected_num_logits

@pytest.mark.internal
@rounder_override(64)
def test_pad_active_slices_speculative_all_prefill(self):
"""Verify active_logit_idxs with only prefill requests (no decode) and speculative tokens."""
num_spec = 2
ctx = self._build_speculative_ctx(num_speculative_tokens=num_spec)

prefill_lengths = [12, 8, 25]
for i, pl in enumerate(prefill_lengths):
req = DynamicInferenceRequest(
request_id=i,
prompt_tokens=torch.arange(0, pl, device='cuda'),
sampling_params=SamplingParams(num_tokens_to_generate=50),
)
ctx.add_request(req)

assert ctx.num_decode_requests == 0
assert ctx.num_prefill_requests == len(prefill_lengths)

ctx.initialize_attention_state()

cumulative = 0
for i, pl in enumerate(prefill_lengths):
cumulative += pl
expected_idx = cumulative - 1
actual_idx = ctx.active_logit_idxs[i].item()
assert (
actual_idx == expected_idx
), f"prefill request {i}: expected idx {expected_idx}, got {actual_idx}"

expected_num_logits = len(prefill_lengths)
assert ctx.num_last_token_logits == expected_num_logits

@pytest.mark.internal
@rounder_override(64)
def test_pad_active_slices_no_speculative_tokens(self):
"""Verify active_logit_idxs without speculative tokens matches cumsum - 1."""
ctx = self._build_speculative_ctx(num_speculative_tokens=0)

req0 = DynamicInferenceRequest(
request_id=0,
prompt_tokens=torch.arange(0, 10, device='cuda'),
sampling_params=SamplingParams(num_tokens_to_generate=50),
)
ctx.add_request(req0)
ctx.initialize_attention_state()
active_mask = torch.ones(1, device='cuda', dtype=torch.int32)
new_tokens = torch.tensor([42], device='cuda')
ctx.update_requests(active_requests_mask=active_mask, new_tokens=new_tokens)

prefill_lengths = [20, 30]
for i, pl in enumerate(prefill_lengths):
req = DynamicInferenceRequest(
request_id=10 + i,
prompt_tokens=torch.arange(0, pl, device='cuda'),
sampling_params=SamplingParams(num_tokens_to_generate=50),
)
ctx.add_request(req)

assert ctx.num_decode_requests == 1
assert ctx.num_prefill_requests == 2

ctx.initialize_attention_state()

all_query_lengths = ctx.request_query_lengths[
ctx.paused_request_count : ctx.total_request_count
]
expected_idxs = torch.cumsum(all_query_lengths, dim=0) - 1
num_logits = ctx.num_last_token_logits
actual_idxs = ctx.active_logit_idxs[:num_logits]
assert torch.equal(
actual_idxs, expected_idxs.to(torch.int32)
), f"non-speculative mismatch: {actual_idxs.tolist()} vs {expected_idxs.tolist()}"

@pytest.mark.internal
@rounder_override(64)
def test_last_token_logits_selects_correct_values_speculative(self):
"""Verify last_token_logits returns logits at the correct token positions."""
num_decode = 2
num_spec = 2
ctx = self._build_speculative_ctx(num_speculative_tokens=num_spec)
self._add_and_step_decode_requests(ctx, num_decode)

prefill_lengths = [10, 15]
for i, pl in enumerate(prefill_lengths):
req = DynamicInferenceRequest(
request_id=100 + i,
prompt_tokens=torch.arange(0, pl, device='cuda'),
sampling_params=SamplingParams(num_tokens_to_generate=50),
)
ctx.add_request(req)

ctx.initialize_attention_state()

vocab_size = 32
logits = torch.arange(
ctx.padded_active_token_count * vocab_size, dtype=torch.float32, device='cuda'
).reshape(1, ctx.padded_active_token_count, vocab_size)

result = ctx.last_token_logits(logits)
expected_num_logits = ctx.num_last_token_logits
assert result.shape == (expected_num_logits, vocab_size)

idxs = ctx.active_logit_idxs[:expected_num_logits].long()
expected = logits.squeeze(0)[idxs, :]
assert torch.equal(result, expected)

@pytest.mark.internal
@rounder_override(64)
def test_speculative_required_logit_indices_matches_active_logit_idxs(self):
"""speculative_required_logit_indices returns a slice of active_logit_idxs."""
num_decode = 2
num_spec = 2
ctx = self._build_speculative_ctx(num_speculative_tokens=num_spec)
self._add_and_step_decode_requests(ctx, num_decode)

req = DynamicInferenceRequest(
request_id=100,
prompt_tokens=torch.arange(0, 20, device='cuda'),
sampling_params=SamplingParams(num_tokens_to_generate=50),
)
ctx.add_request(req)
ctx.initialize_attention_state()

indices = ctx.speculative_required_logit_indices()
expected_len = ctx.num_last_token_logits
assert indices.numel() == expected_len
assert indices.data_ptr() == ctx.active_logit_idxs.data_ptr()
Loading
Loading