Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions megatron/core/models/mamba/mamba_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.pipeline_parallel.fine_grained_activation_offload import (
FineGrainedActivationOffloadingInterface as off_interface,
)
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.quantization.utils import get_quant_config_or_none
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
Expand Down Expand Up @@ -201,6 +204,8 @@ def __init__(
quant_config = get_quant_config_or_none(name, self.config.quant_recipe)
module.finish_init(quant_config)

self.disable_param_offloading = True

def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.

Expand All @@ -217,6 +222,24 @@ def set_input_tensor(self, input_tensor: Tensor) -> None:
assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert'
self.decoder.set_input_tensor(input_tensor[0])

def preprocess_for_fine_grained_offloading(self):
"""Preprocess for fine-grained activation offloading."""
off_interface.init_chunk_handler(
vp_size=self.config.virtual_pipeline_model_parallel_size,
vp_stage=self.vp_stage,
min_offloaded_tensor_size=self.config.min_offloaded_tensor_size,
)
if self.disable_param_offloading:
for param in self.decoder.parameters():
off_interface.mark_not_offloadable(param)
if self.pre_process:
for param in self.embedding.parameters():
off_interface.mark_not_offloadable(param)
if self.post_process:
for param in self.output_layer.parameters():
off_interface.mark_not_offloadable(param)
self.disable_param_offloading = False

def forward(
self,
input_ids: Tensor,
Expand All @@ -241,6 +264,9 @@ def forward(
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.

if self.config.fine_grained_activation_offloading:
self.preprocess_for_fine_grained_offloading()

inference_context = deprecate_inference_params(inference_context, inference_params)

in_inference_mode = inference_context is not None and not self.training
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.


from collections import deque
from contextlib import nullcontext
Expand All @@ -10,7 +11,10 @@
DEBUG = False
DEBUG_RANK = 0

from megatron.core.transformer.cuda_graphs import is_graph_capturing
from megatron.core.transformer.cuda_graphs import (
is_graph_capturing,
set_external_join_stream_for_graph_capture,
)


def debug_rank(message):
Expand Down Expand Up @@ -341,6 +345,11 @@ def __init__(self, name):
self.offload = True
self.total_offload_bytes = 0
self.total_tensor_count = 0
# Events should be created with `external=True` in case of graph capture, since the record
# and synchronization of the event may occur in different graphs. Create the event lazily
# for back compatibility with older pytorch versions.
self._offload_event_cudagraph = None
self._reload_event_cudagraph = None
# Using memory pool is for the compatibility with cuda graph.
# Shapes of tensors for expert_fc1 and moe_act are not known in advance,
# so we do not use CPU pool for them.
Expand All @@ -359,19 +368,35 @@ def pop_tensor(self, tag):

def record_offload_event(self, stream):
"""Record the offload event."""
self._offload_event.record(stream)
if is_graph_capturing():
if self._offload_event_cudagraph is None:
self._offload_event_cudagraph = torch.cuda.Event(external=True)
self._offload_event_cudagraph.record(stream)
else:
self._offload_event.record(stream)

def wait_offload_event(self, stream):
"""Wait for the offload event."""
stream.wait_event(self._offload_event)
if is_graph_capturing():
stream.wait_event(self._offload_event_cudagraph)
else:
stream.wait_event(self._offload_event)

def record_reload_event(self, stream):
"""Record the reload event."""
self._reload_event.record(stream)
if is_graph_capturing():
if self._reload_event_cudagraph is None:
self._reload_event_cudagraph = torch.cuda.Event(external=True)
self._reload_event_cudagraph.record(stream)
else:
self._reload_event.record(stream)

def wait_reload_event(self, stream):
"""Wait for the reload event."""
stream.wait_event(self._reload_event)
if is_graph_capturing():
stream.wait_event(self._reload_event_cudagraph)
else:
stream.wait_event(self._reload_event)

def update_offload_info(self, tensor):
"""Update the offload information."""
Expand Down Expand Up @@ -867,6 +892,9 @@ def tensor_need_offloading_checker(self, tensor):
# Respect tensor's offload preference if specified
if hasattr(tensor, "offloading_activation") and not tensor.offloading_activation:
return False
if hasattr(tensor, "_TE_do_not_offload") and tensor._TE_do_not_offload:
return False

return True

def bulk_offload_group(self):
Expand Down Expand Up @@ -903,8 +931,7 @@ def bulk_reload_group(self):
torch.cuda.nvtx.range_push("activation reloading " + group_to_reload._name)
with torch.cuda.stream(self.h2d_stream):
# Wait for offload to complete before reloading
if not is_graph_capturing():
group_to_reload.wait_offload_event(self.h2d_stream)
group_to_reload.wait_offload_event(self.h2d_stream)
for tensor_tag, state in group_to_reload._tensors.items():
# Only reload if tensor was offloaded (stored as tuple)
if isinstance(state, tuple):
Expand Down Expand Up @@ -969,6 +996,11 @@ def on_group_commit_forward(self, forced_released_tensors):
if not self.do_offload:
return
debug_rank("--on_group_commit_forward")

if is_graph_capturing():
# Mark that d2h_stream is used so it gets joined before capture ends
set_external_join_stream_for_graph_capture(self.d2h_stream)

# Wait for compute to finish before starting offload
self.d2h_stream.wait_stream(torch.cuda.current_stream())
self.bulk_offload(forced_released_tensors)
Expand Down Expand Up @@ -1005,7 +1037,7 @@ def on_group_commit_backward(self, name):
cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk()
assert cur_backward_chunk is self, f"Chunk mismatch {cur_backward_chunk} {self}"
# Wait for reload to complete before using tensors
if not is_graph_capturing() and len(self._reloading_group) > 0:
if len(self._reloading_group) > 0:
for reloading_group in self._reloading_group:
if reloading_group._name == name:
reloading_group.wait_reload_event(torch.cuda.current_stream())
Expand Down Expand Up @@ -1042,6 +1074,10 @@ def on_group_start_backward(self):
if not self.do_offload:
return
debug_rank(f"--on_group_start_backward {self}")

if is_graph_capturing():
set_external_join_stream_for_graph_capture(self.h2d_stream)

# Wait for compute to finish before starting reload
self.h2d_stream.wait_stream(torch.cuda.current_stream())
self.bulk_reload()
Expand Down
73 changes: 53 additions & 20 deletions megatron/core/ssm/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
tensor_merge,
)
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.pipeline_parallel.fine_grained_activation_offload import (
FineGrainedActivationOffloadingInterface as off_interface,
)
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.transformer import TransformerConfig
Expand Down Expand Up @@ -397,6 +400,21 @@ def __init__(
)
self.tp_group = pg_collection.tp

self.offload_in_proj = (
self.config.fine_grained_activation_offloading
and "mamba_in_proj" in self.config.offload_modules
)

self.offload_out_proj = (
self.config.fine_grained_activation_offloading
and "mamba_out_proj" in self.config.offload_modules
)

self.offload_ssm = (
self.config.fine_grained_activation_offloading
and "mamba_ssm" in self.config.offload_modules
)

def forward(
self,
hidden_states,
Expand Down Expand Up @@ -429,7 +447,13 @@ def forward(
out, out_bias = self._decode(hidden_states, conv_state, ssm_state)
return out, out_bias

zxBCdt, _ = self.in_proj(hidden_states)
with off_interface(self.offload_in_proj, hidden_states, "mamba_in_proj") as hidden_states:
zxBCdt, _ = self.in_proj(hidden_states)

if self.offload_in_proj:
zxBCdt = off_interface.group_commit(
zxBCdt, name="mamba_in_proj", forced_released_tensors=[]
)

zxBCdt = self.cp.pre_conv_ssm(zxBCdt, packed_seq_params)

Expand All @@ -444,7 +468,11 @@ def forward(
assert ssm_state is None
y = self._ssm_training(zxBCdt, packed_seq_params)

out, out_bias = self.out_proj(y)
with off_interface(self.offload_out_proj, y, "mamba_out_proj") as y:
out, out_bias = self.out_proj(y)

if self.offload_out_proj:
out = off_interface.group_commit(out, name="mamba_out_proj", forced_released_tensors=[])

return out, out_bias

Expand Down Expand Up @@ -656,24 +684,29 @@ def _ssm_training(
assert sequence_packing_available, reason_for_no_sequence_packing
seq_idx = self._create_packed_seq_idx(packed_seq_params, zxBCdt.shape[1])

y = mamba_split_conv1d_scan_combined(
zxBCdt,
rearrange(self.cp.get_conv1d_weight(), "d 1 w -> d w"),
self.cp.get_conv1d_bias(),
self.cp.get_dt_bias().float(),
A,
D=(
rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim)
if self.D_has_hdim
else self.cp.get_D()
),
chunk_size=self.chunk_size,
activation=self.activation,
headdim=None if self.D_has_hdim else self.headdim,
ngroups=self.cp.ngroups_local_tpcp,
norm_before_gate=self.norm_before_gate,
seq_idx=seq_idx,
)
with off_interface(self.offload_ssm, zxBCdt, "mamba_ssm") as zxBCdt:

y = mamba_split_conv1d_scan_combined(
zxBCdt,
rearrange(self.cp.get_conv1d_weight(), "d 1 w -> d w"),
self.cp.get_conv1d_bias(),
self.cp.get_dt_bias().float(),
A,
D=(
rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim)
if self.D_has_hdim
else self.cp.get_D()
),
chunk_size=self.chunk_size,
activation=self.activation,
headdim=None if self.D_has_hdim else self.headdim,
ngroups=self.cp.ngroups_local_tpcp,
norm_before_gate=self.norm_before_gate,
seq_idx=seq_idx,
)

if self.offload_ssm:
y = off_interface.group_commit(y, name="mamba_ssm", forced_released_tensors=[])

y = rearrange(y, "b l d -> l b d").contiguous()
y = self.cp.post_conv_ssm(y, packed_seq_params)
Expand Down
Loading
Loading