diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index 93b129a..87580fd 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default from typing import List, Optional try: @@ -312,3 +313,49 @@ def forward( nvtx_range_pop(suffix='out_proj') return out, out_bias + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None, tp_group=None): + """Provide a sharded state dictionary for distributed checkpointing.""" + from megatron.core.transformer.utils import ensure_metadata_has_dp_cp_group + + # Guard for cases metadata is not provided + metadata = ensure_metadata_has_dp_cp_group(metadata) + + sharded_state_dict = {} + # Parameters + self._save_to_state_dict(sharded_state_dict, '', keep_vars=True) + sharded_state_dict = make_sharded_tensors_for_checkpoint( + sharded_state_dict, + prefix, + tensor_parallel_layers_axis_map={ + 'A_log': 0, + 'dt_bias': 0, + }, # parameters sharded across TP + sharded_offsets=sharded_offsets, + tp_group=(tp_group if tp_group is not None else self.pg_collection.tp), + dp_cp_group=metadata['dp_cp_group'], + ) + # Submodules + tp_group = tp_group if tp_group is not None else self.pg_collection.tp + for name, module in self.named_children(): + if name == 'conv1d': + # Add TP sharding for Conv1d + module_sd = module.state_dict(prefix='', keep_vars=True) + tp_sharding_map = {'weight': 0} + if self.conv_bias: + tp_sharding_map['bias'] = 0 + module_sharded_sd = make_sharded_tensors_for_checkpoint( + module_sd, + f'{prefix}{name}.', + tp_sharding_map, + sharded_offsets, + tp_group=tp_group, + dp_cp_group=metadata['dp_cp_group'], + ) + else: + module_sharded_sd = sharded_state_dict_default( + module, f'{prefix}{name}.', sharded_offsets, metadata, tp_group=tp_group) + + sharded_state_dict.update(module_sharded_sd) + + return sharded_state_dict