diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index 6ccbccf0f33..b9f62e59547 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -142,11 +142,7 @@ def req_count(self) -> int: @staticmethod def adjust_batch_dims_for_expert_parallelism( local_batch_dims, - strict: bool = False, - decode_only_cuda_graphs: bool = True, - smallest_non_decode_cuda_graph_size: int = 0, ep_group: Optional[torch.distributed.ProcessGroup] = None, - num_speculative_tokens: int = 0, ep_zmq_communicator=None, ) -> Optional["InferenceBatchDimensions"]: """Adjust CUDA graph batch dimensions for expert parallelism. @@ -158,8 +154,6 @@ def adjust_batch_dims_for_expert_parallelism( Args: local_batch_dims: The local batch dimensions to adjust. - strict: Whether to use strict matching for batch dimensions. - decode_only_cuda_graphs: Whether CUDA graphs are only used for decode steps. ep_group: Optional expert parallel process group. If None, uses global parallel state. When using different EP sizes for inference vs training, pass the inference EP group explicitly. @@ -181,22 +175,12 @@ def adjust_batch_dims_for_expert_parallelism( if ep_zmq_communicator is not None: # CPU-only sync via ZMQ: avoids a NCCL AllReduce kernel on the # compute stream plus the H2D/D2H pair that sandwiches it. - (max_token_count, max_is_non_decode, max_prefill_count, max_decode_count) = ( - ep_zmq_communicator.sync_all_reduce_max( - local_batch_dims.token_count, - int(is_non_decode), - local_batch_dims.prefill_req_count, - local_batch_dims.decode_req_count, - ) + (max_token_count, max_is_non_decode) = ep_zmq_communicator.sync_all_reduce_max( + local_batch_dims.token_count, int(is_non_decode) ) else: sync_tensor = torch.tensor( - [ - local_batch_dims.token_count, - int(is_non_decode), - local_batch_dims.prefill_req_count, - local_batch_dims.decode_req_count, - ], + [local_batch_dims.token_count, int(is_non_decode)], dtype=torch.int32, device=torch.cuda.current_device(), ) @@ -206,44 +190,16 @@ def adjust_batch_dims_for_expert_parallelism( sync_tensor = sync_tensor.cpu() max_token_count = int(sync_tensor[0].item()) max_is_non_decode = int(sync_tensor[1].item()) - max_prefill_count = int(sync_tensor[2].item()) - max_decode_count = int(sync_tensor[3].item()) is_any_ep_rank_in_non_decode = max_is_non_decode == 1 - if is_any_ep_rank_in_non_decode and decode_only_cuda_graphs: + if is_any_ep_rank_in_non_decode: return None # any rank has prefill → eager mode - adjusted_token_count = max_token_count - - # Sync request counts across EP ranks when strict matching is enabled - # or when speculative tokens are used. With speculative tokens, - # decode-only graphs have token counts of decode_req_count * (spec+1) - # which creates a different granularity than mixed graphs (raw sizes). - # Without syncing, decode-only ranks and prefill ranks search different - # graph pools and may pick graphs with different token counts. - sync_request_counts = strict or ( - is_any_ep_rank_in_non_decode and num_speculative_tokens > 0 - ) - adjusted_prefill_req_count = ( - max_prefill_count if sync_request_counts else local_batch_dims.prefill_req_count - ) - adjusted_decode_req_count = ( - max_decode_count if sync_request_counts else local_batch_dims.decode_req_count - ) - - # When any EP rank has prefill requests (non-strict mode), elevate - # the token count to be >= the smallest prefill/mixed cuda graph. - # This ensures decode-only ranks don't match a fine-grained decode - # graph while prefill ranks match a coarser mixed graph, which would - # produce inconsistent token counts across EP ranks. - if is_any_ep_rank_in_non_decode and not strict: - adjusted_token_count = max(adjusted_token_count, smallest_non_decode_cuda_graph_size) - adjusted_batch_dim = InferenceBatchDimensions( - token_count=adjusted_token_count, - prefill_req_count=adjusted_prefill_req_count, - decode_req_count=adjusted_decode_req_count, + token_count=max_token_count, + prefill_req_count=local_batch_dims.prefill_req_count, + decode_req_count=local_batch_dims.decode_req_count, ) return adjusted_batch_dim @@ -529,11 +485,8 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int def match_graph_config( real_batch_dim: InferenceBatchDimensions, cuda_graph_batch_dimensions_list: List[InferenceBatchDimensions], - smallest_non_decode_cuda_graph_size: int = 0, strict: bool = False, - decode_only_cuda_graphs: bool = True, ep_group: Optional[torch.distributed.ProcessGroup] = None, - num_speculative_tokens: int = 0, ep_zmq_communicator=None, match_ep_token_counts: bool = True, ) -> Optional[InferenceBatchDimensions]: @@ -571,13 +524,7 @@ def match_graph_config( # NCCL dispatcher: all EP ranks must select the same CUDA graph. Sync batch dims # across the EP group so graph selection is consistent. adjusted_batch_dim = InferenceBatchDimensions.adjust_batch_dims_for_expert_parallelism( - real_batch_dim, - strict=strict, - decode_only_cuda_graphs=decode_only_cuda_graphs, - ep_group=ep_group, - smallest_non_decode_cuda_graph_size=smallest_non_decode_cuda_graph_size, - num_speculative_tokens=num_speculative_tokens, - ep_zmq_communicator=ep_zmq_communicator, + real_batch_dim, ep_group=ep_group, ep_zmq_communicator=ep_zmq_communicator ) if adjusted_batch_dim is None: diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 159e1f90b34..842c234d6bc 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -600,18 +600,30 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ), "Router recording/replay requested but no MoE experts specified!" self.moe_routing_metadata = RoutingMetadata(self, model_config.moe_router_topk) - # CUDA graph config list. + # are we using the inference_optimized nccl ep dispatcher for MoEs? self._nccl_ep_dispatcher = ( get_pg_size(self.expert_model_parallel_group) > 1 - and getattr(model_config, 'inference_moe_token_dispatcher_type', 'nccl') == 'nccl' + and model_config.inference_moe_token_dispatcher_type == 'nccl' + ) + + # are we using the training a2a dispatcher for MoEs? + # Note that this is not optimal for speed. + self._training_ep_dispatcher = ( + get_pg_size(self.expert_model_parallel_group) > 1 + and model_config.transformer_impl == "transformer_engine" + ) + + # We only allow non-decode cuda graphs for the nvls dispatcher + force_disable_non_decode_cuda_graphs = ( + self._nccl_ep_dispatcher or self._training_ep_dispatcher ) - # We disable non-decode cuda graphs for the nccl dispatcher. - # The NCCL dispatcher uses allgathers. Thus there is a need to - # run the same sized cuda-graph on every EP rank. This is difficult to - # generalize for non-decode steps. + self.use_cuda_graphs_for_non_decode_steps = ( - inference_config.use_cuda_graphs_for_non_decode_steps and not self._nccl_ep_dispatcher + inference_config.use_cuda_graphs_for_non_decode_steps + and not (force_disable_non_decode_cuda_graphs) ) + + # CUDA graph config list. self.cuda_graph_batch_dimensions_list, self.cuda_graph_token_counts = ( CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( tp_size=tp_size, @@ -642,10 +654,6 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ep_group=self.expert_model_parallel_group, ) - self.smallest_non_decode_cuda_graph_size = min( - inference_config.cuda_graph_mixed_prefill_count, self.max_requests - ) - # Deal with chunked prefill self.enable_chunked_prefill = inference_config.enable_chunked_prefill @@ -1982,13 +1990,10 @@ def initialize_attention_state( best_graph = CUDAGraphBatchDimensionBuilder.match_graph_config( batch_dimensions, self.cuda_graph_batch_dimensions_list, - smallest_non_decode_cuda_graph_size=self.smallest_non_decode_cuda_graph_size, strict=self.is_hybrid_model, - decode_only_cuda_graphs=(not self.use_cuda_graphs_for_non_decode_steps), ep_group=self.expert_model_parallel_group, - num_speculative_tokens=self.num_speculative_tokens, + match_ep_token_counts=self._nccl_ep_dispatcher or self._training_ep_dispatcher, ep_zmq_communicator=self._ep_zmq_communicator, - match_ep_token_counts=self._nccl_ep_dispatcher, ) self._using_cuda_graph_this_step = best_graph is not None