Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/multimodal/layer_specs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2024-2026, NVIDIA CORPORATION. All rights reserved.
import torch

from megatron.core.extensions.transformer_engine import HAVE_TE
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.hybrid.hybrid_block import HybridStack, HybridStackSubmodules
from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules
Expand All @@ -15,7 +16,6 @@
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.typed_torch import not_none
from megatron.core.extensions.transformer_engine import HAVE_TE

if HAVE_TE:
from megatron.core.extensions.transformer_engine import (
Expand Down Expand Up @@ -112,7 +112,7 @@ def get_layer_spec_te(is_vit=False, padding=False) -> ModuleSpec:
submodules=SelfAttentionSubmodules(
linear_qkv=not_none(TELayerNormColumnParallelLinear),
core_attention=not_none(TEDotProductAttention),
linear_proj=TERowParallelLinear,
linear_proj=not_none(TERowParallelLinear),
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
Expand Down Expand Up @@ -158,7 +158,7 @@ def get_hybrid_layer_spec_te(padding=False) -> ModuleSpec:
submodules=SelfAttentionSubmodules(
linear_qkv=not_none(TELayerNormColumnParallelLinear),
core_attention=not_none(TEDotProductAttention),
linear_proj=TERowParallelLinear,
linear_proj=not_none(TERowParallelLinear),
),
),
self_attn_bda=get_bias_dropout_add,
Expand Down
10 changes: 3 additions & 7 deletions examples/multimodal/radio/radio_g.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from functools import partial

import torch

from examples.multimodal.layer_scaling import (
LayerScalingTransformerLayer,
get_bias_dropout_add_layer_scaling,
)
from megatron.core.extensions.transformer_engine import HAVE_TE
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.transformer.transformer_layer import TransformerLayerSubmodules
from megatron.core.typed_torch import not_none
from megatron.core.extensions.transformer_engine import HAVE_TE

if HAVE_TE:
from megatron.core.extensions.transformer_engine import (
Expand Down Expand Up @@ -125,7 +121,7 @@ def get_radio_g_layer_spec_te() -> ModuleSpec:
submodules=SelfAttentionSubmodules(
linear_qkv=not_none(TELayerNormColumnParallelLinear),
core_attention=not_none(TEDotProductAttention),
linear_proj=TERowParallelLinear,
linear_proj=not_none(TERowParallelLinear),
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
Expand Down
3 changes: 2 additions & 1 deletion megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
from __future__ import annotations

import dataclasses
import enum
Expand Down Expand Up @@ -865,7 +866,7 @@ def will_execute_quantized(self, is_context_quantized: bool) -> bool:
self.te_quant_params, self.training, is_context_quantized
)

def forward(self, x):
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def column_parallel_linear(self) -> type:
"""Which column parallel linear module TE backend uses"""
return TEColumnParallelLinear

def row_parallel_linear(self) -> type:
def row_parallel_linear(self) -> type[TERowParallelLinear]:
"""Which row parallel linear module TE backend uses"""
return TERowParallelLinear

Expand Down
6 changes: 3 additions & 3 deletions megatron/core/models/T5/t5_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def encoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
submodules=SelfAttentionSubmodules(
linear_qkv=not_none(TELayerNormColumnParallelLinear),
core_attention=not_none(TEDotProductAttention),
linear_proj=TERowParallelLinear,
linear_proj=not_none(TERowParallelLinear),
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
Expand Down Expand Up @@ -93,7 +93,7 @@ def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
submodules=SelfAttentionSubmodules(
linear_qkv=not_none(TELayerNormColumnParallelLinear),
core_attention=not_none(TEDotProductAttention),
linear_proj=TERowParallelLinear,
linear_proj=not_none(TERowParallelLinear),
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
Expand All @@ -107,7 +107,7 @@ def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
linear_q=not_none(TEColumnParallelLinear),
linear_kv=not_none(TEColumnParallelLinear),
core_attention=not_none(TEDotProductAttention),
linear_proj=TERowParallelLinear,
linear_proj=not_none(TERowParallelLinear),
),
),
cross_attn_bda=get_bias_dropout_add,
Expand Down
6 changes: 3 additions & 3 deletions megatron/core/models/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def column_parallel_linear(self) -> type:
"""Which column parallel linear module the backend uses"""
return ColumnParallelLinear

def row_parallel_linear(self) -> type:
def row_parallel_linear(self) -> type[RowParallelLinear]:
"""Which row parallel linear module the backend uses"""
return RowParallelLinear

Expand Down Expand Up @@ -157,8 +157,8 @@ def column_parallel_linear(self) -> type:
"""Which column parallel linear module TE backend uses"""
return InferenceColumnParallelLinear

def row_parallel_linear(self) -> type:
"""Which row parallel linear module TE backend uses"""
def row_parallel_linear(self) -> type[InferenceRowParallelLinear]:
"""Which row parallel linear module Inference backend uses"""
return InferenceRowParallelLinear

def fuse_layernorm_and_linear(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _get_heterogenous_attention_spec(
not_none(TELayerNormColumnParallelLinear) if use_te else ColumnParallelLinear
),
core_attention=not_none(TEDotProductAttention) if use_te else DotProductAttention,
linear_proj=TERowParallelLinear if use_te else RowParallelLinear,
linear_proj=not_none(TERowParallelLinear) if use_te else RowParallelLinear,
q_layernorm=ln,
k_layernorm=ln,
),
Expand Down
79 changes: 55 additions & 24 deletions megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@
HAVE_FUSED_QKV_ROPE = False


class LinearQkv(Protocol):
"""Protocol for linear_qkv modules."""
class LinearQkvInterface(Protocol):
"""Interface for linear_qkv modules."""

def forward(self, input: Tensor, /) -> tuple[Tensor, object]:
"""Applies linear_qkv."""
Expand Down Expand Up @@ -152,13 +152,13 @@ def __call__(
is_expert: bool,
tp_comm_buffer_name: str,
tp_group: torch.distributed.ProcessGroup | None = None,
) -> LinearQkv: ...
) -> LinearQkvInterface: ...


class LinearLayer(Protocol):
"""Protocol for linear_q and linear_kv modules."""
class LinearInterface(Protocol):
"""Interface for linear_q and linear_kv modules."""

def forward(self, input: Tensor, /) -> Tuple[Tensor, object]:
def forward(self, input: Tensor, /) -> tuple[Tensor, object]:
"""Applies linear_q/linear_kv."""
...

Expand All @@ -178,23 +178,23 @@ def __call__(
bias: bool,
skip_bias_add: bool,
is_expert: bool,
) -> LinearLayer: ...
) -> LinearInterface: ...


