Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,9 @@ def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bo
"""
self._replace_param_with_raw_if_needed()

if self.data_parallel_sharding_strategy == "no_shard":
return

if not force_sync and self.ddp_config.overlap_param_gather:
# All-gather the first bucket before the forward pass.
if self.ddp_config.fsdp_all_gather_in_start_param_sync:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3246,7 +3246,7 @@ def all_reduce_gradients(self, async_op: bool = False):
all_reduce_ops = []
for g in self.parameter_groups:
gbuf = g.main_grad_buffer
if gbuf is not None:
if gbuf is None:
continue
scaling_factor = gbuf.gradient_scaling_factor
if self.ddp_config.check_for_nan_in_grad:
Expand Down
9 changes: 8 additions & 1 deletion megatron/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,13 @@ def get_megatron_optimizer(
model_chunk_offset = 0
ddp_config = model_chunks[0].ddp_config # Use the first model chunk's DDP config
if ddp_config.use_megatron_fsdp:
# For no_shard, gradients are replicated across DP ranks after all-reduce, so grad stats
# should only be reduced over TP/PP (model_parallel_group) to avoid inflating the norm.
effective_intra_dist_opt_group = (
mp_group
if ddp_config.data_parallel_sharding_strategy == 'no_shard'
else intra_dist_opt_group
)
for model_chunk, overlap_param_gather_with_optimizer_step in zip(
all_dense_model_chunks, overlap_param_gather_with_optimizer_step_flags
):
Expand All @@ -960,7 +967,7 @@ def get_megatron_optimizer(
data_parallel_group=dp_cp_group,
data_parallel_group_gloo=intra_dp_cp_group_gloo,
data_parallel_group_idx=model_parallel_rank,
intra_dist_opt_group=intra_dist_opt_group,
intra_dist_opt_group=effective_intra_dist_opt_group,
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
pg_collection=pg_collection,
)
Expand Down
8 changes: 7 additions & 1 deletion megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,13 @@ def validate_args(args, defaults={}):

assert args.ckpt_format == "fsdp_dtensor", \
"Megatron-FSDP requires the `fsdp_dtensor` checkpointing format."


if args.init_model_with_meta_device and args.data_parallel_sharding_strategy == "no_shard":
raise ValueError(
"Meta device initialization (init_model_with_meta_device=True) is not "
"supported or necessary for the 'no_shard' / 0 sharding strategy."
)

if args.fsdp_manual_registration:
assert args.use_megatron_fsdp, "FSDP manual registration is only supported with Megatron FSDP."
assert args.nccl_ub, "FSDP manual registration is only supported with --nccl-ub argument."
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import copy
import gc
import random

import numpy as np
Expand Down Expand Up @@ -323,51 +324,9 @@ def train_step(model, optimizer, inputs):
msg=f"Parameters for {name1} don't match",
)

def test_fsdp_expt_device_mesh(self):
"""Test that expt_device_mesh is None for dense models and not None for MoE models."""
if not is_torch_min_version("2.4.0"):
pytest.skip("Megatron FSDP requires torch >= 2.4.0")

fsdp_config = DistributedDataParallelConfig(
data_parallel_sharding_strategy="optim_grads_params",
overlap_grad_reduce=True,
overlap_param_gather=True,
bucket_size=10000,
use_megatron_fsdp=True,
)
input_dim, output_dim = 13, 17

# Dense model: expt_device_mesh should not be built without MoE config
dense_config = TransformerConfig(
num_attention_heads=1, num_layers=1, context_parallel_size=1
)
dense_model = TestModel(input_dim=input_dim, output_dim=output_dim).cuda()
fsdp_dense = FullyShardedDataParallel(
config=dense_config,
ddp_config=fsdp_config,
module=dense_model,
fsdp_unit_modules=[torch.nn.Linear],
)
assert (
fsdp_dense.megatron_fsdp_dist_index.expt_device_mesh is None
), "Dense model: expt_device_mesh should be None"
fsdp_dense.stop_communication()

# MoE model: expt_device_mesh should be built when num_moe_experts is set
moe_config = TransformerConfig(
num_attention_heads=1, num_layers=1, context_parallel_size=1, num_moe_experts=4
)
moe_model = TestModel(input_dim=input_dim, output_dim=output_dim).cuda()
fsdp_moe = FullyShardedDataParallel(
config=moe_config,
ddp_config=fsdp_config,
module=moe_model,
fsdp_unit_modules=[torch.nn.Linear],
)
assert (
fsdp_moe.megatron_fsdp_dist_index.expt_device_mesh is not None
), "MoE model: expt_device_mesh should not be None"
fsdp_moe.stop_communication()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()

# Testing fsdp_double_buffer with and without nccl_ub
@pytest.mark.parametrize(
Expand Down Expand Up @@ -534,6 +493,10 @@ def train_step(model, optimizer, inputs):
msg=f"Parameters for {name1} don't match",
)

gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()

@classmethod
def hsdp_one_step_test(cls, num_fsdp_group):
if not is_torch_min_version("2.4.0"):
Expand Down Expand Up @@ -836,6 +799,10 @@ def _training_loop(seed=42, **kwargs):
dict(data_parallel_sharding_strategy="optim", fsdp_double_buffer=False),
id="optim_double_buffer",
),
pytest.param(
dict(data_parallel_sharding_strategy="no_shard", fsdp_double_buffer=False),
id="no_shard_no_double_buffer",
),
],
)
def test_compatible_with_nd_parallel(self, ref_cache, nd_topology, spec_configs):
Expand All @@ -852,9 +819,12 @@ def test_compatible_with_nd_parallel(self, ref_cache, nd_topology, spec_configs)
use_distributed_optimizer=True, **distopt_spec_configs
)

# no_shard is incompatible with meta device initialization. See fully_shard.py:326.
init_model_with_meta_device = True if spec_configs['data_parallel_sharding_strategy'] != "no_shard" else False

outputs = TestMegatronFSDPE2E._training_loop(
use_megatron_fsdp=True,
init_model_with_meta_device=True,
init_model_with_meta_device=init_model_with_meta_device,
ckpt_format="fsdp_dtensor",
gradient_accumulation_fusion=False,
**spec_configs,
Expand All @@ -877,6 +847,10 @@ def test_compatible_with_nd_parallel(self, ref_cache, nd_topology, spec_configs)
),
)

gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()


def compare_losses(loss_a: float, loss_b: float, reference: str = "b"):
"""
Expand Down
Loading