From 153003a26020cbafaeedb526bc438673247bd2f9 Mon Sep 17 00:00:00 2001 From: Peter Dykas Date: Fri, 1 May 2026 07:16:59 -0700 Subject: [PATCH 1/3] fix hang --- megatron/core/parallel_state.py | 11 +- ..._hybrid_model_expert_parallel_inference.py | 101 +++++++++++------- 2 files changed, 75 insertions(+), 37 deletions(-) diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 337485b4d12..63857495f53 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -2092,6 +2092,16 @@ def get_all_ranks(): def destroy_model_parallel(): """Set the groups to none.""" + # Destroy NCCL subgroups before nulling references — without this their + # NVLS multicast bindings leak across re-inits and eventually surface as + # a spurious OOM in transport/nvls.cc. Entry 0 is the default group. + global _global_process_group_list + if _global_process_group_list is not None: + pg_map = torch.distributed.distributed_c10d._world.pg_map + for group in _global_process_group_list[1:]: + if group is not None and pg_map.get(group, None) is not None: + torch.distributed.destroy_process_group(group) + global _MODEL_PARALLEL_GROUP _MODEL_PARALLEL_GROUP = None @@ -2232,7 +2242,6 @@ def destroy_model_parallel(): global _INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP _INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP = None - global _global_process_group_list _global_process_group_list = None SymmetricMemoryManager.destroy() diff --git a/tests/unit_tests/models/test_hybrid_model_expert_parallel_inference.py b/tests/unit_tests/models/test_hybrid_model_expert_parallel_inference.py index 6fb3df43ad5..c21f375b641 100644 --- a/tests/unit_tests/models/test_hybrid_model_expert_parallel_inference.py +++ b/tests/unit_tests/models/test_hybrid_model_expert_parallel_inference.py @@ -41,8 +41,13 @@ DECODE = "decode" # >0 decode, 0 prefill PREFILL = "prefill" # 0 decode, >0 prefill MIXED = "mixed" # >0 decode, >0 prefill +# Non-cuda-graphable states (token_count exceeds the cuda graph capacity +# of _SWEEP_MAX_REQUESTS, forcing eager fallback for all EP ranks). +PREFILL_AT_MAX_TOKENS = "prefill_max" # prefill-only, token_count > graph capacity +MIXED_GIANT_PREFILL = "mixed_giant" # mixed with a giant prefill, token_count > graph capacity -ALL_STATES = [NONE, DECODE, PREFILL, MIXED] +_NO_CUDA_GRAPH_STATES = {PREFILL_AT_MAX_TOKENS, MIXED_GIANT_PREFILL} +ALL_STATES = [NONE, DECODE, PREFILL, MIXED, PREFILL_AT_MAX_TOKENS, MIXED_GIANT_PREFILL] # Fixed expert-parallel size. When world_size > _EP_SIZE the remaining # ranks form data-parallel replicas, each running the same EP combo @@ -53,9 +58,14 @@ # across the EP ranks. Since rank assignment is symmetric (shuffling ranks # with the same multiset of states is not a distinct configuration), we use # combinations_with_replacement rather than the full Cartesian product. -# For _EP_SIZE=4 this gives C(4+4-1, 4) = 35 test cases. _STATE_COMBOS = list(itertools.combinations_with_replacement(ALL_STATES, _EP_SIZE)) +# Cap max_requests for the cross-product sweep so that the cuda graph +# capacity (max_requests * (num_speculative_tokens + 1) = 64) is small +# enough that PREFILL_AT_MAX_TOKENS / MIXED_GIANT_PREFILL token counts +# overflow it (forcing eager fallback), while the other states still fit. +_SWEEP_MAX_REQUESTS = 64 + # Batch dimensions used to set up each non-dummy state via # add_dummy_requests_for_cudagraph_capture. These are intentionally small # to keep the tests fast while still exercising the EP padding logic. @@ -66,6 +76,14 @@ PREFILL: InferenceBatchDimensions(token_count=32, prefill_req_count=2, decode_req_count=0), # 4 decode (4 tokens) + 2 prefill (60 tokens) = 64 tokens MIXED: InferenceBatchDimensions(token_count=64, prefill_req_count=2, decode_req_count=4), + # 1 prefill of 128 tokens (> _SWEEP_MAX_REQUESTS=64, overflows every graph) + PREFILL_AT_MAX_TOKENS: InferenceBatchDimensions( + token_count=128, prefill_req_count=1, decode_req_count=0 + ), + # 2 decode (2 tokens) + 1 prefill (126 tokens) = 128 tokens (overflows graphs) + MIXED_GIANT_PREFILL: InferenceBatchDimensions( + token_count=128, prefill_req_count=1, decode_req_count=2 + ), } @@ -220,60 +238,71 @@ def _assert_dummy_forward_shape(self, model, rank): @torch.inference_mode() def test_ep_state_cross_product(self, rank_states): """Test all combinatorial (unordered, with repetition) assignments of - the four request states across EP ranks. + the request states across EP ranks. The context is built with use_cuda_graphs_for_non_decode_steps=True, so the CUDA graph list contains decode-only, mixed, and prefill-only - graphs. After the EP all-reduce in match_graph_config, every rank - (including dummy ranks) should always find a matching graph. - - State setup uses add_dummy_requests_for_cudagraph_capture to populate - the context directly with the desired request configuration. + graphs. For combos whose states all fit within the cuda graph capacity + (max_requests * (num_speculative_tokens + 1) = _SWEEP_MAX_REQUESTS), + every rank — including dummy ranks — must find a matching graph. + + For combos that include any non-cuda-graphable state + (PREFILL_AT_MAX_TOKENS / MIXED_GIANT_PREFILL), match_graph_config + returns None on every rank (eager fallback): dummy ranks bail out and + run dummy_forward; non-dummy ranks run the eager + padded_batch_dimensions path. """ ep_rank = parallel_state.get_expert_model_parallel_rank() my_state = rank_states[ep_rank] is_dummy = my_state == NONE + expect_graph = not any(s in _NO_CUDA_GRAPH_STATES for s in rank_states) model = self._build_model() - ctx = self._build_context(model) + ctx = self._build_context(model, max_requests=_SWEEP_MAX_REQUESTS) # Phase 1: Set up each rank's request state directly. if not is_dummy: ctx.add_dummy_requests_for_cudagraph_capture(_STATE_DIMS[my_state]) # Phase 2: Initialize attention state (EP collective). - if is_dummy: - ctx.initialize_attention_state(is_expert_parallel_dummy_cuda_graph_step=True) - else: - ctx.initialize_attention_state() + ctx.initialize_attention_state(is_expert_parallel_dummy_cuda_graph_step=is_dummy) - # Phase 3: Verify. - # With mixed CUDA graphs available, every rank — including dummy - # ranks whose EP-adjusted dimensions inherit prefill/decode counts - # from peers — must find a matching graph. - assert ctx.using_cuda_graph_this_step(), ( - f"EP rank {ep_rank} (state={my_state}): expected a CUDA graph match " - f"with use_cuda_graphs_for_non_decode_steps=True " - f"(rank_states={rank_states})" + # Phase 3: Verify CUDA graph match status agrees with the combo. + used_graph = ctx.using_cuda_graph_this_step() + assert used_graph == expect_graph, ( + f"EP rank {ep_rank} (state={my_state}, rank_states={rank_states}): " + f"expected using_cuda_graph_this_step={expect_graph}, got {used_graph}" ) - # All EP ranks must agree on padded token count. - padded = ctx.padded_batch_dimensions - ep_group = parallel_state.get_expert_model_parallel_group() - tc = torch.tensor([padded.token_count], dtype=torch.int32, device="cuda") - tc_max = tc.clone() - tc_min = tc.clone() - dist.all_reduce(tc_max, op=dist.ReduceOp.MAX, group=ep_group) - dist.all_reduce(tc_min, op=dist.ReduceOp.MIN, group=ep_group) - assert tc_max.item() == tc_min.item(), ( - f"Padded token count mismatch across EP ranks: " - f"min={tc_min.item()}, max={tc_max.item()} " - f"(rank_states={rank_states})" - ) + if used_graph: + # All EP ranks must agree on padded token count when a graph + # was matched. (In eager mode dummy ranks bail out without + # setting padded_batch_dimensions, so this only applies here.) + padded = ctx.padded_batch_dimensions + ep_group = parallel_state.get_expert_model_parallel_group() + tc = torch.tensor([padded.token_count], dtype=torch.int32, device="cuda") + tc_max = tc.clone() + tc_min = tc.clone() + dist.all_reduce(tc_max, op=dist.ReduceOp.MAX, group=ep_group) + dist.all_reduce(tc_min, op=dist.ReduceOp.MIN, group=ep_group) + assert tc_max.item() == tc_min.item(), ( + f"Padded token count mismatch across EP ranks: " + f"min={tc_min.item()}, max={tc_max.item()} " + f"(rank_states={rank_states})" + ) + self._assert_dynamic_inference_shape(model, ctx, ep_rank, my_state) + elif is_dummy: + # Eager fallback, dummy rank: bailed out of + # initialize_attention_state, run the dummy_forward path. + self._assert_dummy_forward_shape(model, ep_rank) + else: + # Eager fallback, non-dummy rank: padded_batch_dimensions + # set via the eager fallback path inside + # initialize_attention_state. + self._assert_dynamic_inference_shape(model, ctx, ep_rank, my_state) - self._assert_dynamic_inference_shape(model, ctx, ep_rank, my_state) self._assert_cuda_graphs_were_replayed( - True, ep_rank, f"state={my_state}, rank_states={rank_states}" + used_graph, ep_rank, f"state={my_state}, rank_states={rank_states}" ) # ------------------------------------------------------------------ From 4130fc403c0a98ffffce754701305d1235604625 Mon Sep 17 00:00:00 2001 From: Peter Dykas Date: Mon, 4 May 2026 14:16:58 -0700 Subject: [PATCH 2/3] fix test --- tests/unit_tests/inference/test_hybrid_moe.py | 65 +++++++++++++------ 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/tests/unit_tests/inference/test_hybrid_moe.py b/tests/unit_tests/inference/test_hybrid_moe.py index c0d4bd204c2..c716db6ebcc 100644 --- a/tests/unit_tests/inference/test_hybrid_moe.py +++ b/tests/unit_tests/inference/test_hybrid_moe.py @@ -49,9 +49,8 @@ MIXED_GIANT_PREFILL = ( "mixed_giant_prefill" # (max_requests-1) decode + 1 prefill with tokens > max_requests ) -ALL_STATES = [NONE, DECODE, PREFILL, MIXED] - _NO_CUDA_GRAPH_STATES = {PREFILL_AT_MAX_TOKENS, MIXED_GIANT_PREFILL} +ALL_STATES = [NONE, DECODE, PREFILL, MIXED, PREFILL_AT_MAX_TOKENS, MIXED_GIANT_PREFILL] # Fixed expert-parallel size. When world_size > _EP_SIZE the remaining # ranks form data-parallel replicas, each running the same EP combo @@ -62,14 +61,7 @@ # across the EP ranks. Since rank assignment is symmetric (shuffling ranks # with the same multiset of states is not a distinct configuration), we use # combinations_with_replacement rather than the full Cartesian product. -# For _EP_SIZE=4 this gives C(4+4-1, 4) = 35 test cases. -# -# Edge states (PREFILL_AT_MAX_TOKENS, DECODE_AT_MAX_REQUESTS, MIXED_GIANT_PREFILL) -# are NOT swept combinatorially: the NVLS dispatcher matches cuda-graph dims -# independently per rank, so a combo mixing edge and non-edge ranks deadlocks -# (some ranks capture cuda graphs while others run eager — the EP collective -# inside the graph never gets its peer). One rank in the edge state against a -# fixed peer is sufficient and is covered by the dedicated NCCL tests below. +# For _EP_SIZE=4 this gives C(6+4-1, 4) = 126 test cases. _STATE_COMBOS = list(itertools.combinations_with_replacement(ALL_STATES, _EP_SIZE)) @@ -200,6 +192,33 @@ def _assert_dynamic_inference_shape(self, model, ctx, rank, state_label): f"got {tuple(out.shape)}" ) + @torch.inference_mode() + def _capture_all_cuda_graphs(self, model, ctx): + """Pre-capture all cuda graphs in lockstep across EP ranks. + + Mirrors DynamicInferenceEngine.create_cuda_graphs(): iterates every + shape in cuda_graph_batch_dimensions_list and runs a forward pass + with all EP ranks at the matching shape. After this, every rank's + model.forward goes through the replay path (no warmup loop), so + capture-mode and eager-mode ranks emit the same number of EP + collectives per call. Without this, the first forward triggers + cuda_graphs.create_fwd_graph's warmup loop on capture-mode ranks + only, deadlocking against eager-mode peers in mixed combos. + """ + for graph_dim in ctx.cuda_graph_batch_dimensions_list: + ctx.reset() + ctx.initialize_attention_state(construct_graph_dimensions=graph_dim) + padded = ctx.padded_batch_dimensions + input_ids = torch.randint(0, self.VOCAB_SIZE, (1, padded.token_count), device="cuda") + model( + input_ids=input_ids, + position_ids=None, + attention_mask=None, + inference_context=ctx, + runtime_gather_output=True, + ) + ctx.reset() + @staticmethod def _assert_cuda_graphs_were_replayed(expect_replayed, rank, label): """Assert that CUDA graphs were (or were not) recorded and replayed @@ -242,14 +261,14 @@ class TestDynamicInferenceNVLS(_TestDynamicInferenceBase): @torch.inference_mode() def test_nvls_ep_state_cross_product(self, rank_states): """Test all combinatorial (unordered, with repetition) assignments of - the four request states across EP ranks. + the request states across EP ranks. - The NVLS dispatcher is used (match_ep_token_counts=False), so each rank - matches its own batch dimensions independently — no EP all-reduce. - The context is built with use_cuda_graphs_for_non_decode_steps=True, - so the CUDA graph list contains decode-only, mixed, and prefill-only - graphs, and every rank should find a matching graph for its own state - unless its token count exceeds the cuda-graph range (PREFILL_EXCEED). + The NVLS dispatcher (match_ep_token_counts=False) does per-rank + independent graph matching. Each rank finds a matching graph for its + own state unless its token count exceeds the cuda-graph range + (PREFILL_AT_MAX_TOKENS / MIXED_GIANT_PREFILL), in which case that + rank falls back to eager. The AllGather-V dispatcher handles per-rank + size variation, so mixed graph/eager combos work. State setup uses add_dummy_requests_for_cudagraph_capture to populate the context directly with the desired request configuration. @@ -261,6 +280,13 @@ def test_nvls_ep_state_cross_product(self, rank_states): model = self._build_model() ctx = self._build_context(model, max_requests=64, max_tokens=512) + # Pre-capture every cuda graph in lockstep across EP ranks (mirrors + # DynamicInferenceEngine.create_cuda_graphs in production). Without + # this, the first per-rank forward triggers create_fwd_graph's + # warmup loop on capture-mode ranks only, which fires extra EP + # collectives that an eager peer cannot match — deadlock. + self._capture_all_cuda_graphs(model, ctx) + # Phase 1: Set up each rank's request state directly. if not is_dummy: ctx.add_dummy_requests_for_cudagraph_capture(_STATE_DIMS[my_state]) @@ -271,10 +297,7 @@ def test_nvls_ep_state_cross_product(self, rank_states): else: ctx.initialize_attention_state() - # Phase 3: Verify. - # With NVLS dispatcher each rank matches independently, so every rank - # must find a graph for its own state — except PREFILL_EXCEED, whose - # token count exceeds the max cuda-graph size and falls back to eager. + # Phase 3: Verify per-rank graph match status. if my_state in _NO_CUDA_GRAPH_STATES: assert not ctx.using_cuda_graph_this_step(), ( f"EP rank {ep_rank} (state={my_state}): expected no CUDA graph match " From 11702737cea1cef12dab575a2fcf7865d7262909 Mon Sep 17 00:00:00 2001 From: Peter Dykas Date: Mon, 4 May 2026 14:19:03 -0700 Subject: [PATCH 3/3] fix comment --- tests/unit_tests/inference/test_hybrid_moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/inference/test_hybrid_moe.py b/tests/unit_tests/inference/test_hybrid_moe.py index c716db6ebcc..8bd3ebe8687 100644 --- a/tests/unit_tests/inference/test_hybrid_moe.py +++ b/tests/unit_tests/inference/test_hybrid_moe.py @@ -297,7 +297,10 @@ def test_nvls_ep_state_cross_product(self, rank_states): else: ctx.initialize_attention_state() - # Phase 3: Verify per-rank graph match status. + # Phase 3: Verify. + # With NVLS dispatcher each rank matches independently, so every rank + # must find a graph for its own state — except PREFILL_EXCEED, whose + # token count exceeds the max cuda-graph size and falls back to eager. if my_state in _NO_CUDA_GRAPH_STATES: assert not ctx.using_cuda_graph_this_step(), ( f"EP rank {ep_rank} (state={my_state}): expected no CUDA graph match "