Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 212 additions & 10 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,19 +430,204 @@ def __new__(cls, config: TransformerConfig):
TEActivationOp = None


if HAVE_TE and is_te_min_version("1.13.0"):

class TEFusedResidualRMSNorm(te.pytorch.RMSNorm):
"""
RMSNorm with fused residual output for Megatron Core.

Inherits from te.pytorch.RMSNorm to maintain all parameter management,
checkpoint compatibility, and Megatron-specific features. Creates a fused
implementation using TE's ops API that shares the base class parameters.

The fused implementation uses:
- MakeExtraOutput: Forks the residual connection
- RMSNorm: Normalizes the main path

Forward pass returns: (normalized_output, residual)
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Fused implementation (stored in tuple to avoid submodule registration)
self._fused_impl: Optional[Tuple[te.pytorch.ops.Sequential]] = None

def _make_fused_impl(self) -> te.pytorch.ops.Sequential:
"""
Construct fused ops pipeline that shares parameters with base RMSNorm.

Creates MakeExtraOutput + RMSNorm ops, where the RMSNorm op shares
the weight parameter with self.weight from the base class.
"""

fused_impl = te.pytorch.ops.Sequential()

# Op 1: MakeExtraOutput - forks the residual
fused_impl.append(te.pytorch.ops.MakeExtraOutput())

# Op 2: RMSNorm - shares weight parameter with self
kwargs = {
"eps": self.eps,
"device": "meta", # Already initialized
"dtype": self.weight.dtype,
"zero_centered_gamma": self.zero_centered_gamma,
}

# Add sm_margin if available (TE 2.5+)
if hasattr(self, '_sm_margins'):
kwargs["sm_margin"] = self._sm_margins

rmsnorm_op = te.pytorch.ops.RMSNorm(self.weight.shape, **kwargs)

rmsnorm_op.weight = self.weight

fused_impl.append(rmsnorm_op)

self._register_hooks_on_fused_impl(fused_impl)

return fused_impl

def _register_hooks_on_fused_impl(self, fused_impl: torch.nn.Module) -> None:

forward_pre_hooks = []
forward_post_hooks = []
backward_pre_hooks = []
backward_post_hooks = []

for submodule in self.modules():
for hook in submodule._forward_pre_hooks.values():
forward_pre_hooks.append((submodule, hook))
for hook in submodule._forward_hooks.values():
forward_post_hooks.append((submodule, hook))
for hook in submodule._backward_pre_hooks.values():
backward_pre_hooks.append((submodule, hook))
for hook in submodule._backward_hooks.values():
backward_post_hooks.append((submodule, hook))

# Pre-forward hooks
# Note: DDP pre-forward hooks are safe since they do not
# interact with input tensor.
if forward_pre_hooks:
from megatron.core.distributed import distributed_data_parallel

if any(
inspect.getmodule(hook) != distributed_data_parallel
for _, hook in forward_pre_hooks
):
warnings.warn(
"TEFusedResidualRMSNorm module has a submodule with a pre-forward hook. "
"TEFusedResidualRMSNorm module does not expose intermediate tensors, "
"so the hook may have incorrect behavior if it attempts to "
"access the input tensor."
)

def forward_pre_hook(module, *_) -> None:
for submodule, hook in forward_pre_hooks:
# Assume that hook does not interact with input
ret = hook(submodule, None)
if ret is not None:
raise RuntimeError(
"TEFusedResidualRMSNorm module does not expose "
"intermediate tensors, but submodule has "
"pre-forward hook that modifies input tensor."
)

fused_impl.register_forward_pre_hook(forward_pre_hook)

# Post-forward hooks
if forward_post_hooks:
warnings.warn(
"TEFusedResidualRMSNorm module has a submodule with a post-forward hook. "
"TEFusedResidualRMSNorm module does not expose intermediate tensors, "
"so the hook may have incorrect behavior if it attempts to "
"access the input or output tensors."
)

def forward_post_hook(module, *_) -> None:
for submodule, hook in forward_post_hooks:
# Assume that hook does not interact with input or output
ret = hook(submodule, None, None)
if ret is not None:
raise RuntimeError(
"TEFusedResidualRMSNorm module does not expose "
"intermediate tensors, but submodule has "
"post-forward hook that modifies output tensor."
)

fused_impl.register_forward_hook(forward_post_hook)

# Backward hooks
if backward_pre_hooks:
raise RuntimeError(
"TEFusedResidualRMSNorm module does not support "
"submodules with pre-backward hooks"
)
if backward_post_hooks:
raise RuntimeError(
"TEFusedResidualRMSNorm module does not support "
"submodules with post-backward hooks"
)

def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass with fused residual output.

Args:
hidden_states: Input tensor [s, b, h]

Returns:
Tuple of (normalized_output, residual), both [s, b, h]

Note:
Sequential.forward() automatically returns (output, extra_outputs...)
when MakeExtraOutput is present, so we don't need manual unpacking.
"""

