Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2092,6 +2092,16 @@ def get_all_ranks():

def destroy_model_parallel():
"""Set the groups to none."""
# Destroy NCCL subgroups before nulling references — without this their
# NVLS multicast bindings leak across re-inits and eventually surface as
# a spurious OOM in transport/nvls.cc. Entry 0 is the default group.
global _global_process_group_list
if _global_process_group_list is not None:
pg_map = torch.distributed.distributed_c10d._world.pg_map
for group in _global_process_group_list[1:]:
if group is not None and pg_map.get(group, None) is not None:
torch.distributed.destroy_process_group(group)

global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None

Expand Down Expand Up @@ -2232,7 +2242,6 @@ def destroy_model_parallel():
global _INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP
_INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP = None

global _global_process_group_list
_global_process_group_list = None

SymmetricMemoryManager.destroy()
57 changes: 43 additions & 14 deletions tests/unit_tests/inference/test_hybrid_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@
MIXED_GIANT_PREFILL = (
"mixed_giant_prefill" # (max_requests-1) decode + 1 prefill with tokens > max_requests
)
ALL_STATES = [NONE, DECODE, PREFILL, MIXED]

_NO_CUDA_GRAPH_STATES = {PREFILL_AT_MAX_TOKENS, MIXED_GIANT_PREFILL}
ALL_STATES = [NONE, DECODE, PREFILL, MIXED, PREFILL_AT_MAX_TOKENS, MIXED_GIANT_PREFILL]

# Fixed expert-parallel size. When world_size > _EP_SIZE the remaining
# ranks form data-parallel replicas, each running the same EP combo
Expand All @@ -62,11 +61,7 @@
# across the EP ranks. Since rank assignment is symmetric (shuffling ranks
# with the same multiset of states is not a distinct configuration), we use
# combinations_with_replacement rather than the full Cartesian product.
# For _EP_SIZE=4 this gives C(4+4-1, 4) = 35 test cases.
#
# Edge states (PREFILL_AT_MAX_TOKENS, DECODE_AT_MAX_REQUESTS, MIXED_GIANT_PREFILL)
# are not swept combinatorially — one rank in the edge state against a fixed
# peer is sufficient.
# For _EP_SIZE=4 this gives C(6+4-1, 4) = 126 test cases.
_STATE_COMBOS = list(itertools.combinations_with_replacement(ALL_STATES, _EP_SIZE))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this fixes things, can we add the non cuda graphable states to ALL_STATES now?



Expand Down Expand Up @@ -197,6 +192,33 @@ def _assert_dynamic_inference_shape(self, model, ctx, rank, state_label):
f"got {tuple(out.shape)}"
)

@torch.inference_mode()
def _capture_all_cuda_graphs(self, model, ctx):
"""Pre-capture all cuda graphs in lockstep across EP ranks.

Mirrors DynamicInferenceEngine.create_cuda_graphs(): iterates every
shape in cuda_graph_batch_dimensions_list and runs a forward pass
with all EP ranks at the matching shape. After this, every rank's
model.forward goes through the replay path (no warmup loop), so
capture-mode and eager-mode ranks emit the same number of EP
collectives per call. Without this, the first forward triggers
cuda_graphs.create_fwd_graph's warmup loop on capture-mode ranks
only, deadlocking against eager-mode peers in mixed combos.
"""
for graph_dim in ctx.cuda_graph_batch_dimensions_list:
ctx.reset()
ctx.initialize_attention_state(construct_graph_dimensions=graph_dim)
padded = ctx.padded_batch_dimensions
input_ids = torch.randint(0, self.VOCAB_SIZE, (1, padded.token_count), device="cuda")
model(
input_ids=input_ids,
position_ids=None,
attention_mask=None,
inference_context=ctx,
runtime_gather_output=True,
)
ctx.reset()

@staticmethod
def _assert_cuda_graphs_were_replayed(expect_replayed, rank, label):
"""Assert that CUDA graphs were (or were not) recorded and replayed
Expand Down Expand Up @@ -239,14 +261,14 @@ class TestDynamicInferenceNVLS(_TestDynamicInferenceBase):
@torch.inference_mode()
def test_nvls_ep_state_cross_product(self, rank_states):
"""Test all combinatorial (unordered, with repetition) assignments of
the four request states across EP ranks.
the request states across EP ranks.

The NVLS dispatcher is used (match_ep_token_counts=False), so each rank
matches its own batch dimensions independently — no EP all-reduce.
The context is built with use_cuda_graphs_for_non_decode_steps=True,
so the CUDA graph list contains decode-only, mixed, and prefill-only
graphs, and every rank should find a matching graph for its own state
unless its token count exceeds the cuda-graph range (PREFILL_EXCEED).
The NVLS dispatcher (match_ep_token_counts=False) does per-rank
independent graph matching. Each rank finds a matching graph for its
own state unless its token count exceeds the cuda-graph range
(PREFILL_AT_MAX_TOKENS / MIXED_GIANT_PREFILL), in which case that
rank falls back to eager. The AllGather-V dispatcher handles per-rank
size variation, so mixed graph/eager combos work.

State setup uses add_dummy_requests_for_cudagraph_capture to populate
the context directly with the desired request configuration.
Expand All @@ -258,6 +280,13 @@ def test_nvls_ep_state_cross_product(self, rank_states):
model = self._build_model()
ctx = self._build_context(model, max_requests=64, max_tokens=512)

# Pre-capture every cuda graph in lockstep across EP ranks (mirrors
# DynamicInferenceEngine.create_cuda_graphs in production). Without
# this, the first per-rank forward triggers create_fwd_graph's
# warmup loop on capture-mode ranks only, which fires extra EP
# collectives that an eager peer cannot match — deadlock.
self._capture_all_cuda_graphs(model, ctx)

# Phase 1: Set up each rank's request state directly.
if not is_dummy:
ctx.add_dummy_requests_for_cudagraph_capture(_STATE_DIMS[my_state])
Expand Down
Loading