Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion megatron/core/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,9 @@ def _build_model_and_main_param_groups(

# fp32 params.
elif model_param.type() == 'torch.cuda.FloatTensor':
shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
shard_model_param = model_param.detach().view(-1)[
param_range.start : param_range.end
]
model_fp32_params_this_group.append(model_param)
shard_fp32_params_this_group.append(shard_model_param)
tensor_parallel.copy_tensor_model_parallel_attributes(
Expand Down Expand Up @@ -604,6 +606,23 @@ def __init__(
self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges]
self.optimizer.load_state_dict(self.optimizer.state_dict())

# Rebuild model_param_group_index_map to reflect parameter reordering.
# The _build_model_and_main_param_groups method reorders parameters by dtype
# (FP32 first, then FP16/BF16), so we need to update the mapping to match
# the new positions in optimizer.param_groups.
for group_index, group_range in enumerate(self.opt_group_ranges):
param_order = 0
# First, add FP32 params (in the same order as they appear in group_range["params"])
for model_param in group_range["params"]:
if model_param.type() == 'torch.cuda.FloatTensor':
self.model_param_group_index_map[model_param] = (group_index, param_order)
param_order += 1
# Then, add FP16/BF16 params (in the same order as they appear in group_range["params"])
for model_param in group_range["params"]:
if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
self.model_param_group_index_map[model_param] = (group_index, param_order)
param_order += 1

def _get_model_param_range_map(self, param: torch.nn.Parameter):
"""
Given a model param, get the index sub-range of the param that this
Expand Down
99 changes: 99 additions & 0 deletions tests/unit_tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,3 +750,102 @@ def test_get_megatron_optimizer_custom_process_groups_validation():
use_gloo_process_groups=True, # Should be False when using custom groups
pg_collection=pg_collection_complete,
)


@pytest.mark.parametrize("use_distributed_optimizer", [True])
def test_mixed_precision_param_index_map(use_distributed_optimizer: bool):
"""
Test that model_param_group_index_map stays synchronized after parameter reordering.

This test addresses issue #2777 where the index map becomes stale after
_build_model_and_main_param_groups reorders parameters by dtype (FP32 first,
then FP16/BF16). The test creates a model with mixed precision parameters
and verifies that checkpoint operations work correctly.
"""
world = int(os.getenv('WORLD_SIZE', '1'))
rank = int(os.getenv('RANK', '0'))

# Setup distributed environment
_init_distributed(world, rank)
Utils.initialize_model_parallel()

# Create a model with mixed precision parameters
# We'll manually set some parameters to FP32 and others to BF16
class MixedPrecisionModel(nn.Module):
def __init__(self):
super().__init__()
# First layer in BF16
self.fc1 = nn.Linear(100, 50, bias=False, dtype=torch.bfloat16, device='cuda')
# Second layer in FP32 (simulating manual precision promotion)
self.fc2 = nn.Linear(50, 30, bias=False, dtype=torch.float32, device='cuda')
# Third layer in BF16
self.fc3 = nn.Linear(30, 10, bias=False, dtype=torch.bfloat16, device='cuda')

def forward(self, x):
x = self.fc1(x)
x = self.fc2(x.float())
x = self.fc3(x.bfloat16())
return x

model = MixedPrecisionModel()
model.requires_grad_(True)

# Wrap with DDP
ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=use_distributed_optimizer)
model = DistributedDataParallel(
TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model
)

# Create optimizer with distributed optimizer enabled
optimizer_config = OptimizerConfig(
optimizer='adam', bf16=True, use_distributed_optimizer=use_distributed_optimizer
)
optim = get_megatron_optimizer(optimizer_config, [model])

# Access the underlying distributed optimizer
if use_distributed_optimizer:
dist_optim = optim.optimizer

# Verify that model_param_group_index_map is correctly synchronized
# After the fix, the map should reflect the reordered parameters
for model_param in dist_optim.model_param_group_index_map.keys():
group_index, group_order = dist_optim.model_param_group_index_map[model_param]

# Verify the index points to a valid parameter
assert group_index < len(
dist_optim.optimizer.param_groups
), f"group_index {group_index} out of range"
assert group_order < len(
dist_optim.optimizer.param_groups[group_index]["params"]
), f"group_order {group_order} out of range for group {group_index}"

# Get the corresponding optimizer parameter
opt_param = dist_optim.optimizer.param_groups[group_index]["params"][group_order]

# Verify the sizes match (this would fail before the fix)
model_param_range = dist_optim._get_model_param_range_map(model_param)
param_range = model_param_range["param"]
assert param_range.size == opt_param.numel(), (
f"Size mismatch: model param range size {param_range.size} "
f"!= optimizer param size {opt_param.numel()}"
)

# Run a forward/backward pass to populate optimizer state
input_data = torch.randn(8, 100, dtype=torch.bfloat16, device='cuda')
output = model(input_data)
loss = output.sum()
loss.backward()
optim.step()

# Test get_parameter_state_dp_zero (the function that was failing in issue #2777)
# This should work without size mismatch errors
try:
state_dict = dist_optim.get_parameter_state_dp_zero()
# Verify state_dict was created successfully
if rank == 0 or state_dict is not None:
assert state_dict is not None, "Failed to get parameter state"
assert 'buckets_coalesced' in state_dict, "Missing expected keys in state dict"
except RuntimeError as e:
pytest.fail(f"get_parameter_state_dp_zero failed with error: {e}")

_deinit_distributed()
Loading