diff --git a/examples/multimodal/layer_specs.py b/examples/multimodal/layer_specs.py index acced15eeb6..8b7384f6259 100644 --- a/examples/multimodal/layer_specs.py +++ b/examples/multimodal/layer_specs.py @@ -1,6 +1,7 @@ # Copyright (c) 2024-2026, NVIDIA CORPORATION. All rights reserved. import torch +from megatron.core.extensions.transformer_engine import HAVE_TE from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.models.hybrid.hybrid_block import HybridStack, HybridStackSubmodules from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules @@ -15,7 +16,6 @@ from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules from megatron.core.typed_torch import not_none -from megatron.core.extensions.transformer_engine import HAVE_TE if HAVE_TE: from megatron.core.extensions.transformer_engine import ( @@ -112,7 +112,7 @@ def get_layer_spec_te(is_vit=False, padding=False) -> ModuleSpec: submodules=SelfAttentionSubmodules( linear_qkv=not_none(TELayerNormColumnParallelLinear), core_attention=not_none(TEDotProductAttention), - linear_proj=TERowParallelLinear, + linear_proj=not_none(TERowParallelLinear), q_layernorm=IdentityOp, k_layernorm=IdentityOp, ), @@ -158,7 +158,7 @@ def get_hybrid_layer_spec_te(padding=False) -> ModuleSpec: submodules=SelfAttentionSubmodules( linear_qkv=not_none(TELayerNormColumnParallelLinear), core_attention=not_none(TEDotProductAttention), - linear_proj=TERowParallelLinear, + linear_proj=not_none(TERowParallelLinear), ), ), self_attn_bda=get_bias_dropout_add, diff --git a/examples/multimodal/radio/radio_g.py b/examples/multimodal/radio/radio_g.py index 9883d58db61..8dc18a25999 100644 --- a/examples/multimodal/radio/radio_g.py +++ b/examples/multimodal/radio/radio_g.py @@ -1,12 +1,9 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -from functools import partial - -import torch - from examples.multimodal.layer_scaling import ( LayerScalingTransformerLayer, get_bias_dropout_add_layer_scaling, ) +from megatron.core.extensions.transformer_engine import HAVE_TE from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.dot_product_attention import DotProductAttention @@ -14,9 +11,8 @@ from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules +from megatron.core.transformer.transformer_layer import TransformerLayerSubmodules from megatron.core.typed_torch import not_none -from megatron.core.extensions.transformer_engine import HAVE_TE if HAVE_TE: from megatron.core.extensions.transformer_engine import ( @@ -125,7 +121,7 @@ def get_radio_g_layer_spec_te() -> ModuleSpec: submodules=SelfAttentionSubmodules( linear_qkv=not_none(TELayerNormColumnParallelLinear), core_attention=not_none(TEDotProductAttention), - linear_proj=TERowParallelLinear, + linear_proj=not_none(TERowParallelLinear), q_layernorm=IdentityOp, k_layernorm=IdentityOp, ), diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 869e0099fce..7edf26dbd35 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -1,4 +1,5 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +from __future__ import annotations import dataclasses import enum @@ -865,7 +866,7 @@ def will_execute_quantized(self, is_context_quantized: bool) -> bool: self.te_quant_params, self.training, is_context_quantized ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: """Forward.""" _is_first_microbatch = ( None if self.disable_parameter_transpose_cache else self.is_first_microbatch diff --git a/megatron/core/extensions/transformer_engine_spec_provider.py b/megatron/core/extensions/transformer_engine_spec_provider.py index 04228e02e88..352f3b15a8a 100644 --- a/megatron/core/extensions/transformer_engine_spec_provider.py +++ b/megatron/core/extensions/transformer_engine_spec_provider.py @@ -44,7 +44,7 @@ def column_parallel_linear(self) -> type: """Which column parallel linear module TE backend uses""" return TEColumnParallelLinear - def row_parallel_linear(self) -> type: + def row_parallel_linear(self) -> type[TERowParallelLinear]: """Which row parallel linear module TE backend uses""" return TERowParallelLinear diff --git a/megatron/core/models/T5/t5_spec.py b/megatron/core/models/T5/t5_spec.py index 9f465df5c21..0e755f6559e 100644 --- a/megatron/core/models/T5/t5_spec.py +++ b/megatron/core/models/T5/t5_spec.py @@ -63,7 +63,7 @@ def encoder_model_with_transformer_engine_default_spec() -> ModuleSpec: submodules=SelfAttentionSubmodules( linear_qkv=not_none(TELayerNormColumnParallelLinear), core_attention=not_none(TEDotProductAttention), - linear_proj=TERowParallelLinear, + linear_proj=not_none(TERowParallelLinear), q_layernorm=IdentityOp, k_layernorm=IdentityOp, ), @@ -93,7 +93,7 @@ def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec: submodules=SelfAttentionSubmodules( linear_qkv=not_none(TELayerNormColumnParallelLinear), core_attention=not_none(TEDotProductAttention), - linear_proj=TERowParallelLinear, + linear_proj=not_none(TERowParallelLinear), q_layernorm=IdentityOp, k_layernorm=IdentityOp, ), @@ -107,7 +107,7 @@ def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec: linear_q=not_none(TEColumnParallelLinear), linear_kv=not_none(TEColumnParallelLinear), core_attention=not_none(TEDotProductAttention), - linear_proj=TERowParallelLinear, + linear_proj=not_none(TERowParallelLinear), ), ), cross_attn_bda=get_bias_dropout_add, diff --git a/megatron/core/models/backends.py b/megatron/core/models/backends.py index b019d527342..a270161ddd6 100644 --- a/megatron/core/models/backends.py +++ b/megatron/core/models/backends.py @@ -103,7 +103,7 @@ def column_parallel_linear(self) -> type: """Which column parallel linear module the backend uses""" return ColumnParallelLinear - def row_parallel_linear(self) -> type: + def row_parallel_linear(self) -> type[RowParallelLinear]: """Which row parallel linear module the backend uses""" return RowParallelLinear @@ -157,8 +157,8 @@ def column_parallel_linear(self) -> type: """Which column parallel linear module TE backend uses""" return InferenceColumnParallelLinear - def row_parallel_linear(self) -> type: - """Which row parallel linear module TE backend uses""" + def row_parallel_linear(self) -> type[InferenceRowParallelLinear]: + """Which row parallel linear module Inference backend uses""" return InferenceRowParallelLinear def fuse_layernorm_and_linear(self) -> bool: diff --git a/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py b/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py index f4385429422..3bbf21afae4 100644 --- a/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +++ b/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py @@ -118,7 +118,7 @@ def _get_heterogenous_attention_spec( not_none(TELayerNormColumnParallelLinear) if use_te else ColumnParallelLinear ), core_attention=not_none(TEDotProductAttention) if use_te else DotProductAttention, - linear_proj=TERowParallelLinear if use_te else RowParallelLinear, + linear_proj=not_none(TERowParallelLinear) if use_te else RowParallelLinear, q_layernorm=ln, k_layernorm=ln, ), diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index f89259be442..9d7f6ddd877 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -123,8 +123,8 @@ HAVE_FUSED_QKV_ROPE = False -class LinearQkv(Protocol): - """Protocol for linear_qkv modules.""" +class LinearQkvInterface(Protocol): + """Interface for linear_qkv modules.""" def forward(self, input: Tensor, /) -> tuple[Tensor, object]: """Applies linear_qkv.""" @@ -152,13 +152,13 @@ def __call__( is_expert: bool, tp_comm_buffer_name: str, tp_group: torch.distributed.ProcessGroup | None = None, - ) -> LinearQkv: ... + ) -> LinearQkvInterface: ... -class LinearLayer(Protocol): - """Protocol for linear_q and linear_kv modules.""" +class LinearInterface(Protocol): + """Interface for linear_q and linear_kv modules.""" - def forward(self, input: Tensor, /) -> Tuple[Tensor, object]: + def forward(self, input: Tensor, /) -> tuple[Tensor, object]: """Applies linear_q/linear_kv.""" ... @@ -178,23 +178,23 @@ def __call__( bias: bool, skip_bias_add: bool, is_expert: bool, - ) -> LinearLayer: ... + ) -> LinearInterface: ... -class CoreAttention(Protocol): - """Protocol for core_attention modules.""" +class CoreAttentionInterface(Protocol): + """Interface for core_attention modules.""" def forward( self, query: Tensor, key: Tensor, value: Tensor, - attention_mask: Optional[Tensor], + attention_mask: Tensor | None, /, *, attn_mask_type: AttnMaskType, - attention_bias: Optional[Tensor], - packed_seq_params: Optional[PackedSeqParams], + attention_bias: Tensor | None, + packed_seq_params: PackedSeqParams | None, ) -> Tensor: """Applies dot product attention.""" ... @@ -210,10 +210,42 @@ def __call__( layer_number: int, attn_mask_type: AttnMaskType, attention_type: str, - cp_comm_type: Optional[str], - softmax_scale: Optional[float], - pg_collection: Optional[ProcessGroupCollection], - ) -> CoreAttention: ... + cp_comm_type: str | None, + softmax_scale: float | None, + pg_collection: ProcessGroupCollection | None, + ) -> CoreAttentionInterface: ... + + +class LinearProjInterface(Protocol): + """Interface for linear_proj modules.""" + + def forward(self, hidden_states: Tensor, /) -> tuple[Tensor, Tensor | None]: + """Applies the linear projection to the output of the core attention.""" + ... + + def backward_dw(self) -> None: + """Computes weight gradients of output projection layer.""" + ... + + +class LinearProjBuilder(Protocol): + """Protocol for building linear_proj layers.""" + + def __call__( + self, + query_projection_size: int, + hidden_size: int, + /, + *, + config: TransformerConfig, + init_method: Callable[[torch.Tensor], None], + bias: bool, + input_is_parallel: bool, + skip_bias_add: bool, + is_expert: bool, + tp_comm_buffer_name: str, + tp_group: torch.distributed.ProcessGroup | None, + ) -> LinearProjInterface: ... @dataclass @@ -224,7 +256,7 @@ class SelfAttentionSubmodules: linear_qkv: LinearQkvBuilder core_attention: CoreAttentionBuilder - linear_proj: Union[ModuleSpec, type] = None + linear_proj: LinearProjBuilder q_layernorm: LayerNormBuilder | None = None k_layernorm: LayerNormBuilder | None = None @@ -238,7 +270,7 @@ class CrossAttentionSubmodules: linear_q: LinearLayerBuilder linear_kv: LinearLayerBuilder core_attention: CoreAttentionBuilder - linear_proj: Union[ModuleSpec, type] = None + linear_proj: LinearProjBuilder class Attention(MegatronModule, ABC): @@ -354,12 +386,11 @@ def __init__( ) # Output. - self.linear_proj = build_module( - submodules.linear_proj, + self.linear_proj = submodules.linear_proj( self.query_projection_size, self.config.hidden_size, config=self.config, - init_method=self.config.output_layer_init_method, + init_method=not_none(self.config.output_layer_init_method), bias=self.config.add_bias_linear, input_is_parallel=True, skip_bias_add=True, @@ -983,7 +1014,7 @@ def forward( sequence_len_offset: Optional[int] = None, *, inference_params: Optional[BaseInferenceContext] = None, - ) -> tuple[Tensor, Tensor]: + ) -> tuple[Tensor, Tensor | None]: """ Perform a forward pass through the attention module. @@ -1133,7 +1164,7 @@ def forward( ) out = output.transpose(0, 1).contiguous() context_layer = out.view(out.size(0), out.size(1), -1) - output, bias = self.linear_proj(context_layer) + output, bias = apply_module(self.linear_proj)(context_layer) return output, bias if ( @@ -1301,7 +1332,7 @@ def forward( # ================= nvtx_range_push(suffix="linear_proj") with off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") as core_attn_out: - output, bias = self.linear_proj(core_attn_out) + output, bias = apply_module(self.linear_proj)(core_attn_out) if self.offload_attn_proj: output = off_interface.group_commit( output, name="attn_proj", forced_released_tensors=[core_attn_out] diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 601ae89fae1..4e862e868fb 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -1,9 +1,9 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - +from __future__ import annotations import math from dataclasses import dataclass -from typing import NoReturn, Optional, Union +from typing import TYPE_CHECKING, NoReturn, Optional, Union import torch import torch.nn.functional as F @@ -35,12 +35,12 @@ gather_from_tensor_model_parallel_region, scatter_to_sequence_parallel_region, ) -from megatron.core.transformer.attention import Attention +from megatron.core.transformer.attention import Attention, LinearProjBuilder from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.torch_norm import LayerNormBuilder from megatron.core.transformer.transformer_config import MLATransformerConfig -from megatron.core.typed_torch import apply_module +from megatron.core.typed_torch import apply_module, not_none from megatron.core.utils import ( deprecate_inference_params, get_pg_size, @@ -77,6 +77,10 @@ split_te_layernorm_column_parallel_linear, ) = (None, None, None, None, None, None) +if TYPE_CHECKING: + from megatron.core.inference.contexts import BaseInferenceContext + from megatron.core.packed_seq_params import PackedSeqParams + def _prepare_mla_core_attention_value(parallel_attention, query, value, packed_seq_params): """Prepare value tensor for MLA core attention THD execution.""" @@ -108,6 +112,8 @@ def _trim_mla_core_attention_output(core_attn_out, need_v_pad, orig_v_dim, padde class MLASelfAttentionSubmodules: """Submodules for the MLA self-attention layer.""" + linear_proj: LinearProjBuilder + # TODO(nschank): Move layernorms back to the bottom once all other layers have defaults removed. q_layernorm: LayerNormBuilder kv_layernorm: LayerNormBuilder @@ -119,7 +125,6 @@ class MLASelfAttentionSubmodules: linear_kv_up_proj: Union[ModuleSpec, type] = None linear_qkv_down_proj: Union[ModuleSpec, type] = None core_attention: Union[ModuleSpec, type] = None - linear_proj: Union[ModuleSpec, type] = None class MultiLatentAttention(Attention): @@ -140,7 +145,8 @@ def __init__( pg_collection: Optional[ProcessGroupCollection] = None, pp_layer_offset: Optional[int] = None, ) -> None: - + # TODO(nschank): Restructure so that the Attention initializer knows which specific + # submodules it will construct, so that MLASelfAttentionSubmodules honors that interface. super().__init__( config=config, submodules=submodules, @@ -210,12 +216,11 @@ def __init__( ) # Output. - self.linear_proj = build_module( - submodules.linear_proj, + self.linear_proj = submodules.linear_proj( self.query_projection_size, self.config.hidden_size, config=self.config, - init_method=self.config.output_layer_init_method, + init_method=not_none(self.config.output_layer_init_method), bias=self.config.add_bias_linear, input_is_parallel=True, skip_bias_add=True, @@ -289,20 +294,20 @@ def _run_core_attention( def forward( self, - hidden_states, - attention_mask, - key_value_states=None, - inference_context=None, - rotary_pos_emb=None, - rotary_pos_cos=None, - rotary_pos_sin=None, - rotary_pos_cos_sin=None, - attention_bias=None, - packed_seq_params=None, - position_ids=None, - sequence_len_offset=None, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + key_value_states: torch.Tensor | None = None, + inference_context: BaseInferenceContext | None = None, + rotary_pos_emb: torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None = None, + rotary_pos_cos: torch.Tensor | None = None, + rotary_pos_sin: torch.Tensor | None = None, + rotary_pos_cos_sin: torch.Tensor | None = None, + attention_bias: torch.Tensor | None = None, + packed_seq_params: PackedSeqParams | None = None, + position_ids: torch.Tensor | None = None, + sequence_len_offset: int | None = None, *, - inference_params=None, + inference_params: BaseInferenceContext | None = None, ): """Forward pass for multi-latent attention""" assert rotary_pos_emb is None, "Rotary position embeddings should not be passed into MLA." @@ -451,7 +456,7 @@ def forward( # Output. [sq, b, h] # ================= with off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") as core_attn_out: - output, bias = self.linear_proj(core_attn_out) + output, bias = apply_module(self.linear_proj)(core_attn_out) if self.offload_attn_proj: output = off_interface.group_commit( output, name="attn_proj", forced_released_tensors=[core_attn_out] diff --git a/tests/unit_tests/transformer/test_multi_latent_attention.py b/tests/unit_tests/transformer/test_multi_latent_attention.py index 1028bbd2371..b41a66f4ee7 100644 --- a/tests/unit_tests/transformer/test_multi_latent_attention.py +++ b/tests/unit_tests/transformer/test_multi_latent_attention.py @@ -28,6 +28,7 @@ MultiLatentAttention, ) from megatron.core.transformer.transformer_config import MLATransformerConfig +from megatron.core.typed_torch import apply_module from megatron.core.utils import is_te_min_version, is_torch_min_version from megatron.training.arguments import parse_args from megatron.training.checkpointing import load_checkpoint, save_checkpoint @@ -804,8 +805,12 @@ def test_gpu_forward_thd_precision(self): ) assert torch.equal(_core_attn_out_sbhd, core_attn_out_thd) - output_sbhd, bias_sbhd = self.parallel_attention.linear_proj(core_attn_out_sbhd) - output_thd, bias_thd = self.parallel_attention.linear_proj(core_attn_out_thd) + output_sbhd, bias_sbhd = apply_module(self.parallel_attention.linear_proj)( + core_attn_out_sbhd + ) + output_thd, bias_thd = apply_module(self.parallel_attention.linear_proj)( + core_attn_out_thd + ) _output_sbhd = output_sbhd.transpose(0, 1).contiguous().view(*output_thd.shape) assert torch.equal(_output_sbhd, output_thd) @@ -963,8 +968,12 @@ def test_gpu_forward_thd_precision(self): ) torch.testing.assert_close(_core_attn_out_sbhd, core_attn_out_thd, atol=atol, rtol=rtol) - output_sbhd, bias_sbhd = self.parallel_attention.linear_proj(core_attn_out_sbhd) - output_thd, bias_thd = self.parallel_attention.linear_proj(core_attn_out_thd) + output_sbhd, bias_sbhd = apply_module(self.parallel_attention.linear_proj)( + core_attn_out_sbhd + ) + output_thd, bias_thd = apply_module(self.parallel_attention.linear_proj)( + core_attn_out_thd + ) _output_sbhd = output_sbhd.transpose(0, 1).contiguous().view(*output_thd.shape) torch.testing.assert_close(_output_sbhd, output_thd, atol=atol, rtol=rtol) @@ -1108,8 +1117,12 @@ def test_gpu_forward_thd_precision(self): ) assert torch.equal(_core_attn_out_sbhd, core_attn_out_thd) - output_sbhd, bias_sbhd = self.parallel_attention.linear_proj(core_attn_out_sbhd) - output_thd, bias_thd = self.parallel_attention.linear_proj(core_attn_out_thd) + output_sbhd, bias_sbhd = apply_module(self.parallel_attention.linear_proj)( + core_attn_out_sbhd + ) + output_thd, bias_thd = apply_module(self.parallel_attention.linear_proj)( + core_attn_out_thd + ) _output_sbhd = output_sbhd.transpose(0, 1).contiguous().view(*output_thd.shape) assert torch.equal(_output_sbhd, output_thd)