diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index c6f69df2b20..9765758eea7 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -889,6 +889,14 @@ def initialize_all_tensors(self) -> None: self.active_request_metadata = { label: torch.empty_like(tensor) for label, tensor in self.request_metadata.items() } + # Static tensor addresses to make `last_token_logits` graphable with speculative decoding. + max_logit_idxs = self.max_requests * (self.num_speculative_tokens + 1) + self.active_logit_idxs = torch.zeros( + max_logit_idxs, dtype=torch.int32, device=torch.cuda.current_device() + ) + self._decode_logit_idxs = torch.arange( + max_logit_idxs, dtype=torch.int32, device=torch.cuda.current_device() + ) # NOTE: Need to build this outside the UVM / TMS context to avoid IMA. if self.is_hybrid_model: @@ -1107,7 +1115,30 @@ def build_active_slices(self, batch_size: int): def pad_active_slices(self): """Pad the active slices of specific tensors.""" - pass + active_request_count = self.total_request_count - self.paused_request_count + active_decode_count = self.num_decode_requests + active_prefill_count = active_request_count - active_decode_count + active_decode_token_count = active_decode_count * (self.num_speculative_tokens + 1) + + # Decode prefix: positions [0, 1, ..., active_decode_token_count - 1]. + self.active_logit_idxs[:active_decode_token_count].copy_( + self._decode_logit_idxs[:active_decode_token_count] + ) + + # Prefill last-token positions: cumsum the prefill query lengths in place, + # then shift by (active_decode_token_count - 1) to get absolute positions. + prefill_dst = self.active_logit_idxs[ + active_decode_token_count : active_decode_token_count + active_prefill_count + ] + prefill_idxs = self.paused_request_count + active_decode_count + torch.cumsum( + self.request_query_lengths[prefill_idxs : self.total_request_count], + dim=0, + out=prefill_dst, + ) + prefill_dst.add_(active_decode_token_count - 1) + + self.active_logit_idxs[active_decode_token_count + active_prefill_count :].zero_() def append_key_value_cache(self, layer_number: int, key: Tensor, value: Tensor) -> None: """Append to KV cache. @@ -1731,7 +1762,9 @@ def initialize_attention_state( self.padded_active_request_count = self.padded_batch_dimensions.req_count self.padding_slice = slice(self.active_token_count, self.padded_active_token_count) - self.build_active_slices(self.padded_active_request_count) + self.build_active_slices( + min(self.padded_active_request_count, self.max_requests - self.paused_request_count) + ) self.pad_active_slices() # Update token position indexes. @@ -1923,31 +1956,18 @@ def current_input_and_position_ids( self.token_to_pos_ids[:num_tokens].unsqueeze(0), ) - def speculative_required_logit_indices(self, device: torch.device) -> Tensor: + def speculative_required_logit_indices(self) -> Tensor: """Token-level indices needed for speculative decode verification. Returns all decode token positions (base + speculative) concatenated with the last token position of each prefill request. - Args: - device (torch.device): Device on which to create the index tensor. - Return: (Tensor) 1-D indices into the packed token sequence, length - ``num_decode_requests * (num_speculative_tokens + 1) + num_prefill_requests``. + ``num_decode_requests * (num_speculative_tokens + 1) + num_prefill_requests`` + in eager, or the equivalent padded count under non-eager. """ - paused = self.paused_request_count - total = self.total_request_count - query_lengths = self.request_query_lengths[paused:total] - num_decode = self.num_decode_requests - - decode_token_count = num_decode * (self.num_speculative_tokens + 1) - decode_indices = torch.arange(decode_token_count, device=device) - - cumsum = torch.cumsum(query_lengths, dim=0) - prefill_last_indices = cumsum[num_decode:] - 1 - - return torch.cat([decode_indices, prefill_last_indices]) + return self.active_logit_idxs[: self.num_last_token_logits] @property def num_last_token_logits(self) -> int: @@ -1957,11 +1977,22 @@ def num_last_token_logits(self) -> int: `(num_speculative_tokens + 1)` rows per decode request when MTP is active. """ if self.num_speculative_tokens > 0: - return ( - self.num_decode_requests * (self.num_speculative_tokens + 1) - + self.num_prefill_requests - ) - return self.total_request_count - self.paused_request_count + if self._using_cuda_graph_this_step: + return ( + self.padded_batch_dimensions.decode_req_count + * (self.num_speculative_tokens + 1) + + self.padded_batch_dimensions.prefill_req_count + ) + else: + return ( + self.num_decode_requests * (self.num_speculative_tokens + 1) + + self.num_prefill_requests + ) + else: + if self._using_cuda_graph_this_step: + return self.padded_active_request_count + else: + return self.total_request_count - self.paused_request_count def last_token_logits(self, logits: Tensor) -> Tensor: """Select the logit positions needed for token generation. @@ -1984,19 +2015,7 @@ def last_token_logits(self, logits: Tensor) -> Tensor: f"logits.size(1) ({tuple(logits.shape)}) != " f"padded_active_token_count ({self.padded_active_token_count})." ) - logits_2d = logits.squeeze(0) - - if self.num_speculative_tokens > 0: - selected = self.speculative_required_logit_indices(logits.device) - assert selected.numel() == self.num_last_token_logits - return logits_2d[selected, :] - - paused = self.paused_request_count - total = self.total_request_count - query_lengths = self.request_query_lengths[paused:total] - last_token_idxs = torch.cumsum(query_lengths, dim=0) - 1 - assert last_token_idxs.numel() == self.num_last_token_logits - return logits_2d[last_token_idxs, :] + return logits.squeeze(0)[self.active_logit_idxs[: self.num_last_token_logits], :] def _compute_prefix_match( self, req: DynamicInferenceRequest, prefill_chunk_length: int diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 0bdc5853aaf..d665e17dc1a 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -1063,7 +1063,7 @@ def _dynamic_step_sample_logits_and_verify_tokens(self, input_ids: Tensor): nvtx_range_push("mtp-spec-decoding/verify/logit-indices") # Use pre-allocated buffer for CUDA graph compatibility. logits = self._all_logits_cuda - required_logit_indices = context.speculative_required_logit_indices(logits.device) + required_logit_indices = context.speculative_required_logit_indices() if context.config.materialize_only_last_token_logits: # last_token_logits already selected exactly the required positions. diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py index 721e69212e3..b20686685cd 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_context.py @@ -2970,3 +2970,243 @@ def test_chunked_prefill_meets_prefix_caching(self): # Verify block references updated appropriately assert ctx.kv_block_allocator.block_ref_counts[req1_blocks[2]].item() == 2 assert ctx.kv_block_allocator.block_ref_counts[req1_blocks[3]].item() == 2 + + # ------------------------------------------------------------------ # + # Tests for active_logit_idxs / last_token_logits / pad_active_slices + # ------------------------------------------------------------------ # + + def _build_speculative_ctx(self, num_speculative_tokens=2, block_size=256): + """Build a context configured for speculative decoding.""" + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.05, + block_size_tokens=block_size, + num_speculative_tokens=num_speculative_tokens, + unified_memory_level=0, + ) + return DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + def _add_and_step_decode_requests(self, ctx, num_requests, prompt_length=10): + """Add prefill requests, then step them into decode state with speculative tokens. + + Returns the context in a state with ``num_requests`` decode requests whose + query_lengths equal ``num_speculative_tokens + 1``. + """ + for i in range(num_requests): + req = DynamicInferenceRequest( + request_id=i, + prompt_tokens=torch.arange(0, prompt_length, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=100), + ) + ctx.add_request(req) + + ctx.initialize_attention_state() + + active_mask = torch.ones(num_requests, device='cuda', dtype=torch.int32) + new_tokens = torch.arange(num_requests, device='cuda') + num_spec = ctx.num_speculative_tokens + new_spec = torch.arange(num_spec * num_requests, device='cuda').reshape( + num_spec, num_requests + ) + ctx.update_requests( + active_requests_mask=active_mask, new_tokens=new_tokens, new_speculative_tokens=new_spec + ) + return ctx + + @pytest.mark.internal + @rounder_override(64) + def test_pad_active_slices_speculative_decode_only(self): + """Verify active_logit_idxs for a decode-only batch with speculative tokens.""" + num_decode = 3 + num_spec = 2 + ctx = self._build_speculative_ctx(num_speculative_tokens=num_spec) + self._add_and_step_decode_requests(ctx, num_decode) + + assert ctx.num_prefill_requests == 0 + assert ctx.num_decode_requests == num_decode + tokens_per_decode = num_spec + 1 + + ctx.initialize_attention_state() + + decode_token_count = num_decode * tokens_per_decode + expected_decode = torch.arange(decode_token_count, dtype=torch.int32, device='cuda') + actual = ctx.active_logit_idxs[:decode_token_count] + assert torch.equal( + actual, expected_decode + ), f"decode indices mismatch: {actual.tolist()} vs {expected_decode.tolist()}" + + assert ctx.num_last_token_logits == decode_token_count + assert ctx.active_logit_idxs[decode_token_count:].sum().item() == 0 + + @pytest.mark.internal + @rounder_override(64) + def test_pad_active_slices_speculative_mixed_batch(self): + """Verify active_logit_idxs for a mixed decode+prefill batch with speculative tokens.""" + num_decode = 2 + num_spec = 2 + ctx = self._build_speculative_ctx(num_speculative_tokens=num_spec) + self._add_and_step_decode_requests(ctx, num_decode) + + prefill_lengths = [15, 20] + for i, pl in enumerate(prefill_lengths): + req = DynamicInferenceRequest( + request_id=100 + i, + prompt_tokens=torch.arange(0, pl, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=50), + ) + ctx.add_request(req) + + assert ctx.num_decode_requests == num_decode + assert ctx.num_prefill_requests == len(prefill_lengths) + tokens_per_decode = num_spec + 1 + + ctx.initialize_attention_state() + + decode_token_count = num_decode * tokens_per_decode + expected_decode = torch.arange(decode_token_count, dtype=torch.int32, device='cuda') + actual_decode = ctx.active_logit_idxs[:decode_token_count] + assert torch.equal(actual_decode, expected_decode) + + cumulative = 0 + for i, pl in enumerate(prefill_lengths): + cumulative += pl + expected_prefill_idx = decode_token_count + cumulative - 1 + actual_prefill_idx = ctx.active_logit_idxs[decode_token_count + i].item() + assert ( + actual_prefill_idx == expected_prefill_idx + ), f"prefill request {i}: expected idx {expected_prefill_idx}, got {actual_prefill_idx}" + + expected_num_logits = decode_token_count + len(prefill_lengths) + assert ctx.num_last_token_logits == expected_num_logits + + @pytest.mark.internal + @rounder_override(64) + def test_pad_active_slices_speculative_all_prefill(self): + """Verify active_logit_idxs with only prefill requests (no decode) and speculative tokens.""" + num_spec = 2 + ctx = self._build_speculative_ctx(num_speculative_tokens=num_spec) + + prefill_lengths = [12, 8, 25] + for i, pl in enumerate(prefill_lengths): + req = DynamicInferenceRequest( + request_id=i, + prompt_tokens=torch.arange(0, pl, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=50), + ) + ctx.add_request(req) + + assert ctx.num_decode_requests == 0 + assert ctx.num_prefill_requests == len(prefill_lengths) + + ctx.initialize_attention_state() + + cumulative = 0 + for i, pl in enumerate(prefill_lengths): + cumulative += pl + expected_idx = cumulative - 1 + actual_idx = ctx.active_logit_idxs[i].item() + assert ( + actual_idx == expected_idx + ), f"prefill request {i}: expected idx {expected_idx}, got {actual_idx}" + + expected_num_logits = len(prefill_lengths) + assert ctx.num_last_token_logits == expected_num_logits + + @pytest.mark.internal + @rounder_override(64) + def test_pad_active_slices_no_speculative_tokens(self): + """Verify active_logit_idxs without speculative tokens matches cumsum - 1.""" + ctx = self._build_speculative_ctx(num_speculative_tokens=0) + + req0 = DynamicInferenceRequest( + request_id=0, + prompt_tokens=torch.arange(0, 10, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=50), + ) + ctx.add_request(req0) + ctx.initialize_attention_state() + active_mask = torch.ones(1, device='cuda', dtype=torch.int32) + new_tokens = torch.tensor([42], device='cuda') + ctx.update_requests(active_requests_mask=active_mask, new_tokens=new_tokens) + + prefill_lengths = [20, 30] + for i, pl in enumerate(prefill_lengths): + req = DynamicInferenceRequest( + request_id=10 + i, + prompt_tokens=torch.arange(0, pl, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=50), + ) + ctx.add_request(req) + + assert ctx.num_decode_requests == 1 + assert ctx.num_prefill_requests == 2 + + ctx.initialize_attention_state() + + all_query_lengths = ctx.request_query_lengths[ + ctx.paused_request_count : ctx.total_request_count + ] + expected_idxs = torch.cumsum(all_query_lengths, dim=0) - 1 + num_logits = ctx.num_last_token_logits + actual_idxs = ctx.active_logit_idxs[:num_logits] + assert torch.equal( + actual_idxs, expected_idxs.to(torch.int32) + ), f"non-speculative mismatch: {actual_idxs.tolist()} vs {expected_idxs.tolist()}" + + @pytest.mark.internal + @rounder_override(64) + def test_last_token_logits_selects_correct_values_speculative(self): + """Verify last_token_logits returns logits at the correct token positions.""" + num_decode = 2 + num_spec = 2 + ctx = self._build_speculative_ctx(num_speculative_tokens=num_spec) + self._add_and_step_decode_requests(ctx, num_decode) + + prefill_lengths = [10, 15] + for i, pl in enumerate(prefill_lengths): + req = DynamicInferenceRequest( + request_id=100 + i, + prompt_tokens=torch.arange(0, pl, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=50), + ) + ctx.add_request(req) + + ctx.initialize_attention_state() + + vocab_size = 32 + logits = torch.arange( + ctx.padded_active_token_count * vocab_size, dtype=torch.float32, device='cuda' + ).reshape(1, ctx.padded_active_token_count, vocab_size) + + result = ctx.last_token_logits(logits) + expected_num_logits = ctx.num_last_token_logits + assert result.shape == (expected_num_logits, vocab_size) + + idxs = ctx.active_logit_idxs[:expected_num_logits].long() + expected = logits.squeeze(0)[idxs, :] + assert torch.equal(result, expected) + + @pytest.mark.internal + @rounder_override(64) + def test_speculative_required_logit_indices_matches_active_logit_idxs(self): + """speculative_required_logit_indices returns a slice of active_logit_idxs.""" + num_decode = 2 + num_spec = 2 + ctx = self._build_speculative_ctx(num_speculative_tokens=num_spec) + self._add_and_step_decode_requests(ctx, num_decode) + + req = DynamicInferenceRequest( + request_id=100, + prompt_tokens=torch.arange(0, 20, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=50), + ) + ctx.add_request(req) + ctx.initialize_attention_state() + + indices = ctx.speculative_required_logit_indices() + expected_len = ctx.num_last_token_logits + assert indices.numel() == expected_len + assert indices.data_ptr() == ctx.active_logit_idxs.data_ptr() diff --git a/tests/unit_tests/inference/test_mtp_cuda_graph_inference.py b/tests/unit_tests/inference/test_mtp_cuda_graph_inference.py index ee2e08e5ccb..60ab7f29a9e 100644 --- a/tests/unit_tests/inference/test_mtp_cuda_graph_inference.py +++ b/tests/unit_tests/inference/test_mtp_cuda_graph_inference.py @@ -25,9 +25,11 @@ from megatron.core.inference.config import InferenceConfig from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext from megatron.core.inference.engines.dynamic_engine import DynamicInferenceEngine +from megatron.core.inference.inference_request import DynamicInferenceRequest from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( GPTInferenceWrapper, ) +from megatron.core.inference.sampling_params import SamplingParams from megatron.core.inference.text_generation_controllers.text_generation_controller import ( TextGenerationController, ) @@ -667,6 +669,112 @@ def test_delete_cuda_graphs_resets_mtp_runners(self): assert all(not r.fwd_graph_recorded for r in manager.cudagraph_runners) assert all(r.fwd_graph is None for r in manager.cudagraph_runners) + # ---- Test 8: last_token_logits under CUDA graph padding ---------------- # + + @torch.inference_mode() + def test_last_token_logits_cuda_graph_padding(self): + """num_last_token_logits returns padded count and last_token_logits + produces the correct shape under CUDA graph padding. + + Uses add_request + update_requests to build real decode batches, then + verifies that under CUDA graph matching: + 1. num_last_token_logits uses the padded decode count from the matched graph + 2. last_token_logits returns the padded number of rows + 3. The real (unpadded) index positions are sequential 0..N-1 + """ + num_spec = 2 + max_requests = 16 + engine = self._build_engine(num_speculative_tokens=num_spec, max_requests=max_requests) + context = engine.context + tokens_per_decode = num_spec + 1 + + # Collect decode-only graph sizes to pick active counts that will match. + decode_graph_sizes = sorted( + { + dim.decode_req_count + for dim in context.cuda_graph_batch_dimensions_list + if dim.prefill_req_count == 0 and dim.decode_req_count > 1 + } + ) + assert len(decode_graph_sizes) > 0, "No decode-only graph dims found" + + # Use active counts 1 less than some graph sizes to guarantee padding. + active_counts = [s - 1 for s in decode_graph_sizes if s >= 2][:3] + assert len(active_counts) > 0, "No sub-capacity decode graph dims found" + + for active_decode_count in active_counts: + context.reset() + + # Add prefill requests, then step them into decode state. + prompt_length = 10 + for i in range(active_decode_count): + req = DynamicInferenceRequest( + request_id=i, + prompt_tokens=torch.arange(prompt_length, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=100), + ) + context.add_request(req) + + context.initialize_attention_state() + + active_mask = torch.ones(active_decode_count, device='cuda', dtype=torch.int32) + new_tokens = torch.arange(active_decode_count, device='cuda') + new_spec = torch.arange(num_spec * active_decode_count, device='cuda').reshape( + num_spec, active_decode_count + ) + context.update_requests( + active_requests_mask=active_mask, + new_tokens=new_tokens, + new_speculative_tokens=new_spec, + ) + + # Now all requests are decode. initialize_attention_state should match a graph. + context.initialize_attention_state() + + assert ( + context.using_cuda_graph_this_step() + ), f"Expected CUDA graph for active={active_decode_count}" + + # Read the actually matched graph dimensions. + matched = context.padded_batch_dimensions + padded_decode = matched.decode_req_count + padded_token_count = matched.token_count + assert padded_decode >= active_decode_count + + expected_padded_logits = padded_decode * tokens_per_decode + assert context.num_last_token_logits == expected_padded_logits, ( + f"active={active_decode_count}, padded={padded_decode}: " + f"num_last_token_logits expected {expected_padded_logits}, " + f"got {context.num_last_token_logits}" + ) + + # Verify the real decode indices are [0, 1, ..., real_token_count - 1]. + real_token_count = active_decode_count * tokens_per_decode + real_slice = context.active_logit_idxs[:real_token_count] + expected_real = torch.arange(real_token_count, dtype=torch.int32, device='cuda') + assert torch.equal( + real_slice, expected_real + ), f"real decode indices: {real_slice.tolist()} vs {expected_real.tolist()}" + + # Padding indices should be zero (indexing into logits[0]). + padding_count = expected_padded_logits - real_token_count + if padding_count > 0: + padding_slice = context.active_logit_idxs[real_token_count:expected_padded_logits] + assert ( + padding_slice.sum().item() == 0 + ), f"padding indices should be zero, got {padding_slice.tolist()}" + + # Verify last_token_logits produces a tensor with the padded row count. + vocab_size = 64 + fake_logits = torch.randn( + 1, padded_token_count, vocab_size, device='cuda', dtype=torch.float32 + ) + result = context.last_token_logits(fake_logits) + assert result.shape == (expected_padded_logits, vocab_size), ( + f"last_token_logits shape: expected ({expected_padded_logits}, {vocab_size}), " + f"got {result.shape}" + ) + # --------------------------------------------------------------------------- # # TestMTPCudaGraphExpertParallel (EP = 2) diff --git a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py index 8ce21761f66..5846fbd95a0 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py @@ -323,9 +323,11 @@ def test_sample_from_dynamic_logits( self.text_generation_controller._sampling_backend = backend context.padded_active_token_count = batch_size - context.request_query_lengths = torch.ones(batch_size, dtype=torch.int32) + context.request_query_lengths = torch.ones(batch_size, dtype=torch.int32, device='cuda') context.paused_request_count = 0 context.total_request_count = batch_size + context.num_prefill_requests = 0 + context.pad_active_slices() # Bookkeeping. self.text_generation_controller._dynamic_step_sample_bookkeeping() @@ -973,6 +975,8 @@ def test_speculative_verify_tokens(self): ctx.request_query_lengths = torch.tensor( [3, 3], dtype=torch.int32, device='cuda' ) # 1 sampled + 2 spec + ctx.num_prefill_requests = 0 + ctx.pad_active_slices() # Init accepted tokens tensors self.text_generation_controller._init_mtp_sampling_tensors() @@ -1218,6 +1222,8 @@ def test_speculative_multinomial_sampling(self): ) # Decode requests # query lengths for decode with spec tokens is (1 + num_spec) = 4 ctx.request_query_lengths = torch.tensor([4, 4], dtype=torch.int32, device='cuda') + ctx.num_prefill_requests = 0 + ctx.pad_active_slices() # Setup inputs input_ids = torch.randint(0, self.vocab_size, (1, 8), device='cuda')