From 75ff49f5db9375c593cb2ad1ddd7bb5f8f957d9a Mon Sep 17 00:00:00 2001 From: jinliangl Date: Mon, 9 Mar 2026 06:06:26 -0700 Subject: [PATCH 1/3] fix no_shard training convergency add unittest for no_shard and add empty cache to avoid OOM add meta_device_check for no_shard following fully_shard.py:326 Signed-off-by: jinliangl --- .../fsdp/src/megatron_fsdp/megatron_fsdp.py | 3 +++ .../megatron_fsdp/param_and_grad_buffer.py | 2 +- megatron/core/optimizer/__init__.py | 13 ++++++++++++- megatron/training/arguments.py | 10 +++++++++- .../test_mcore_fully_sharded_data_parallel.py | 19 ++++++++++++++++++- 5 files changed, 43 insertions(+), 4 deletions(-) diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py index f8640446814..e40886ac823 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py @@ -1195,6 +1195,9 @@ def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bo """ self._replace_param_with_raw_if_needed() + if self.data_parallel_sharding_strategy == "no_shard": + return + if not force_sync and self.ddp_config.overlap_param_gather: # All-gather the first bucket before the forward pass. if self.ddp_config.fsdp_all_gather_in_start_param_sync: diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py index 684cd7a99eb..bbdf3e9b04b 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py @@ -3157,7 +3157,7 @@ def all_reduce_gradients(self, async_op: bool = False): all_reduce_ops = [] for g in self.parameter_groups: gbuf = g.main_grad_buffer - if gbuf is not None: + if gbuf is None: continue scaling_factor = gbuf.gradient_scaling_factor if self.ddp_config.check_for_nan_in_grad: diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 8cfb22620bb..0319c3901c9 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -627,7 +627,18 @@ def init_state_fn(opt, config=None): # This is needed for case where num_distributed_optimizer_instances > 1. In this case, # weight gradients are all-reduced across optimizer instances, so each instance has # the duplicated weight gradients, need to reduce gradient stats inside each instance. - setattr(optimizer, 'grad_stats_parallel_group', intra_dist_opt_group) + # Besides, for Megatron-FSDP with no_shard, gradients are replicated across DP ranks (after + # all-reduce), so grad stats should only be reduced across model-parallel ranks + # (TP*PP) to avoid inflating the grad norm by sqrt(DP_world_size). + ddp_config = getattr(model_chunks[0], 'ddp_config', None) + if ( + ddp_config is not None + and getattr(ddp_config, 'use_megatron_fsdp', False) + and ddp_config.data_parallel_sharding_strategy == 'no_shard' + ): + setattr(optimizer, 'grad_stats_parallel_group', model_parallel_group) + else: + setattr(optimizer, 'grad_stats_parallel_group', intra_dist_opt_group) else: optimizer = Float16OptimizerWithFloat16Params(*optimizer_args) setattr(optimizer, 'grad_stats_parallel_group', model_parallel_group) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index a4d938f2e7b..a359ef48372 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -965,7 +965,15 @@ def validate_args(args, defaults={}): assert args.ckpt_format == "fsdp_dtensor", \ "Megatron FSDP only supports fsdp_dtensor checkpoint format" - + + args.reuse_grad_buf_for_mxfp8_param_ag = False + + if args.init_model_with_meta_device and args.data_parallel_sharding_strategy == "no_shard": + raise ValueError( + "Meta device initialization (init_model_with_meta_device=True) is not " + "supported or necessary for the 'no_shard' / 0 sharding strategy." + ) + if args.fsdp_manual_registration: assert args.use_megatron_fsdp, "FSDP manual registration is only supported with Megatron FSDP" assert args.nccl_ub, "FSDP manual registration is only supported with nccl-ub option" diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py index 0271da1fed9..15910af81a6 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py +++ b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py @@ -1,4 +1,5 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +import gc import random import numpy as np @@ -321,6 +322,10 @@ def train_step(model, optimizer, inputs): msg=f"Parameters for {name1} don't match", ) + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + # Testing fsdp_double_buffer with and without nccl_ub @pytest.mark.parametrize( ("dp_size", "nccl_ub", "fsdp_double_buffer", "fsdp_manual_registration"), @@ -486,6 +491,10 @@ def train_step(model, optimizer, inputs): msg=f"Parameters for {name1} don't match", ) + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + @classmethod def hsdp_one_step_test(cls, num_fsdp_group): if not is_torch_min_version("2.4.0"): @@ -738,6 +747,7 @@ def _training_loop(seed=42, **kwargs): ("optim_grads_params", True), ("optim_grads", False), ("optim", True), + ("no_shard", False), ], ) def test_compatible_with_nd_parallel( @@ -749,10 +759,13 @@ def test_compatible_with_nd_parallel( use_distributed_optimizer=True ) + # no_shard is incompatible with meta device initialization. See fully_shard.py:326. + init_model_with_meta_device = True if fsdp_sharding_strategy != "no_shard" else False + outputs = TestMegatronFSDPE2E._training_loop( use_megatron_fsdp=True, data_parallel_sharding_strategy=fsdp_sharding_strategy, - init_model_with_meta_device=True, + init_model_with_meta_device=init_model_with_meta_device, ckpt_format="fsdp_dtensor", gradient_accumulation_fusion=False, fsdp_double_buffer=use_double_buffer, @@ -775,6 +788,10 @@ def test_compatible_with_nd_parallel( ), ) + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + def compare_losses(loss_a: float, loss_b: float, reference: str = "b"): """ From f7da7064d7e3c2f9361a2735f30629cabc5cbdd5 Mon Sep 17 00:00:00 2001 From: jinliangl Date: Thu, 19 Mar 2026 08:33:52 -0700 Subject: [PATCH 2/3] use more elegant way to indicate nosharding strategy --- megatron/core/optimizer/__init__.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 0319c3901c9..992bef0808d 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -627,18 +627,7 @@ def init_state_fn(opt, config=None): # This is needed for case where num_distributed_optimizer_instances > 1. In this case, # weight gradients are all-reduced across optimizer instances, so each instance has # the duplicated weight gradients, need to reduce gradient stats inside each instance. - # Besides, for Megatron-FSDP with no_shard, gradients are replicated across DP ranks (after - # all-reduce), so grad stats should only be reduced across model-parallel ranks - # (TP*PP) to avoid inflating the grad norm by sqrt(DP_world_size). - ddp_config = getattr(model_chunks[0], 'ddp_config', None) - if ( - ddp_config is not None - and getattr(ddp_config, 'use_megatron_fsdp', False) - and ddp_config.data_parallel_sharding_strategy == 'no_shard' - ): - setattr(optimizer, 'grad_stats_parallel_group', model_parallel_group) - else: - setattr(optimizer, 'grad_stats_parallel_group', intra_dist_opt_group) + setattr(optimizer, 'grad_stats_parallel_group', intra_dist_opt_group) else: optimizer = Float16OptimizerWithFloat16Params(*optimizer_args) setattr(optimizer, 'grad_stats_parallel_group', model_parallel_group) @@ -749,6 +738,13 @@ def get_megatron_optimizer( model_chunk_offset = 0 ddp_config = model_chunks[0].ddp_config # Use the first model chunk's DDP config if ddp_config.use_megatron_fsdp: + # For no_shard, gradients are replicated across DP ranks after all-reduce, so grad stats + # should only be reduced over TP/PP (model_parallel_group) to avoid inflating the norm. + effective_intra_dist_opt_group = ( + mp_group + if ddp_config.data_parallel_sharding_strategy == 'no_shard' + else intra_dist_opt_group + ) for model_chunk, overlap_param_gather_with_optimizer_step in zip( all_dense_model_chunks, overlap_param_gather_with_optimizer_step_flags ): @@ -771,7 +767,7 @@ def get_megatron_optimizer( data_parallel_group=dp_cp_group, data_parallel_group_gloo=intra_dp_cp_group_gloo, data_parallel_group_idx=model_parallel_rank, - intra_dist_opt_group=intra_dist_opt_group, + intra_dist_opt_group=effective_intra_dist_opt_group, distributed_optimizer_instance_id=distributed_optimizer_instance_id, pg_collection=pg_collection, ) From 803c6a034a3536bc7bad115e6567b0c78bbcf506 Mon Sep 17 00:00:00 2001 From: jinliangl Date: Sun, 22 Mar 2026 07:37:07 -0700 Subject: [PATCH 3/3] minor change --- megatron/training/arguments.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index a359ef48372..3f82ca8c92d 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -966,8 +966,6 @@ def validate_args(args, defaults={}): assert args.ckpt_format == "fsdp_dtensor", \ "Megatron FSDP only supports fsdp_dtensor checkpoint format" - args.reuse_grad_buf_for_mxfp8_param_ag = False - if args.init_model_with_meta_device and args.data_parallel_sharding_strategy == "no_shard": raise ValueError( "Meta device initialization (init_model_with_meta_device=True) is not "