diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 337485b4d12..63857495f53 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -2092,6 +2092,16 @@ def get_all_ranks(): def destroy_model_parallel(): """Set the groups to none.""" + # Destroy NCCL subgroups before nulling references — without this their + # NVLS multicast bindings leak across re-inits and eventually surface as + # a spurious OOM in transport/nvls.cc. Entry 0 is the default group. + global _global_process_group_list + if _global_process_group_list is not None: + pg_map = torch.distributed.distributed_c10d._world.pg_map + for group in _global_process_group_list[1:]: + if group is not None and pg_map.get(group, None) is not None: + torch.distributed.destroy_process_group(group) + global _MODEL_PARALLEL_GROUP _MODEL_PARALLEL_GROUP = None @@ -2232,7 +2242,6 @@ def destroy_model_parallel(): global _INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP _INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP = None - global _global_process_group_list _global_process_group_list = None SymmetricMemoryManager.destroy() diff --git a/tests/unit_tests/inference/test_hybrid_moe.py b/tests/unit_tests/inference/test_hybrid_moe.py index dd28225077e..8bd3ebe8687 100644 --- a/tests/unit_tests/inference/test_hybrid_moe.py +++ b/tests/unit_tests/inference/test_hybrid_moe.py @@ -49,9 +49,8 @@ MIXED_GIANT_PREFILL = ( "mixed_giant_prefill" # (max_requests-1) decode + 1 prefill with tokens > max_requests ) -ALL_STATES = [NONE, DECODE, PREFILL, MIXED] - _NO_CUDA_GRAPH_STATES = {PREFILL_AT_MAX_TOKENS, MIXED_GIANT_PREFILL} +ALL_STATES = [NONE, DECODE, PREFILL, MIXED, PREFILL_AT_MAX_TOKENS, MIXED_GIANT_PREFILL] # Fixed expert-parallel size. When world_size > _EP_SIZE the remaining # ranks form data-parallel replicas, each running the same EP combo @@ -62,11 +61,7 @@ # across the EP ranks. Since rank assignment is symmetric (shuffling ranks # with the same multiset of states is not a distinct configuration), we use # combinations_with_replacement rather than the full Cartesian product. -# For _EP_SIZE=4 this gives C(4+4-1, 4) = 35 test cases. -# -# Edge states (PREFILL_AT_MAX_TOKENS, DECODE_AT_MAX_REQUESTS, MIXED_GIANT_PREFILL) -# are not swept combinatorially — one rank in the edge state against a fixed -# peer is sufficient. +# For _EP_SIZE=4 this gives C(6+4-1, 4) = 126 test cases. _STATE_COMBOS = list(itertools.combinations_with_replacement(ALL_STATES, _EP_SIZE)) @@ -197,6 +192,33 @@ def _assert_dynamic_inference_shape(self, model, ctx, rank, state_label): f"got {tuple(out.shape)}" ) + @torch.inference_mode() + def _capture_all_cuda_graphs(self, model, ctx): + """Pre-capture all cuda graphs in lockstep across EP ranks. + + Mirrors DynamicInferenceEngine.create_cuda_graphs(): iterates every + shape in cuda_graph_batch_dimensions_list and runs a forward pass + with all EP ranks at the matching shape. After this, every rank's + model.forward goes through the replay path (no warmup loop), so + capture-mode and eager-mode ranks emit the same number of EP + collectives per call. Without this, the first forward triggers + cuda_graphs.create_fwd_graph's warmup loop on capture-mode ranks + only, deadlocking against eager-mode peers in mixed combos. + """ + for graph_dim in ctx.cuda_graph_batch_dimensions_list: + ctx.reset() + ctx.initialize_attention_state(construct_graph_dimensions=graph_dim) + padded = ctx.padded_batch_dimensions + input_ids = torch.randint(0, self.VOCAB_SIZE, (1, padded.token_count), device="cuda") + model( + input_ids=input_ids, + position_ids=None, + attention_mask=None, + inference_context=ctx, + runtime_gather_output=True, + ) + ctx.reset() + @staticmethod def _assert_cuda_graphs_were_replayed(expect_replayed, rank, label): """Assert that CUDA graphs were (or were not) recorded and replayed @@ -239,14 +261,14 @@ class TestDynamicInferenceNVLS(_TestDynamicInferenceBase): @torch.inference_mode() def test_nvls_ep_state_cross_product(self, rank_states): """Test all combinatorial (unordered, with repetition) assignments of - the four request states across EP ranks. + the request states across EP ranks. - The NVLS dispatcher is used (match_ep_token_counts=False), so each rank - matches its own batch dimensions independently — no EP all-reduce. - The context is built with use_cuda_graphs_for_non_decode_steps=True, - so the CUDA graph list contains decode-only, mixed, and prefill-only - graphs, and every rank should find a matching graph for its own state - unless its token count exceeds the cuda-graph range (PREFILL_EXCEED). + The NVLS dispatcher (match_ep_token_counts=False) does per-rank + independent graph matching. Each rank finds a matching graph for its + own state unless its token count exceeds the cuda-graph range + (PREFILL_AT_MAX_TOKENS / MIXED_GIANT_PREFILL), in which case that + rank falls back to eager. The AllGather-V dispatcher handles per-rank + size variation, so mixed graph/eager combos work. State setup uses add_dummy_requests_for_cudagraph_capture to populate the context directly with the desired request configuration. @@ -258,6 +280,13 @@ def test_nvls_ep_state_cross_product(self, rank_states): model = self._build_model() ctx = self._build_context(model, max_requests=64, max_tokens=512) + # Pre-capture every cuda graph in lockstep across EP ranks (mirrors + # DynamicInferenceEngine.create_cuda_graphs in production). Without + # this, the first per-rank forward triggers create_fwd_graph's + # warmup loop on capture-mode ranks only, which fires extra EP + # collectives that an eager peer cannot match — deadlock. + self._capture_all_cuda_graphs(model, ctx) + # Phase 1: Set up each rank's request state directly. if not is_dummy: ctx.add_dummy_requests_for_cudagraph_capture(_STATE_DIMS[my_state])