Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 8 additions & 61 deletions megatron/core/inference/batch_dimensions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,7 @@ def req_count(self) -> int:
@staticmethod
def adjust_batch_dims_for_expert_parallelism(
local_batch_dims,
strict: bool = False,
decode_only_cuda_graphs: bool = True,
smallest_non_decode_cuda_graph_size: int = 0,
ep_group: Optional[torch.distributed.ProcessGroup] = None,
num_speculative_tokens: int = 0,
ep_zmq_communicator=None,
) -> Optional["InferenceBatchDimensions"]:
"""Adjust CUDA graph batch dimensions for expert parallelism.
Expand All @@ -158,8 +154,6 @@ def adjust_batch_dims_for_expert_parallelism(

Args:
local_batch_dims: The local batch dimensions to adjust.
strict: Whether to use strict matching for batch dimensions.
decode_only_cuda_graphs: Whether CUDA graphs are only used for decode steps.
ep_group: Optional expert parallel process group. If None, uses global parallel state.
When using different EP sizes for inference vs training, pass the
inference EP group explicitly.
Expand All @@ -181,22 +175,12 @@ def adjust_batch_dims_for_expert_parallelism(
if ep_zmq_communicator is not None:
# CPU-only sync via ZMQ: avoids a NCCL AllReduce kernel on the
# compute stream plus the H2D/D2H pair that sandwiches it.
(max_token_count, max_is_non_decode, max_prefill_count, max_decode_count) = (
ep_zmq_communicator.sync_all_reduce_max(
local_batch_dims.token_count,
int(is_non_decode),
local_batch_dims.prefill_req_count,
local_batch_dims.decode_req_count,
)
(max_token_count, max_is_non_decode) = ep_zmq_communicator.sync_all_reduce_max(
local_batch_dims.token_count, int(is_non_decode)
)
else:
sync_tensor = torch.tensor(
[
local_batch_dims.token_count,
int(is_non_decode),
local_batch_dims.prefill_req_count,
local_batch_dims.decode_req_count,
],
[local_batch_dims.token_count, int(is_non_decode)],
dtype=torch.int32,
device=torch.cuda.current_device(),
)
Expand All @@ -206,44 +190,16 @@ def adjust_batch_dims_for_expert_parallelism(
sync_tensor = sync_tensor.cpu()
max_token_count = int(sync_tensor[0].item())
max_is_non_decode = int(sync_tensor[1].item())
max_prefill_count = int(sync_tensor[2].item())
max_decode_count = int(sync_tensor[3].item())

is_any_ep_rank_in_non_decode = max_is_non_decode == 1

if is_any_ep_rank_in_non_decode and decode_only_cuda_graphs:
if is_any_ep_rank_in_non_decode:
return None # any rank has prefill → eager mode

adjusted_token_count = max_token_count

# Sync request counts across EP ranks when strict matching is enabled
# or when speculative tokens are used. With speculative tokens,
# decode-only graphs have token counts of decode_req_count * (spec+1)
# which creates a different granularity than mixed graphs (raw sizes).
# Without syncing, decode-only ranks and prefill ranks search different
# graph pools and may pick graphs with different token counts.
sync_request_counts = strict or (
is_any_ep_rank_in_non_decode and num_speculative_tokens > 0
)
adjusted_prefill_req_count = (
max_prefill_count if sync_request_counts else local_batch_dims.prefill_req_count
)
adjusted_decode_req_count = (
max_decode_count if sync_request_counts else local_batch_dims.decode_req_count
)

# When any EP rank has prefill requests (non-strict mode), elevate
# the token count to be >= the smallest prefill/mixed cuda graph.
# This ensures decode-only ranks don't match a fine-grained decode
# graph while prefill ranks match a coarser mixed graph, which would
# produce inconsistent token counts across EP ranks.
if is_any_ep_rank_in_non_decode and not strict:
adjusted_token_count = max(adjusted_token_count, smallest_non_decode_cuda_graph_size)

adjusted_batch_dim = InferenceBatchDimensions(
token_count=adjusted_token_count,
prefill_req_count=adjusted_prefill_req_count,
decode_req_count=adjusted_decode_req_count,
token_count=max_token_count,
prefill_req_count=local_batch_dims.prefill_req_count,
decode_req_count=local_batch_dims.decode_req_count,
)

return adjusted_batch_dim
Expand Down Expand Up @@ -529,11 +485,8 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int
def match_graph_config(
real_batch_dim: InferenceBatchDimensions,
cuda_graph_batch_dimensions_list: List[InferenceBatchDimensions],
smallest_non_decode_cuda_graph_size: int = 0,
strict: bool = False,
decode_only_cuda_graphs: bool = True,
ep_group: Optional[torch.distributed.ProcessGroup] = None,
num_speculative_tokens: int = 0,
ep_zmq_communicator=None,
match_ep_token_counts: bool = True,
) -> Optional[InferenceBatchDimensions]:
Expand Down Expand Up @@ -571,13 +524,7 @@ def match_graph_config(
# NCCL dispatcher: all EP ranks must select the same CUDA graph. Sync batch dims
# across the EP group so graph selection is consistent.
adjusted_batch_dim = InferenceBatchDimensions.adjust_batch_dims_for_expert_parallelism(
real_batch_dim,
strict=strict,
decode_only_cuda_graphs=decode_only_cuda_graphs,
ep_group=ep_group,
smallest_non_decode_cuda_graph_size=smallest_non_decode_cuda_graph_size,
num_speculative_tokens=num_speculative_tokens,
ep_zmq_communicator=ep_zmq_communicator,
real_batch_dim, ep_group=ep_group, ep_zmq_communicator=ep_zmq_communicator
)

if adjusted_batch_dim is None:
Expand Down
35 changes: 20 additions & 15 deletions megatron/core/inference/contexts/dynamic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,18 +600,30 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC
), "Router recording/replay requested but no MoE experts specified!"
self.moe_routing_metadata = RoutingMetadata(self, model_config.moe_router_topk)

# CUDA graph config list.
# are we using the inference_optimized nccl ep dispatcher for MoEs?
self._nccl_ep_dispatcher = (
get_pg_size(self.expert_model_parallel_group) > 1
and getattr(model_config, 'inference_moe_token_dispatcher_type', 'nccl') == 'nccl'
and model_config.inference_moe_token_dispatcher_type == 'nccl'
)

# are we using the training a2a dispatcher for MoEs?
# Note that this is not optimal for speed.
self._training_ep_dispatcher = (
get_pg_size(self.expert_model_parallel_group) > 1
and model_config.transformer_impl == "transformer_engine"
)

# We only allow non-decode cuda graphs for the nvls dispatcher
force_disable_non_decode_cuda_graphs = (
self._nccl_ep_dispatcher or self._training_ep_dispatcher
)
# We disable non-decode cuda graphs for the nccl dispatcher.
# The NCCL dispatcher uses allgathers. Thus there is a need to
# run the same sized cuda-graph on every EP rank. This is difficult to
# generalize for non-decode steps.

self.use_cuda_graphs_for_non_decode_steps = (
inference_config.use_cuda_graphs_for_non_decode_steps and not self._nccl_ep_dispatcher
inference_config.use_cuda_graphs_for_non_decode_steps
and not (force_disable_non_decode_cuda_graphs)
)

# CUDA graph config list.
self.cuda_graph_batch_dimensions_list, self.cuda_graph_token_counts = (
CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list(
tp_size=tp_size,
Expand Down Expand Up @@ -642,10 +654,6 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC
ep_group=self.expert_model_parallel_group,
)

self.smallest_non_decode_cuda_graph_size = min(
inference_config.cuda_graph_mixed_prefill_count, self.max_requests
)

# Deal with chunked prefill
self.enable_chunked_prefill = inference_config.enable_chunked_prefill

Expand Down Expand Up @@ -1982,13 +1990,10 @@ def initialize_attention_state(
best_graph = CUDAGraphBatchDimensionBuilder.match_graph_config(
batch_dimensions,
self.cuda_graph_batch_dimensions_list,
smallest_non_decode_cuda_graph_size=self.smallest_non_decode_cuda_graph_size,
strict=self.is_hybrid_model,
decode_only_cuda_graphs=(not self.use_cuda_graphs_for_non_decode_steps),
ep_group=self.expert_model_parallel_group,
num_speculative_tokens=self.num_speculative_tokens,
match_ep_token_counts=self._nccl_ep_dispatcher or self._training_ep_dispatcher,
ep_zmq_communicator=self._ep_zmq_communicator,
match_ep_token_counts=self._nccl_ep_dispatcher,
)
self._using_cuda_graph_this_step = best_graph is not None

Expand Down
Loading