class CoreAttention(Protocol):
"""Protocol for core_attention modules."""
class CoreAttentionInterface(Protocol):
"""Interface for core_attention modules."""

def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Optional[Tensor],
attention_mask: Tensor | None,
/,
*,
attn_mask_type: AttnMaskType,
attention_bias: Optional[Tensor],
packed_seq_params: Optional[PackedSeqParams],
attention_bias: Tensor | None,
packed_seq_params: PackedSeqParams | None,
) -> Tensor:
"""Applies dot product attention."""
...
Expand All @@ -210,10 +210,42 @@ def __call__(
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
cp_comm_type: Optional[str],
softmax_scale: Optional[float],
pg_collection: Optional[ProcessGroupCollection],
) -> CoreAttention: ...
cp_comm_type: str | None,
softmax_scale: float | None,
pg_collection: ProcessGroupCollection | None,
) -> CoreAttentionInterface: ...


class LinearProjInterface(Protocol):
"""Interface for linear_proj modules."""

def forward(self, hidden_states: Tensor, /) -> tuple[Tensor, Tensor | None]:
"""Applies the linear projection to the output of the core attention."""
...

def backward_dw(self) -> None:
"""Computes weight gradients of output projection layer."""
...


class LinearProjBuilder(Protocol):
"""Protocol for building linear_proj layers."""

def __call__(
self,
query_projection_size: int,
hidden_size: int,
/,
*,
config: TransformerConfig,
init_method: Callable[[torch.Tensor], None],
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str,
tp_group: torch.distributed.ProcessGroup | None,
) -> LinearProjInterface: ...


@dataclass
Expand All @@ -224,7 +256,7 @@ class SelfAttentionSubmodules:

linear_qkv: LinearQkvBuilder
core_attention: CoreAttentionBuilder
linear_proj: Union[ModuleSpec, type] = None
linear_proj: LinearProjBuilder
q_layernorm: LayerNormBuilder | None = None
k_layernorm: LayerNormBuilder | None = None

Expand All @@ -238,7 +270,7 @@ class CrossAttentionSubmodules:
linear_q: LinearLayerBuilder
linear_kv: LinearLayerBuilder
core_attention: CoreAttentionBuilder
linear_proj: Union[ModuleSpec, type] = None
linear_proj: LinearProjBuilder


class Attention(MegatronModule, ABC):
Expand Down Expand Up @@ -354,12 +386,11 @@ def __init__(
)

# Output.
self.linear_proj = build_module(
submodules.linear_proj,
self.linear_proj = submodules.linear_proj(
self.query_projection_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
init_method=not_none(self.config.output_layer_init_method),
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
Expand Down Expand Up @@ -983,7 +1014,7 @@ def forward(
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
) -> tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor | None]:
"""
Perform a forward pass through the attention module.

Expand Down Expand Up @@ -1133,7 +1164,7 @@ def forward(
)
out = output.transpose(0, 1).contiguous()
context_layer = out.view(out.size(0), out.size(1), -1)
output, bias = self.linear_proj(context_layer)
output, bias = apply_module(self.linear_proj)(context_layer)
return output, bias

if (
Expand Down Expand Up @@ -1301,7 +1332,7 @@ def forward(
# =================
nvtx_range_push(suffix="linear_proj")
with off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") as core_attn_out:
output, bias = self.linear_proj(core_attn_out)
output, bias = apply_module(self.linear_proj)(core_attn_out)
if self.offload_attn_proj:
output = off_interface.group_commit(
output, name="attn_proj", forced_released_tensors=[core_attn_out]
Expand Down
Loading
Loading