Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/mcore_bridge/model/modules/gated_delta_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn.functional as F
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default
from typing import List, Optional

try:
Expand Down Expand Up @@ -312,3 +313,49 @@ def forward(
nvtx_range_pop(suffix='out_proj')

return out, out_bias

def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None, tp_group=None):
"""Provide a sharded state dictionary for distributed checkpointing."""
from megatron.core.transformer.utils import ensure_metadata_has_dp_cp_group

# Guard for cases metadata is not provided
metadata = ensure_metadata_has_dp_cp_group(metadata)

sharded_state_dict = {}
# Parameters
self._save_to_state_dict(sharded_state_dict, '', keep_vars=True)
sharded_state_dict = make_sharded_tensors_for_checkpoint(
sharded_state_dict,
prefix,
tensor_parallel_layers_axis_map={
'A_log': 0,
'dt_bias': 0,
}, # parameters sharded across TP
sharded_offsets=sharded_offsets,
tp_group=(tp_group if tp_group is not None else self.pg_collection.tp),
dp_cp_group=metadata['dp_cp_group'],
)
# Submodules
tp_group = tp_group if tp_group is not None else self.pg_collection.tp
for name, module in self.named_children():
if name == 'conv1d':
# Add TP sharding for Conv1d
module_sd = module.state_dict(prefix='', keep_vars=True)
tp_sharding_map = {'weight': 0}
if self.conv_bias:
tp_sharding_map['bias'] = 0
module_sharded_sd = make_sharded_tensors_for_checkpoint(
module_sd,
f'{prefix}{name}.',
tp_sharding_map,
sharded_offsets,
tp_group=tp_group,
dp_cp_group=metadata['dp_cp_group'],
)
else:
module_sharded_sd = sharded_state_dict_default(
module, f'{prefix}{name}.', sharded_offsets, metadata, tp_group=tp_group)

sharded_state_dict.update(module_sharded_sd)

return sharded_state_dict
Loading