From e2873fbcddb12f40d6b0a8417e5e00cc7a5d26a6 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 1 May 2026 16:59:00 -0700 Subject: [PATCH 1/2] do cuda-graph matching for the legacy a2a training dispatcher --- .../inference/contexts/dynamic_context.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index c6f69df2b20..144382668ed 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -593,18 +593,30 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ), "Router recording/replay requested but no MoE experts specified!" self.moe_routing_metadata = RoutingMetadata(self, model_config.moe_router_topk) - # CUDA graph config list. + # are we using the inference_optimized nccl ep dispatcher for MoEs? self._nccl_ep_dispatcher = ( get_pg_size(self.expert_model_parallel_group) > 1 - and getattr(model_config, 'inference_moe_token_dispatcher_type', 'nccl') == 'nccl' + and model_config.inference_moe_token_dispatcher_type == 'nccl' ) - # We disable non-decode cuda graphs for the nccl dispatcher. - # The NCCL dispatcher uses allgathers. Thus there is a need to - # run the same sized cuda-graph on every EP rank. This is difficult to - # generalize for non-decode steps. + + # are we using the training a2a dispatcher for MoEs? + # Note that this is not optimal for speed. + self._training_ep_dispatcher = ( + get_pg_size(self.expert_model_parallel_group) > 1 + and model_config.transformer_impl == "transformer_engine" + ) + + # We only allow non-decode cuda graphs for the nvls dispatcher + force_disable_non_decode_cuda_graphs = ( + self._nccl_ep_dispatcher or self._training_ep_dispatcher + ) + self.use_cuda_graphs_for_non_decode_steps = ( - inference_config.use_cuda_graphs_for_non_decode_steps and not self._nccl_ep_dispatcher + inference_config.use_cuda_graphs_for_non_decode_steps + and not (force_disable_non_decode_cuda_graphs) ) + + # CUDA graph config list. self.cuda_graph_batch_dimensions_list, self.cuda_graph_token_counts = ( CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( tp_size=tp_size, @@ -1690,7 +1702,7 @@ def initialize_attention_state( self.cuda_graph_batch_dimensions_list, strict=self.is_hybrid_model, ep_group=self.expert_model_parallel_group, - match_ep_token_counts=self._nccl_ep_dispatcher, + match_ep_token_counts=self._nccl_ep_dispatcher or self._training_ep_dispatcher, ) self._using_cuda_graph_this_step = best_graph is not None From dd23aa02b800a2047194441fc66b9ef412963fd4 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 4 May 2026 09:31:31 -0400 Subject: [PATCH 2/2] Remove obsolete parameter Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> --- megatron/core/inference/batch_dimensions_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index 6ddf0f887ad..b9f62e59547 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -154,7 +154,6 @@ def adjust_batch_dims_for_expert_parallelism( Args: local_batch_dims: The local batch dimensions to adjust. - strict: Whether to use strict matching for batch dimensions. ep_group: Optional expert parallel process group. If None, uses global parallel state. When using different EP sizes for inference vs training, pass the inference EP group explicitly.