# Construct fused impl lazily on first forward
# (in case parameters are modified after __init__)
if self._fused_impl is None:
self._fused_impl = (self._make_fused_impl(),)

# Apply fused implementation
# Sequential returns (normalized_output, residual) automatically
return self._fused_impl[0](hidden_states)

else:
TEFusedResidualRMSNorm = None # type: ignore[assignment, misc]


class TENorm:
"""A conditional wrapper to initialize an instance of
Transformer-Engine's `LayerNorm` or `RMSNorm` based on input."""
Transformer-Engine's `LayerNorm` or `RMSNorm` based on input.

Residual Fusion Design:
----------------------
Residual fusion is a two-level opt-in mechanism:

1. Global capability: config.fused_residual_rmsnorm must be True (enables the feature)
2. Local intent: has_residual=True must be passed at build site (declares this specific
norm is followed by a residual connection)

Fusion only happens when BOTH conditions are met.

"""

# TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm?
def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5):
def __new__(
cls,
config: TransformerConfig,
hidden_size: int,
eps: float = 1e-5,
has_residual: bool = False,
):
if not HAVE_TE:
raise ImportError(
"Transformer Engine is not installed. "
"Please install it with `pip install transformer-engine`."
)

if config.normalization == "LayerNorm":
if config.fused_residual_rmsnorm and has_residual:
raise ValueError("Fused residual RMSNorm is not supported for LayerNorm")
instance = te.pytorch.LayerNorm(
hidden_size=hidden_size,
eps=eps,
Expand All @@ -454,13 +639,30 @@ def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5)
assert hasattr(
te.pytorch, "RMSNorm"
), "Transformer-Engine >= v0.11 required to use this feature"
instance = te.pytorch.RMSNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)

extra_te_kwargs = _get_extra_te_kwargs(config)

if config.fused_residual_rmsnorm and has_residual:
# Use fused residual variant
assert (
TEFusedResidualRMSNorm is not None
), "TEFusedResidualRMSNorm requires Transformer-Engine >= v1.13.0"
instance = TEFusedResidualRMSNorm(
normalized_shape=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**extra_te_kwargs,
)
else:
# Standard RMSNorm without fusion
instance = te.pytorch.RMSNorm(
normalized_shape=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**extra_te_kwargs,
)
else:
raise Exception("Only LayerNorm and RMSNorm are curently supported")

Expand Down
8 changes: 6 additions & 2 deletions megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@
rearrange = None

try:
from flash_attn_3.flash_attn_interface import _flash_attn_forward
from flash_attn_3.flash_attn_interface import (
_flash_attn_forward,
)
from flash_attn_3.flash_attn_interface import (
flash_attn_with_kvcache as flash_attn3_with_kvcache,
)
Expand All @@ -70,7 +72,9 @@

if not HAVE_FA3:
try:
from flashattn_hopper.flash_attn_interface import _flash_attn_forward
from flashattn_hopper.flash_attn_interface import (
_flash_attn_forward,
)
from flashattn_hopper.flash_attn_interface import (
flash_attn_with_kvcache as flash_attn3_with_kvcache,
)
Expand Down
9 changes: 9 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ class TransformerConfig(ModelParallelConfig):
fused_single_qkv_rope: bool = False
"""If set, avoid splitting QKV before ROPE forward and avoid concatenating ROPE dgrads."""

fused_residual_rmsnorm: bool = False
"""If True, uses fuses residual connection and RMSNorm when TE is used."""

####################
# activation recomputation
####################
Expand Down Expand Up @@ -1541,6 +1544,12 @@ def __post_init__(self):
"to True and use_te_activation_func to False."
)

if self.fused_residual_rmsnorm:
if self.normalization != "RMSNorm":
raise ValueError(
"fused_residual_rmsnorm is only supported when normalization is RMSNorm."
)

if self.use_te_activation_func:
if self.activation_func not in (F.gelu, F.silu, F.relu):
raise ValueError(
Expand Down
Loading
Loading