From e756ddfebbfbb093f34a2d9e4263765ed6246b9b Mon Sep 17 00:00:00 2001 From: Daiyaan Date: Fri, 1 May 2026 15:04:30 -0700 Subject: [PATCH] Add opt-in nonuniform tensor parallelism --- megatron/core/distributed/nonuniform_tp.py | 914 ++++++++++++++++++ .../nonuniform_tp_transformer_engine.py | 157 +++ .../distributed/test_nonuniform_tp.py | 886 +++++++++++++++++ .../test_nonuniform_tp_transformer_engine.py | 78 ++ 4 files changed, 2035 insertions(+) create mode 100644 megatron/core/distributed/nonuniform_tp.py create mode 100644 megatron/core/extensions/nonuniform_tp_transformer_engine.py create mode 100644 tests/unit_tests/distributed/test_nonuniform_tp.py create mode 100644 tests/unit_tests/extension/test_nonuniform_tp_transformer_engine.py diff --git a/megatron/core/distributed/nonuniform_tp.py b/megatron/core/distributed/nonuniform_tp.py new file mode 100644 index 00000000000..c976cbc7180 --- /dev/null +++ b/megatron/core/distributed/nonuniform_tp.py @@ -0,0 +1,914 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +""" +Nonuniform Tensor Parallelism (NTP) - Non-intrusive implementation. + +This module provides fault tolerance for tensor-parallel training by allowing +a subset of TP ranks ("spares") to handle failures while "core" ranks continue computation. + +All NTP logic is contained in this module as subclasses of core components, +making it non-intrusive to the main codebase. + +Usage: + Instead of using the standard classes, use the NTP variants: + - NonuniformTPDistributedDataParallel instead of DistributedDataParallel + - NonuniformTPOptimizer to wrap your optimizer + - Call initialize_nonuniform_tp_process_groups() after initialize_model_parallel() +""" + +import functools +import logging +import sys +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist + +from .. import parallel_state +from ..optimizer.param_layout import ( + FullParamLayout, + PerBufferParamLayout, + pad_bucket_end, + pad_param_start, +) +from ..process_groups_config import ProcessGroupCollection +from ..transformer.cuda_graphs import is_graph_capturing +from ..transformer.transformer_config import TransformerConfig +from . import distributed_data_parallel as ddp_module +from .distributed_data_parallel import DistributedDataParallel +from .distributed_data_parallel_config import DistributedDataParallelConfig +from .param_and_grad_buffer import _ParamAndGradBucketGroup, _ParamAndGradBuffer + +logger = logging.getLogger(__name__) + + +def _ntp_get_non_active_ranks( + ntp_config: "NonuniformTPConfig", dp_rank: int, cp_rank: int = 0, pp_rank: int = 0 +) -> Optional[List[int]]: + """Return configured inactive local TP ranks, accepting both legacy and tuple keys.""" + if not ntp_config.non_active_ranks_per_dp: + return None + + rank_key = (dp_rank, cp_rank, pp_rank) + if rank_key in ntp_config.non_active_ranks_per_dp: + return ntp_config.non_active_ranks_per_dp[rank_key] + if dp_rank in ntp_config.non_active_ranks_per_dp: + return ntp_config.non_active_ranks_per_dp[dp_rank] + return None + + +def _ntp_current_rank_is_reduced_dp(ntp_config: "NonuniformTPConfig") -> bool: + """Return True if this rank belongs to a DP replica configured with reduced TP.""" + if ntp_config.tp_spares == 0: + return False + + dp_rank = parallel_state.get_data_parallel_rank() + if ntp_config.non_active_ranks_per_dp: + cp_rank = parallel_state.get_context_parallel_rank() + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + return _ntp_get_non_active_ranks(ntp_config, dp_rank, cp_rank, pp_rank) is not None + return dp_rank < ntp_config.num_reduced_tp_dp_ranks + + +def _ntp_current_rank_should_dp_sync(ntp_config: "NonuniformTPConfig") -> bool: + """Return True if this rank should participate in data-parallel grad sync.""" + if ntp_config.tp_spares == 0: + return True + + tp_size = parallel_state.get_tensor_model_parallel_world_size() + tp_rank = parallel_state.get_tensor_model_parallel_rank() + reduced_tp_size = ntp_config.tp_base - ntp_config.tp_spares + + # Reduced DP replicas only contain active TP ranks after NTP group reconfiguration. + if tp_size != ntp_config.tp_base: + return True + + # In healthy full-TP replicas, ranks beyond reduced_tp_size are folded into core ranks by + # NTP resharding and must not wait for a DP peer from the reduced replica. + return tp_rank < reduced_tp_size + + +def _ntp_param_can_reshard(param: torch.nn.Parameter) -> bool: + """Return True for tensor-parallel params initialized with NTP split metadata.""" + return ( + hasattr(param, 'tensor_model_parallel') + and param.tensor_model_parallel + and hasattr(param, 'partition_dim') + and hasattr(param, 'send_splits') + and hasattr(param, 'recv_splits') + ) + + +def _ntp_should_expand_param_grad( + param: torch.nn.Parameter, ntp_config: "NonuniformTPConfig" +) -> bool: + """Return True if healthy core rank needs side_grad storage for this TP parameter.""" + if ntp_config.tp_spares == 0 or not _ntp_param_can_reshard(param): + return False + if _ntp_current_rank_is_reduced_dp(ntp_config): + return False + if parallel_state.get_tensor_model_parallel_world_size() != ntp_config.tp_base: + return False + return parallel_state.get_tensor_model_parallel_rank() < ( + ntp_config.tp_base - ntp_config.tp_spares + ) + + +def _ntp_extra_partition_dim(param: torch.nn.Parameter, ntp_config: "NonuniformTPConfig") -> int: + """Return side_grad extent along partition_dim for this healthy core rank.""" + tp_rank = parallel_state.get_tensor_model_parallel_rank() + return int(sum(param.recv_splits[tp_rank][-ntp_config.tp_spares :])) + + +def _ntp_param_numel(param: torch.nn.Parameter, ntp_config: "NonuniformTPConfig") -> int: + """Return main grad plus any NTP side grad storage needed for this param.""" + numel = param.data.nelement() + if _ntp_should_expand_param_grad(param, ntp_config): + side_shape = list(param.data.shape) + side_shape[param.partition_dim] = _ntp_extra_partition_dim(param, ntp_config) + numel += torch.Size(side_shape).numel() + return numel + + +def _compute_ntp_per_buffer_param_layout( + params: List[torch.nn.Parameter], + bucket_size: Optional[int], + data_parallel_world_size: int, + ddp_config: DistributedDataParallelConfig, + ntp_config: "NonuniformTPConfig", + param_indices: Optional[List[int]] = None, +) -> PerBufferParamLayout: + """Compute a buffer layout that includes side_grad storage for healthy core ranks.""" + + def _does_param_require_new_bucket(param): + return getattr(param, "shared_embedding", False) + + param_index_map = {} + bucket_indices = [] + per_bucket_numel_unpadded = [] + + param_start_index = 0 + bucket_start_index = 0 + bucket_params = set() + bucket_id = 0 + + def _finalize_bucket(param_end_index, bucket_start_index, bucket_id): + per_bucket_numel_unpadded.append(param_end_index - bucket_start_index) + if ddp_config.use_distributed_optimizer: + bucket_end_index = pad_bucket_end( + param_end_index, + data_parallel_world_size, + ddp_config.pad_buckets_for_high_nccl_busbw, + ) + else: + bucket_end_index = param_end_index + bucket_indices.append((bucket_start_index, bucket_end_index)) + return bucket_end_index, bucket_id + 1 + + for param in params[::-1]: + if ddp_config.use_distributed_optimizer: + param_start_index = pad_param_start(param_start_index) + + if _does_param_require_new_bucket(param) and len(bucket_params) > 0: + bucket_start_index, bucket_id = _finalize_bucket( + param_start_index, bucket_start_index, bucket_id + ) + bucket_params = set() + param_start_index = bucket_start_index + + param_end_index = param_start_index + _ntp_param_numel(param, ntp_config) + param_index_map[param] = (param_start_index, param_end_index, bucket_id) + bucket_params.add(param) + + if ( + bucket_size is not None and (param_end_index - bucket_start_index) >= bucket_size + ) or _does_param_require_new_bucket(param): + bucket_start_index, bucket_id = _finalize_bucket( + param_end_index, bucket_start_index, bucket_id + ) + bucket_params = set() + param_start_index = bucket_start_index + else: + param_start_index = param_end_index + + if len(bucket_params) > 0: + _finalize_bucket(param_end_index, bucket_start_index, bucket_id) + + return PerBufferParamLayout( + param_index_map=param_index_map, + bucket_indices=bucket_indices, + per_bucket_numel_unpadded=per_bucket_numel_unpadded, + param_indices=param_indices if param_indices is not None else [], + ) + + +# ====================================================================================== +# NTP Configuration +# ====================================================================================== + + +@dataclass +class NonuniformTPConfig: + """Configuration for Nonuniform Tensor Parallelism (NTP). + + NTP provides fault tolerance for tensor-parallel training by designating + a subset of TP ranks as "spares" that can handle GPU failures. + """ + + tp_base: int = 8 + """Base for tensor parallelism. This is the number of ranks in healthy tensor parallel groups. + Used for nonuniform tensor parallelism.""" + + tp_spares: int = 0 + """Number of spares for nonuniform tensor parallelism. + + When > 0, (tp_base - tp_spares) ranks handle computation and tp_spares ranks + provide fault tolerance. + """ + + num_reduced_tp_dp_ranks: int = 1 + """Number of DP ranks that use reduced TP (tp_base - tp_spares). The remaining DP ranks use + full tp_base. Reduced TP ranks are assumed to come first in the global rank ordering.""" + + non_active_ranks_per_dp: Optional[Dict[Tuple[int, int, int], List[int]]] = None + """Mapping of (DP rank, CP rank, PP rank) to list of non-active (spare) local TP rank IDs. + This allows specifying arbitrary GPU failures across all parallelism dimensions. + Example: {(0,0,0): [0,3], (0,1,0): [1,2], (1,0,0): [0,3]} means: + - DP rank 0, CP rank 0, PP rank 0 has local TP ranks 0,3 as spares + - DP rank 0, CP rank 1, PP rank 0 has local TP ranks 1,2 as spares + - DP rank 1, CP rank 0, PP rank 0 has local TP ranks 0,3 as spares + The number of non-active ranks must be consistent across CP replicas within each DP rank. + If None, defaults to last tp_spares ranks as non-active.""" + + +# ====================================================================================== +# Utility Functions for NTP Configuration +# ====================================================================================== + + +def compute_uniform_tp_spares_with_parity( + faulty_gpu_map: Dict[int, List[int]], tp_base: int +) -> Tuple[int, Dict[int, List[int]]]: + """ + Compute uniform tp_spares across all faulty DP ranks and add additional + non-active ranks to achieve parity. + + Strategy: + 1. Find the maximum number of failed GPUs across all affected DP ranks + 2. Use this as tp_spares (smallest reduced_tp that works for all) + 3. For DP ranks with fewer failures, pad with additional healthy GPUs + to reach uniform tp_spares + + Args: + faulty_gpu_map: Mapping of DP rank -> list of failed GPU IDs + tp_base: Base tensor parallel size + + Returns: + Tuple of (tp_spares, non_active_ranks_per_dp) + where non_active_ranks_per_dp includes both failed and padded GPUs + + Example: + Input: {0: [2, 5], 1: [1]} # DP rank 0 has 2 failures, DP rank 1 has 1 + Output: (2, {0: [2, 5], 1: [1, 7]}) # Pad DP rank 1 with GPU 7 to reach 2 + """ + if not faulty_gpu_map: + return 0, {} + + # Find maximum number of failures + max_failures = max(len(gpu_ids) for gpu_ids in faulty_gpu_map.values()) + tp_spares = max_failures + + non_active_ranks_per_dp = {} + + for dp_rank, failed_gpus in faulty_gpu_map.items(): + non_active = list(failed_gpus) # Start with actually failed GPUs + num_to_pad = tp_spares - len(failed_gpus) + + if num_to_pad > 0: + # Need to add more non-active ranks for parity + # Find healthy GPUs to mark as non-active + failed_set = set(failed_gpus) + healthy_gpus = [i for i in range(tp_base) if i not in failed_set] + + # Take from the end of healthy GPUs (prefer keeping lower ranks active) + gpus_to_deactivate = healthy_gpus[-num_to_pad:] + non_active.extend(gpus_to_deactivate) + + non_active_ranks_per_dp[dp_rank] = sorted(non_active) + + return tp_spares, non_active_ranks_per_dp + + +def get_active_ranks_for_dp( + dp_rank: int, tp_base: int, ntp_config: NonuniformTPConfig, cp_rank: int = 0, pp_rank: int = 0 +) -> List[int]: + """ + Get list of active (non-spare) local rank IDs for a given DP rank. + + Args: + dp_rank: Data parallel rank + tp_base: Base tensor parallel size + ntp_config: NTP configuration + + Returns: + List of local rank IDs that are active (not spare) + """ + non_active = _ntp_get_non_active_ranks(ntp_config, dp_rank, cp_rank, pp_rank) + if non_active is not None: + # Use explicitly specified non-active ranks + non_active_set = set(non_active) + active_ranks = [i for i in range(tp_base) if i not in non_active_set] + else: + # Default: first (tp_base - tp_spares) ranks are active + red_tp = tp_base - ntp_config.tp_spares + active_ranks = list(range(red_tp)) + + return active_ranks + + +# ====================================================================================== +# Process Group Initialization for NTP +# ====================================================================================== + + +def initialize_nonuniform_tp_process_groups( + ntp_config: NonuniformTPConfig, exit_spares: bool = True +) -> bool: + """ + Reconfigure TP and CP process groups for nonuniform tensor parallelism. + + Call this function after initialize_model_parallel() to enable NTP. + Non-active (spare) ranks will exit after group creation. + + Args: + ntp_config: NTP configuration containing tp_base, tp_spares, num_reduced_tp_dp_ranks, + and optionally non_active_ranks_per_dp + """ + if ntp_config.tp_spares == 0: + # No nonuniform TP, nothing to reconfigure + return True + + tp_base = ntp_config.tp_base + cp_size = parallel_state.get_context_parallel_world_size() + rank = dist.get_rank() + + # Calculate which DP replicas use reduced TP + dp_replica_size = tp_base * cp_size + num_reduced_dp_ranks = ntp_config.num_reduced_tp_dp_ranks + + # Determine if current rank is in a reduced TP DP replica + dp_replica_id = rank // dp_replica_size + if dp_replica_id >= num_reduced_dp_ranks: + # This rank is in a normal TP DP replica, no reconfiguration needed + logger.info( + "[NTP] Rank %s is in normal TP DP replica %s, skipping reconfiguration", + rank, + dp_replica_id, + ) + return True + + local_rank_in_dp = rank % dp_replica_size + cp_rank_in_dp = local_rank_in_dp // tp_base if cp_size > 1 else 0 + + # This rank is in a reduced TP DP replica - need to reconfigure + # Get active ranks for this DP replica (supports non-contiguous) + active_local_ranks = get_active_ranks_for_dp( + dp_replica_id, tp_base, ntp_config, cp_rank=cp_rank_in_dp + ) + + logger.info( + "[NTP] Rank %s in DP replica %s: active_local_ranks=%s", + rank, + dp_replica_id, + active_local_ranks, + ) + + if cp_size > 1: + # With CP enabled: recreate TP, CP, and TP-CP groups + dp_replica_start = dp_replica_id * dp_replica_size + + # Create new TP groups (one per CP slice in this DP replica) + for cp_rank in range(cp_size): + cp_slice_start = dp_replica_start + cp_rank * tp_base + tp_group_ranks = [cp_slice_start + local_tp for local_tp in active_local_ranks] + tp_group = dist.new_group(ranks=tp_group_ranks) + + if rank in tp_group_ranks: + parallel_state._TENSOR_MODEL_PARALLEL_GROUP = tp_group + parallel_state._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = tp_group_ranks + parallel_state._MODEL_PARALLEL_GROUP = tp_group + parallel_state._MODEL_PARALLEL_GLOBAL_RANKS = tp_group_ranks + logger.info("[NTP] Rank %s created TP group: %s", rank, tp_group_ranks) + + # Create new CP groups (one per active TP position) + for tp_rank_in_slice in active_local_ranks: + cp_group_ranks = [ + dp_replica_start + tp_rank_in_slice + i * tp_base for i in range(cp_size) + ] + cp_group = dist.new_group(ranks=cp_group_ranks) + + if rank in cp_group_ranks: + parallel_state._CONTEXT_PARALLEL_GROUP = cp_group + parallel_state._CONTEXT_PARALLEL_GLOBAL_RANKS = cp_group_ranks + logger.info("[NTP] Rank %s created CP group: %s", rank, cp_group_ranks) + + # Update TENSOR_AND_CONTEXT_PARALLEL_GROUP + tp_rank_in_slice = local_rank_in_dp % tp_base + if tp_rank_in_slice in active_local_ranks: + tp_cp_group_ranks = [] + for cp_r in range(cp_size): + for active_tp in active_local_ranks: + tp_cp_group_ranks.append(dp_replica_start + cp_r * tp_base + active_tp) + tp_cp_group = dist.new_group(ranks=tp_cp_group_ranks) + parallel_state._TENSOR_AND_CONTEXT_PARALLEL_GROUP = tp_cp_group + logger.info("[NTP] Rank %s created TP-CP group: %s", rank, tp_cp_group_ranks) + else: + # Non-active (spare) rank - exit + logger.info("[NTP] Rank %s is a spare rank with CP, exiting", rank) + if exit_spares: + sys.exit(0) + return False + else: + # No CP: simpler case + dp_replica_start = dp_replica_id * dp_replica_size + tp_group_ranks = [dp_replica_start + local_tp for local_tp in active_local_ranks] + + if rank in tp_group_ranks: + tp_group = dist.new_group(ranks=tp_group_ranks) + parallel_state._TENSOR_MODEL_PARALLEL_GROUP = tp_group + parallel_state._MODEL_PARALLEL_GROUP = tp_group + parallel_state._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = tp_group_ranks + parallel_state._MODEL_PARALLEL_GLOBAL_RANKS = tp_group_ranks + logger.info("[NTP] Rank %s created TP group: %s", rank, tp_group_ranks) + else: + # Non-active (spare) rank - exit + logger.info("[NTP] Rank %s is a spare rank, exiting", rank) + if exit_spares: + sys.exit(0) + return False + + return True + + +# ====================================================================================== +# Parameter Resharding for NTP +# ====================================================================================== + + +def ntp_map(module: torch.nn.Module, ntp_config: NonuniformTPConfig, num_shards: int): + """ + Initialize TP-sharded params with mapping between healthy and unhealthy TP sizes. + + Only healthy (full TP) ranks need send_splits and recv_splits to know how to reshard + parameters when synchronizing with unhealthy (reduced TP) ranks. + Unhealthy ranks synchronize directly without resharding. + + Args: + module: Module containing parameters to initialize (e.g., self_attention or mlp) + ntp_config: NTP configuration containing tp_base and tp_spares + num_shards: Number of shards (e.g., num_attention_heads or ffn_hidden_size) + """ + if ntp_config.tp_spares == 0: + # No nonuniform TP, skip initialization + return + + # Determine which ranks are active (non-spare) for the current DP rank + rank = dist.get_rank() + dp_rank = parallel_state.get_data_parallel_rank() + cp_rank = parallel_state.get_context_parallel_rank() + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + + logger.debug( + f"[NTP] Rank {rank} [DP {dp_rank}, CP {cp_rank}, PP {pp_rank}] " + f"ntp_map called with module={type(module).__name__}, num_shards={num_shards}" + ) + + # Check if this (dp, cp, pp) combination has non-active ranks specified + # If it does, it's an unhealthy rank that uses reduced TP + if _ntp_get_non_active_ranks(ntp_config, dp_rank, cp_rank, pp_rank) is not None: + # This is an unhealthy rank with reduced TP - skip + logger.debug( + "[NTP] Rank %s [DP %s, CP %s, PP %s] Unhealthy rank, skipping", + rank, + dp_rank, + cp_rank, + pp_rank, + ) + return + + # This is a healthy rank (full TP) - it needs send/recv splits to communicate + # with unhealthy ranks that have reduced TP + logger.debug( + "[NTP] Rank %s [DP %s] Setting up send/recv splits for healthy rank", rank, dp_rank + ) + + for param in module.parameters(): + # Handle both tensor parallel parameters and vocabulary-parallel parameters that only + # carry partition_dim metadata. + if (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or ( + hasattr(param, 'partition_dim') and not hasattr(param, 'tensor_model_parallel') + ): + # For healthy ranks, compute send/recv splits for communication with unhealthy ranks + # We need to know how to reshard to match the reduced TP size + reduced_tp_size = ntp_config.tp_base - ntp_config.tp_spares + + shard_ids = torch.arange(num_shards) + # Partitions for reduced TP (what unhealthy ranks have) + sync_partitions = list(shard_ids.chunk(reduced_tp_size)) + + # Full partitions for healthy ranks (tp_base ranks) + comp_partitions = sync_partitions + [ + torch.empty(int(len(shard_ids) / ntp_config.tp_base), dtype=torch.int) + for _ in range(ntp_config.tp_spares) + ] + + # Build comp_2_sync: for spare positions, which reduced TP ranks do they map to + comp_2_sync = [[] for _ in range(ntp_config.tp_base)] + sync_part_idx = 0 + + for spare_part_idx in range(reduced_tp_size, ntp_config.tp_base): + for shard_part_idx in range(len(comp_partitions[spare_part_idx])): + # Take the last shard from the current reduced TP rank + comp_partitions[spare_part_idx][shard_part_idx] = comp_partitions[ + sync_part_idx + ][-1] + comp_partitions[sync_part_idx] = comp_partitions[sync_part_idx][:-1] + comp_2_sync[spare_part_idx].append(sync_part_idx) + sync_part_idx = (sync_part_idx + 1) % reduced_tp_size + + # Compute param_splits: how many shards each rank sends to each other rank + param_splits = [ + torch.bincount(torch.tensor(c2s, dtype=torch.int), minlength=ntp_config.tp_base) + for c2s in comp_2_sync + ] + + shard_size = int(param.shape[param.partition_dim] * ntp_config.tp_base / len(shard_ids)) + send_splits = [(p_split * shard_size).tolist() for p_split in param_splits] + recv_splits = [ + [send_splits[send_idx][recv_idx] for send_idx in range(len(send_splits))] + for recv_idx in range(ntp_config.tp_base) + ] + param.send_splits = send_splits + param.recv_splits = recv_splits + logger.debug( + f"[NTP] Rank {rank} [DP {dp_rank}] Set send_splits and recv_splits " + f"on parameter id={id(param)}, shape={param.shape}" + ) + + +def ntp_init(layer: torch.nn.Module, ntp_config: NonuniformTPConfig): + """ + Initialize nonuniform TP mappings for a TransformerLayer. + + This should be called after the layer is created to set up the send_splits + and recv_splits attributes on tensor-parallel parameters. + + Args: + layer: TransformerLayer instance + ntp_config: NTP configuration containing tp_base and tp_spares + """ + if ntp_config.tp_spares == 0: + # No nonuniform TP, skip initialization + return + + # Initialize self-attention parameters + if hasattr(layer, 'self_attention'): + ntp_map(layer.self_attention, ntp_config, layer.self_attention.config.num_attention_heads) + + # Initialize MLP parameters + if hasattr(layer, 'mlp'): + ntp_map(layer.mlp, ntp_config, layer.mlp.config.ffn_hidden_size) + + +# ====================================================================================== +# NTP-aware ParamAndGradBuffer +# ====================================================================================== + + +class NonuniformTPParamAndGradBucketGroup(_ParamAndGradBucketGroup): + """ + NTP-aware version of _ParamAndGradBucketGroup. + Skips gradient synchronization for spare GPUs. + """ + + def __init__(self, *args, ntp_config: Optional[NonuniformTPConfig] = None, **kwargs): + super().__init__(*args, **kwargs) + self.ntp_config = ntp_config or NonuniformTPConfig() + + def _wait_ntp_reshard_handles(self): + """Wait for NTP all-to-all reshard work touching params in this bucket group.""" + for bucket in self.buckets: + for param in bucket.params: + handle = getattr(param, 'ntp_reshard_handle', None) + if handle is not None: + handle.wait() + param.ntp_reshard_handle = None + + def start_grad_sync(self, force_all_reduce: Optional[bool] = False): + """Start DP grad sync after any pending NTP reshard for this bucket is complete.""" + self._wait_ntp_reshard_handles() + if not _ntp_current_rank_should_dp_sync(self.ntp_config): + self.grad_reduce_handle = None + return + return super().start_grad_sync(force_all_reduce=force_all_reduce) + + def finish_grad_sync(self, force_all_reduce: Optional[bool] = False): + """Finish DP grad sync, treating folded-away healthy spare ranks as no-ops.""" + self.param_gather_dispatched = False + self._wait_ntp_reshard_handles() + if not _ntp_current_rank_should_dp_sync(self.ntp_config): + self._copy_back_extra_main_grads() + return + return super().finish_grad_sync(force_all_reduce=force_all_reduce) + + def register_grad_ready( + self, param: torch.nn.Parameter, force_all_reduce: Optional[bool] = False + ): + """Skip DP-ready bookkeeping on ranks that are folded into core TP ranks.""" + if not _ntp_current_rank_should_dp_sync(self.ntp_config): + return + return super().register_grad_ready(param, force_all_reduce=force_all_reduce) + + +class NonuniformTPParamAndGradBuffer(_ParamAndGradBuffer): + """ + NTP-aware version of _ParamAndGradBuffer. + Adjusts buffer sizes and splits gradients for NTP. + """ + + def __init__(self, *args, ntp_config: Optional[NonuniformTPConfig] = None, **kwargs): + self.ntp_config = ntp_config or NonuniformTPConfig() + if self.ntp_config.tp_spares > 0: + ddp_config = args[0] if len(args) > 0 else kwargs['ddp_config'] + params_with_names = args[3] if len(args) > 3 else kwargs['params_with_names'] + data_parallel_group = args[4] if len(args) > 4 else kwargs['data_parallel_group'] + bucket_size = args[5] if len(args) > 5 else kwargs['bucket_size'] + param_indices = args[8] if len(args) > 8 else kwargs['param_indices'] + params = [param for param, _ in params_with_names] + kwargs['param_layout'] = _compute_ntp_per_buffer_param_layout( + params, + bucket_size, + data_parallel_group.size(), + ddp_config, + self.ntp_config, + param_indices, + ) + + super().__init__(*args, **kwargs) + + if self.ntp_config.tp_spares > 0: + for param in self.params: + if not _ntp_should_expand_param_grad(param, self.ntp_config): + continue + param_start_index, param_end_index, _ = self.param_index_map[param] + main_numel = param.data.nelement() + side_numel = param_end_index - param_start_index - main_numel + if side_numel <= 0: + continue + + side_shape = list(param.data.shape) + side_shape[param.partition_dim] = _ntp_extra_partition_dim(param, self.ntp_config) + assert torch.Size(side_shape).numel() == side_numel + side_start = param_start_index + main_numel + param.side_grad = self.grad_data[side_start:param_end_index].view(side_shape) + + +# ====================================================================================== +# NTP-aware DistributedDataParallel +# ====================================================================================== + + +class NonuniformTPDistributedDataParallel(DistributedDataParallel): + """ + NTP-aware version of DistributedDataParallel. + Adds gradient synchronization logic for spare GPUs. + """ + + def __init__( + self, + config: TransformerConfig, + ddp_config: DistributedDataParallelConfig, + module: torch.nn.Module, + disable_bucketing: bool = False, + pg_collection: Optional[ProcessGroupCollection] = None, + ntp_config: Optional[NonuniformTPConfig] = None, + full_param_layout: Optional[FullParamLayout] = None, + ): + self.ntp_config = ntp_config or NonuniformTPConfig() + + # Use NTP-aware buffer class + if self.ntp_config.tp_spares > 0: + # DDP imports _ParamAndGradBuffer into its module namespace, so patch that binding + # while the parent constructor allocates buffers. + original_buffer_class = ddp_module._ParamAndGradBuffer + ddp_module._ParamAndGradBuffer = functools.partial( + NonuniformTPParamAndGradBuffer, ntp_config=self.ntp_config + ) + try: + super().__init__( + config=config, + ddp_config=ddp_config, + module=module, + disable_bucketing=disable_bucketing, + pg_collection=pg_collection, + full_param_layout=full_param_layout, + ) + finally: + ddp_module._ParamAndGradBuffer = original_buffer_class + self._wrap_bucket_groups_for_ntp() + else: + super().__init__( + config=config, + ddp_config=ddp_config, + module=module, + disable_bucketing=disable_bucketing, + pg_collection=pg_collection, + full_param_layout=full_param_layout, + ) + + def _wrap_bucket_groups_for_ntp(self): + """Replace DDP bucket groups with NTP-aware groups and rebuild param lookup.""" + + def wrap_groups(bucket_groups): + wrapped_groups = [] + old_to_new = {} + for bucket_group in bucket_groups: + if ( + self.ddp_config.use_distributed_optimizer + or self.ddp_config.overlap_param_gather + ): + collective_group = bucket_group.intra_distributed_optimizer_instance_group + collective_group_size = bucket_group.intra_distributed_optimizer_instance_size + else: + collective_group = bucket_group.data_parallel_group + collective_group_size = bucket_group.data_parallel_group.size() + + wrapped_group = NonuniformTPParamAndGradBucketGroup( + bucket_group.buckets, + bucket_group.ddp_config, + collective_group, + collective_group_size, + ntp_config=self.ntp_config, + ) + if hasattr(bucket_group, 'inter_distributed_optimizer_instance_group'): + wrapped_group.inter_distributed_optimizer_instance_group = ( + bucket_group.inter_distributed_optimizer_instance_group + ) + if hasattr(bucket_group, 'communication_stream'): + wrapped_group.communication_stream = bucket_group.communication_stream + old_to_new[bucket_group] = wrapped_group + wrapped_groups.append(wrapped_group) + + for bucket_group, wrapped_group in old_to_new.items(): + next_group = getattr(bucket_group, 'next_param_gather_bucket_group', None) + if next_group is not None: + wrapped_group.next_param_gather_bucket_group = old_to_new[next_group] + + return wrapped_groups + + self.bucket_groups = wrap_groups(self.bucket_groups) + self.expert_parallel_bucket_groups = wrap_groups(self.expert_parallel_bucket_groups) + self.param_to_bucket_group = {} + for bucket_groups in [self.bucket_groups, self.expert_parallel_bucket_groups]: + for bucket_group in bucket_groups: + for bucket in bucket_group.buckets: + for param in bucket.params_list: + self.param_to_bucket_group[param] = bucket_group + + def _make_backward_post_hook(self, param: torch.nn.Parameter): + """ + Override to add NTP gradient synchronization between spare and core GPUs. + """ + + def ntp_hook(*unused): + if is_graph_capturing(): + return + + bucket_group = self.param_to_bucket_group.get(param) + is_last_microbatch = bucket_group is None or bucket_group.is_last_microbatch + if param in self.param_to_bucket_group: + assert param.requires_grad + if self.ddp_config.overlap_grad_reduce: + assert ( + param.grad is not None + ), 'param.grad being None is not safe when overlap_grad_reduce is True' + if param.grad is not None and ( + not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False) + ): + param.main_grad.add_(param.grad.data) + param.grad = None + + # Add NTP-specific logic + if ( + self.ntp_config.tp_spares > 0 + and _ntp_param_can_reshard(param) + and is_last_microbatch + and not _ntp_current_rank_is_reduced_dp(self.ntp_config) + and parallel_state.get_tensor_model_parallel_world_size() == self.ntp_config.tp_base + ): + empty_shape = list(param.shape) + empty_shape[param.partition_dim] = 0 + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + if tp_rank < self.ntp_config.tp_base - self.ntp_config.tp_spares: + # Core GPU: receive grads from spare GPUs + input = [ + torch.empty( + empty_shape, device=param.device, dtype=param.side_grad.dtype + ).contiguous() + for _ in range(parallel_state.get_tensor_model_parallel_world_size()) + ] + # Split side_grad and send to core GPUs + output = [ + torch.empty( + empty_shape, device=param.device, dtype=param.side_grad.dtype + ).contiguous() + for _ in range(self.ntp_config.tp_base - self.ntp_config.tp_spares) + ] + [ + t.contiguous() + for t in torch.split( + param.side_grad, param.recv_splits[tp_rank], dim=param.partition_dim + ) + ][ + -self.ntp_config.tp_spares : + ] + else: + # Spare GPU: send grads to core GPUs + input = [ + t.contiguous() + for t in torch.split( + param.main_grad, param.send_splits[tp_rank], dim=param.partition_dim + ) + ] + output = [ + torch.empty( + empty_shape, device=param.device, dtype=param.main_grad.dtype + ).contiguous() + for _ in range(parallel_state.get_tensor_model_parallel_world_size()) + ] + + try: + handle = dist.all_to_all( + output, + input, + group=parallel_state.get_tensor_model_parallel_group(), + async_op=True, + ) + param.ntp_reshard_handle = handle + except Exception as e: + logger.error("[NTP] Rank %s all_to_all error: %s", tp_rank, e) + input_contiguity = [i.is_contiguous() for i in input] + output_contiguity = [o.is_contiguous() for o in output] + logger.error( + "[NTP] Rank %s input element contiguity: %s", tp_rank, input_contiguity + ) + logger.error( + "[NTP] Rank %s output element contiguity: %s", tp_rank, output_contiguity + ) + raise e + + if param in self.param_to_bucket_group and self.ddp_config.overlap_grad_reduce: + self.param_to_bucket_group[param].register_grad_ready(param, self.force_all_reduce) + + return ntp_hook + + +# ====================================================================================== +# NTP-aware Optimizer Wrapper +# ====================================================================================== + + +class NonuniformTPOptimizer: + """ + Wrapper for optimizers to make gradients contiguous for NTP. + """ + + def __init__(self, optimizer, ntp_config: NonuniformTPConfig): + self.optimizer = optimizer + self.ntp_config = ntp_config + + def __getattr__(self, name): + """Delegate attribute access to wrapped optimizer.""" + return getattr(self.optimizer, name) + + def prepare_grads(self, *args, **kwargs): + """ + Override prepare_grads to make gradients contiguous for NTP. + """ + # Call original prepare_grads if it exists + if hasattr(self.optimizer, 'prepare_grads'): + result = self.optimizer.prepare_grads(*args, **kwargs) + else: + result = False + + # Make gradients contiguous for NTP + if self.ntp_config.tp_spares > 0: + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + if hasattr(param, 'main_grad') and param.main_grad is not None: + if not param.main_grad.is_contiguous(): + param.grad = param.main_grad.contiguous() + else: + param.grad = param.main_grad + + return result diff --git a/megatron/core/extensions/nonuniform_tp_transformer_engine.py b/megatron/core/extensions/nonuniform_tp_transformer_engine.py new file mode 100644 index 00000000000..3ebda95c16c --- /dev/null +++ b/megatron/core/extensions/nonuniform_tp_transformer_engine.py @@ -0,0 +1,157 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +"""Transformer Engine adapter helpers for nonuniform tensor parallelism.""" + +from contextlib import contextmanager +from typing import Iterable, Optional, Sequence, Tuple + +TPDomains = Tuple[Tuple[int, ...], ...] + + +def normalize_tp_domains(tp_domains: Sequence[Sequence[int]]) -> TPDomains: + """Return deterministic, validated TP domains for all ranks. + + Transformer Engine userbuffer initialization creates one process group per TP domain. + Every rank must create those groups in the same order, so callers may pass domains in + any order and this helper normalizes them by first rank. + """ + normalized = [] + seen_ranks = set() + + for domain in tp_domains: + domain_tuple = tuple(int(rank) for rank in domain) + if not domain_tuple: + raise ValueError("NTP TP domains must not be empty") + if len(set(domain_tuple)) != len(domain_tuple): + raise ValueError(f"NTP TP domain contains duplicate ranks: {domain_tuple}") + overlap = seen_ranks.intersection(domain_tuple) + if overlap: + raise ValueError(f"NTP TP domains overlap on ranks: {sorted(overlap)}") + seen_ranks.update(domain_tuple) + normalized.append(domain_tuple) + + if not normalized: + raise ValueError("At least one NTP TP domain is required") + + return tuple(sorted(normalized, key=lambda domain: (domain[0], len(domain), domain))) + + +def _subgroup_arg( + args: Tuple[object, ...], kwargs: dict, index: int, name: str, default: object = None +) -> object: + if name in kwargs: + return kwargs[name] + if len(args) > index: + return args[index] + return default + + +def _new_group_kwargs(args: Tuple[object, ...], kwargs: dict, domain_index: int) -> dict: + timeout = _subgroup_arg(args, kwargs, 0, "timeout") + backend = _subgroup_arg(args, kwargs, 1, "backend") + pg_options = _subgroup_arg(args, kwargs, 2, "pg_options") + use_local_synchronization = _subgroup_arg(args, kwargs, 3, "use_local_synchronization", False) + group_desc = _subgroup_arg(args, kwargs, 4, "group_desc") + + new_group_kwargs = {} + if timeout is not None: + new_group_kwargs["timeout"] = timeout + if backend is not None: + new_group_kwargs["backend"] = backend + if pg_options is not None: + new_group_kwargs["pg_options"] = pg_options + if use_local_synchronization: + new_group_kwargs["use_local_synchronization"] = use_local_synchronization + if group_desc is not None: + new_group_kwargs["group_desc"] = f"{group_desc}_ntp_{domain_index}" + return new_group_kwargs + + +@contextmanager +def transformer_engine_userbuffer_tp_domains( + tp_domains: Sequence[Sequence[int]], + *, + distributed: Optional[object] = None, + tp_group: Optional[object] = None, +): + """Use explicit TP domains while Transformer Engine initializes userbuffers. + + Transformer Engine currently accepts a scalar ``tp_size`` and derives TP domains by + chunking its bootstrap process group. NTP needs mixed-size TP domains, so this context + manager redirects TE's no-ranks bootstrap group to the current NTP TP group and keeps + subgroup enumeration on the caller-provided domains for TE versions or paths that need it. + """ + if distributed is None: + import torch.distributed as distributed # type: ignore[no-redef] + + domains = normalize_tp_domains(tp_domains) + rank = distributed.get_rank() + if not any(rank in domain for domain in domains): + raise RuntimeError(f"Rank {rank} is not present in any NTP TP domain: {domains}") + + original_new_group = distributed.new_group + original_new_subgroups = distributed.new_subgroups_by_enumeration + + def ntp_new_group(*args, **kwargs): + ranks = kwargs.get("ranks") + if ranks is None and args: + ranks = args[0] + if ranks is None and tp_group is not None: + return tp_group + return original_new_group(*args, **kwargs) + + def ntp_new_subgroups_by_enumeration( + _ranks_per_subgroup_list: Iterable[Iterable[int]], *args, **kwargs + ): + current_group = None + groups = [] + for domain_index, domain in enumerate(domains): + group = distributed.new_group( + ranks=list(domain), **_new_group_kwargs(args, kwargs, domain_index) + ) + groups.append(group) + if rank in domain: + current_group = group + + if current_group is None: + raise RuntimeError(f"Rank {rank} did not get an NTP TP domain") + return current_group, groups + + distributed.new_group = ntp_new_group + distributed.new_subgroups_by_enumeration = ntp_new_subgroups_by_enumeration + try: + yield domains + finally: + distributed.new_group = original_new_group + distributed.new_subgroups_by_enumeration = original_new_subgroups + + +def initialize_transformer_engine_userbuffers_for_nonuniform_tp( + *, + shape: Sequence[int], + tp_size: int, + tp_domains: Sequence[Sequence[int]], + bootstrap_backend: str, + tp_group: Optional[object] = None, + **initialize_ub_kwargs, +) -> TPDomains: + """Initialize Transformer Engine userbuffers on explicit NTP TP domains.""" + try: + from transformer_engine.pytorch import module as te_module + except ImportError as exc: + raise RuntimeError("NTP TP communication overlap requires Transformer Engine") from exc + + if tp_group is None: + from megatron.core import parallel_state + + tp_group = parallel_state.get_tensor_model_parallel_group() + + with transformer_engine_userbuffer_tp_domains( + tp_domains, tp_group=tp_group + ) as normalized_domains: + te_module.base.initialize_ub( + shape=list(shape), + tp_size=tp_size, + bootstrap_backend=bootstrap_backend, + **initialize_ub_kwargs, + ) + return normalized_domains diff --git a/tests/unit_tests/distributed/test_nonuniform_tp.py b/tests/unit_tests/distributed/test_nonuniform_tp.py new file mode 100644 index 00000000000..8a4ca6f971c --- /dev/null +++ b/tests/unit_tests/distributed/test_nonuniform_tp.py @@ -0,0 +1,886 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +""" +Unit tests for Nonuniform Tensor Parallelism (NTP). + +Tests the fault-tolerance mechanism that allows training to continue +when GPU failures occur within a tensor-parallel group. +""" + +import functools +import os +from datetime import timedelta +from unittest.mock import Mock, patch + +import pytest +import torch +import torch.distributed as dist + +import megatron.core.distributed.distributed_data_parallel as ddp_module +from megatron.core import parallel_state +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.distributed.distributed_data_parallel import DistributedDataParallel +from megatron.core.distributed.nonuniform_tp import ( + NonuniformTPConfig, + NonuniformTPDistributedDataParallel, + NonuniformTPOptimizer, + NonuniformTPParamAndGradBucketGroup, + NonuniformTPParamAndGradBuffer, + _compute_ntp_per_buffer_param_layout, + compute_uniform_tp_spares_with_parity, + get_active_ranks_for_dp, + initialize_nonuniform_tp_process_groups, + ntp_init, + ntp_map, +) +from megatron.core.extensions.nonuniform_tp_transformer_engine import ( + initialize_transformer_engine_userbuffers_for_nonuniform_tp, +) +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer import TransformerConfig +from megatron.core.utils import is_te_min_version +from tests.unit_tests.test_utilities import Utils + + +class TestNonuniformTPUtilities: + """Test utility functions for NTP configuration.""" + + def test_compute_uniform_tp_spares_with_parity_no_failures(self): + """Test with no GPU failures.""" + faulty_gpu_map = {} + tp_base = 8 + + tp_spares, non_active_ranks = compute_uniform_tp_spares_with_parity(faulty_gpu_map, tp_base) + + assert tp_spares == 0 + assert non_active_ranks == {} + + def test_compute_uniform_tp_spares_with_parity_uniform_failures(self): + """Test with uniform failures across DP ranks.""" + faulty_gpu_map = { + 0: [2, 5], # DP rank 0 has 2 failures + 1: [1, 3], # DP rank 1 has 2 failures + } + tp_base = 8 + + tp_spares, non_active_ranks = compute_uniform_tp_spares_with_parity(faulty_gpu_map, tp_base) + + assert tp_spares == 2 + assert non_active_ranks[0] == [2, 5] + assert non_active_ranks[1] == [1, 3] + + def test_compute_uniform_tp_spares_with_parity_non_uniform_failures(self): + """Test with non-uniform failures (requires padding).""" + faulty_gpu_map = {0: [2, 5], 1: [1]} # DP rank 0 has 2 failures # DP rank 1 has 1 failure + tp_base = 8 + + tp_spares, non_active_ranks = compute_uniform_tp_spares_with_parity(faulty_gpu_map, tp_base) + + assert tp_spares == 2 + assert non_active_ranks[0] == [2, 5] + # DP rank 1 should be padded with 1 additional GPU (prefer high ranks) + assert len(non_active_ranks[1]) == 2 + assert 1 in non_active_ranks[1] + # Second non-active rank should be from the end (e.g., 7) + assert non_active_ranks[1][1] == 7 + + def test_get_active_ranks_for_dp_default(self): + """Test get_active_ranks_for_dp with default (no explicit non_active_ranks_per_dp).""" + ntp_config = NonuniformTPConfig(tp_base=8, tp_spares=2) + dp_rank = 0 + tp_base = 8 + + active_ranks = get_active_ranks_for_dp(dp_rank, tp_base, ntp_config) + + # Should return first (tp_base - tp_spares) ranks + assert active_ranks == [0, 1, 2, 3, 4, 5] + + def test_get_active_ranks_for_dp_explicit(self): + """Test get_active_ranks_for_dp with explicit non_active_ranks_per_dp.""" + ntp_config = NonuniformTPConfig(tp_base=8, tp_spares=2, non_active_ranks_per_dp={0: [2, 5]}) + dp_rank = 0 + tp_base = 8 + + active_ranks = get_active_ranks_for_dp(dp_rank, tp_base, ntp_config) + + # Should exclude ranks 2 and 5 + assert active_ranks == [0, 1, 3, 4, 6, 7] + + def test_get_active_ranks_for_dp_tuple_key(self): + """Test get_active_ranks_for_dp with DP/CP/PP-scoped non-active ranks.""" + ntp_config = NonuniformTPConfig( + tp_base=8, tp_spares=2, non_active_ranks_per_dp={(1, 2, 0): [0, 7]} + ) + + active_ranks = get_active_ranks_for_dp(1, 8, ntp_config, cp_rank=2, pp_rank=0) + + assert active_ranks == [1, 2, 3, 4, 5, 6] + + +class TestNonuniformTPBufferLayout: + """Test NTP gradient-buffer layout compatibility with DDP buffer features.""" + + @patch('megatron.core.distributed.nonuniform_tp.parallel_state') + def test_layout_expands_side_grad_and_pads_for_distributed_optimizer(self, mock_parallel_state): + """Healthy core ranks need side_grad storage and DistOpt-compatible padding.""" + mock_parallel_state.get_data_parallel_rank.return_value = 1 + mock_parallel_state.get_context_parallel_rank.return_value = 0 + mock_parallel_state.get_pipeline_model_parallel_rank.return_value = 0 + mock_parallel_state.get_tensor_model_parallel_world_size.return_value = 4 + mock_parallel_state.get_tensor_model_parallel_rank.return_value = 0 + + param = torch.nn.Parameter(torch.randn(4, 2)) + param.tensor_model_parallel = True + param.partition_dim = 0 + param.send_splits = [[0, 0, 0, 0] for _ in range(4)] + param.recv_splits = [[0, 0, 2, 2] for _ in range(4)] + + ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True) + ntp_config = NonuniformTPConfig( + tp_base=4, tp_spares=2, non_active_ranks_per_dp={(0, 0, 0): [2, 3]} + ) + + layout = _compute_ntp_per_buffer_param_layout( + [param], + bucket_size=None, + data_parallel_world_size=2, + ddp_config=ddp_config, + ntp_config=ntp_config, + param_indices=[0], + ) + + assert layout.param_index_map[param] == (0, 16, 0) + assert layout.per_bucket_numel_unpadded == [16] + assert layout.bucket_indices == [(0, 128)] + assert layout.param_indices == [0] + + +class TestNonuniformTPParameterResharding: + """Test parameter resharding logic for NTP.""" + + def test_ntp_map_no_spares(self): + """Test ntp_map when tp_spares=0 (should be no-op).""" + # Create mock module with parameter + module = Mock() + param = torch.nn.Parameter(torch.randn(10, 10)) + param.tensor_model_parallel = True + param.partition_dim = 1 + module.parameters = Mock(return_value=[param]) + + ntp_config = NonuniformTPConfig(tp_base=8, tp_spares=0) + + # Should not raise error and not add send_splits/recv_splits + ntp_map(module, ntp_config, num_shards=24) + + assert not hasattr(param, 'send_splits') + assert not hasattr(param, 'recv_splits') + + @patch('megatron.core.distributed.nonuniform_tp.parallel_state') + @patch('megatron.core.distributed.nonuniform_tp.dist') + def test_ntp_map_with_spares_healthy_rank(self, mock_dist, mock_parallel_state): + """Test ntp_map for a healthy rank (should add send/recv splits).""" + # Mock parallel state + mock_dist.get_rank.return_value = 0 + mock_parallel_state.get_data_parallel_rank.return_value = 0 + mock_parallel_state.get_context_parallel_rank.return_value = 0 + mock_parallel_state.get_pipeline_model_parallel_rank.return_value = 0 + + # Create mock module with parameter + class MockConfig: + num_attention_heads = 24 + + module = Mock() + param = torch.nn.Parameter(torch.randn(384, 128)) # 384 = 24 heads * 16 dim + param.tensor_model_parallel = True + param.partition_dim = 0 + # Note: param.shape is already (384, 128) from the tensor, no need to set it + module.parameters = Mock(return_value=[param]) + module.config = MockConfig() + + ntp_config = NonuniformTPConfig( + tp_base=8, + tp_spares=2, + non_active_ranks_per_dp={}, # No explicit non-active ranks, so this is healthy + ) + + # Execute + ntp_map(module, ntp_config, num_shards=24) + + # Should have added send_splits and recv_splits + assert hasattr(param, 'send_splits') + assert hasattr(param, 'recv_splits') + assert len(param.send_splits) == 8 + assert len(param.recv_splits) == 8 + + @patch('megatron.core.distributed.nonuniform_tp.parallel_state') + @patch('megatron.core.distributed.nonuniform_tp.dist') + def test_ntp_map_with_spares_unhealthy_rank(self, mock_dist, mock_parallel_state): + """Test ntp_map for an unhealthy rank (should skip).""" + # Mock parallel state + mock_dist.get_rank.return_value = 0 + mock_parallel_state.get_data_parallel_rank.return_value = 0 + mock_parallel_state.get_context_parallel_rank.return_value = 0 + mock_parallel_state.get_pipeline_model_parallel_rank.return_value = 0 + + # Create mock module + module = Mock() + param = torch.nn.Parameter(torch.randn(10, 10)) + param.tensor_model_parallel = True + param.partition_dim = 1 + module.parameters = Mock(return_value=[param]) + + ntp_config = NonuniformTPConfig( + tp_base=8, + tp_spares=2, + non_active_ranks_per_dp={(0, 0, 0): [2, 5]}, # This rank is unhealthy + ) + + # Execute + ntp_map(module, ntp_config, num_shards=24) + + # Should NOT have added send_splits and recv_splits + assert not hasattr(param, 'send_splits') + assert not hasattr(param, 'recv_splits') + + def test_ntp_init_no_spares(self): + """Test ntp_init when tp_spares=0 (should be no-op).""" + # Create mock layer + layer = Mock() + layer.self_attention = Mock() + layer.mlp = Mock() + + ntp_config = NonuniformTPConfig(tp_base=8, tp_spares=0) + + # Should not raise error + ntp_init(layer, ntp_config) + + @patch('megatron.core.distributed.nonuniform_tp.ntp_map') + def test_ntp_init_with_attention_and_mlp(self, mock_ntp_map): + """Test ntp_init calls ntp_map for both attention and MLP.""" + + class MockConfig: + num_attention_heads = 24 + ffn_hidden_size = 4096 + + # Create mock layer + layer = Mock() + layer.self_attention = Mock() + layer.self_attention.config = MockConfig() + layer.mlp = Mock() + layer.mlp.config = MockConfig() + + ntp_config = NonuniformTPConfig(tp_base=8, tp_spares=2) + + # Execute + ntp_init(layer, ntp_config) + + # Should call ntp_map twice + assert mock_ntp_map.call_count == 2 + # First call for self_attention + assert mock_ntp_map.call_args_list[0][0][0] == layer.self_attention + assert mock_ntp_map.call_args_list[0][0][2] == 24 + # Second call for mlp + assert mock_ntp_map.call_args_list[1][0][0] == layer.mlp + assert mock_ntp_map.call_args_list[1][0][2] == 4096 + + +class TestNonuniformTPOptimizer: + """Test NTP optimizer wrapper.""" + + def test_optimizer_wrapper_delegates_attributes(self): + """Test that optimizer wrapper delegates attribute access.""" + mock_optimizer = Mock() + mock_optimizer.param_groups = [] + mock_optimizer.state = {} + + ntp_config = NonuniformTPConfig(tp_base=8, tp_spares=2) + ntp_optimizer = NonuniformTPOptimizer(mock_optimizer, ntp_config) + + # Should delegate attribute access + assert ntp_optimizer.param_groups == [] + assert ntp_optimizer.state == {} + + def test_optimizer_prepare_grads_no_spares(self): + """Test prepare_grads when tp_spares=0 (should be no-op).""" + mock_optimizer = Mock() + mock_optimizer.param_groups = [{'params': []}] + mock_optimizer.prepare_grads = Mock(return_value=False) + + ntp_config = NonuniformTPConfig(tp_base=8, tp_spares=0) + ntp_optimizer = NonuniformTPOptimizer(mock_optimizer, ntp_config) + + result = ntp_optimizer.prepare_grads() + + # Should call original prepare_grads + mock_optimizer.prepare_grads.assert_called_once() + assert result == False + + def test_optimizer_prepare_grads_makes_contiguous(self): + """Test prepare_grads makes gradients contiguous for NTP.""" + # Create parameter with non-contiguous main_grad + param = torch.nn.Parameter(torch.randn(10, 10)) + param.main_grad = torch.randn(10, 10).t() # Transposed = non-contiguous + assert not param.main_grad.is_contiguous() + + mock_optimizer = Mock() + mock_optimizer.param_groups = [{'params': [param]}] + mock_optimizer.prepare_grads = Mock(return_value=False) + + ntp_config = NonuniformTPConfig(tp_base=8, tp_spares=2) + ntp_optimizer = NonuniformTPOptimizer(mock_optimizer, ntp_config) + + ntp_optimizer.prepare_grads() + + # Should have made grad contiguous + assert hasattr(param, 'grad') + assert param.grad.is_contiguous() + + def test_optimizer_prepare_grads_already_contiguous(self): + """Test prepare_grads when gradient is already contiguous.""" + # Create parameter with contiguous main_grad + param = torch.nn.Parameter(torch.randn(10, 10)) + param.main_grad = torch.randn(10, 10) + assert param.main_grad.is_contiguous() + + mock_optimizer = Mock() + mock_optimizer.param_groups = [{'params': [param]}] + mock_optimizer.prepare_grads = Mock(return_value=False) + + ntp_config = NonuniformTPConfig(tp_base=8, tp_spares=2) + ntp_optimizer = NonuniformTPOptimizer(mock_optimizer, ntp_config) + + ntp_optimizer.prepare_grads() + + # Should have set grad directly (no copy) + assert hasattr(param, 'grad') + assert param.grad is param.main_grad + + +class TestNonuniformTPDDPCompatibility: + """Test compatibility with current Megatron DDP construction and bucket state.""" + + def test_ddp_patches_imported_buffer_binding_and_accepts_full_param_layout(self): + """DDP imports _ParamAndGradBuffer directly, so NTP must patch that binding.""" + seen = {} + original_buffer_class = ddp_module._ParamAndGradBuffer + + def fake_parent_init( + self, + *, + config, + ddp_config, + module, + disable_bucketing=False, + pg_collection=None, + full_param_layout=None, + ): + seen['buffer_binding'] = ddp_module._ParamAndGradBuffer + seen['full_param_layout'] = full_param_layout + self.ddp_config = ddp_config + self.bucket_groups = [] + self.expert_parallel_bucket_groups = [] + self.param_to_bucket_group = {} + + config = TransformerConfig( + num_layers=1, hidden_size=8, num_attention_heads=1, context_parallel_size=1 + ) + ddp_config = DistributedDataParallelConfig() + ntp_config = NonuniformTPConfig(tp_base=2, tp_spares=1) + full_param_layout = object() + + with patch.object(DistributedDataParallel, '__init__', new=fake_parent_init): + NonuniformTPDistributedDataParallel( + config=config, + ddp_config=ddp_config, + module=torch.nn.Linear(8, 8), + disable_bucketing=True, + ntp_config=ntp_config, + full_param_layout=full_param_layout, + ) + + patched_binding = seen['buffer_binding'] + assert isinstance(patched_binding, functools.partial) + assert patched_binding.func is NonuniformTPParamAndGradBuffer + assert seen['full_param_layout'] is full_param_layout + assert ddp_module._ParamAndGradBuffer is original_buffer_class + + def test_bucket_wrapping_preserves_overlap_param_gather_and_partial_distopt_state(self): + """NTP bucket wrappers should not drop DDP state set before wrapping.""" + + class FakeGroup: + def __init__(self, size=2, rank=0): + self._size = size + self._rank = rank + + def size(self): + return self._size + + def rank(self): + return self._rank + + class FakeBucketGroup: + def __init__(self, ddp_config): + self.buckets = [] + self.ddp_config = ddp_config + self.intra_distributed_optimizer_instance_group = FakeGroup() + self.intra_distributed_optimizer_instance_size = 2 + self.inter_distributed_optimizer_instance_group = 'inter-group' + self.communication_stream = 'comm-stream' + self.next_param_gather_bucket_group = None + + ddp_config = DistributedDataParallelConfig( + use_distributed_optimizer=True, + overlap_param_gather=True, + num_distributed_optimizer_instances=2, + ) + first_group = FakeBucketGroup(ddp_config) + second_group = FakeBucketGroup(ddp_config) + second_group.next_param_gather_bucket_group = first_group + + ntp_ddp = object.__new__(NonuniformTPDistributedDataParallel) + ntp_ddp.ddp_config = ddp_config + ntp_ddp.ntp_config = NonuniformTPConfig(tp_base=2, tp_spares=1) + ntp_ddp.bucket_groups = [first_group, second_group] + ntp_ddp.expert_parallel_bucket_groups = [] + ntp_ddp.param_to_bucket_group = {} + + ntp_ddp._wrap_bucket_groups_for_ntp() + + wrapped_first, wrapped_second = ntp_ddp.bucket_groups + assert isinstance(wrapped_first, NonuniformTPParamAndGradBucketGroup) + assert isinstance(wrapped_second, NonuniformTPParamAndGradBucketGroup) + assert wrapped_second.next_param_gather_bucket_group is wrapped_first + assert wrapped_second.inter_distributed_optimizer_instance_group == 'inter-group' + assert wrapped_second.communication_stream == 'comm-stream' + + +class TestNonuniformTPIntegration: + """Integration tests for NTP with DDP - run with torchrun.""" + + @classmethod + def setup_class(cls): + Utils.initialize_model_parallel(tensor_model_parallel_size=1) + + @classmethod + def teardown_class(cls): + Utils.destroy_model_parallel() + + def test_ntp_ddp_initialization(self): + """Test NonuniformTPDistributedDataParallel can be instantiated.""" + model = torch.nn.Linear(10, 10) + config = TransformerConfig( + num_layers=1, hidden_size=10, num_attention_heads=1, context_parallel_size=1 + ) + ddp_config = DistributedDataParallelConfig() + ntp_config = NonuniformTPConfig(tp_base=1, tp_spares=0) + + ntp_ddp = NonuniformTPDistributedDataParallel( + config, ddp_config, model, disable_bucketing=True, ntp_config=ntp_config + ) + from megatron.core.distributed import DistributedDataParallel + + assert isinstance(ntp_ddp, DistributedDataParallel) + + def test_ntp_backward_hook_created(self): + """Test that NTP backward hook is created without error.""" + model = torch.nn.Linear(10, 10) + model.weight.tensor_model_parallel = True + model.weight.partition_dim = 1 + + config = TransformerConfig( + num_layers=1, hidden_size=10, num_attention_heads=1, context_parallel_size=1 + ) + ddp_config = DistributedDataParallelConfig() + ntp_config = NonuniformTPConfig(tp_base=1, tp_spares=0) + + ntp_ddp = NonuniformTPDistributedDataParallel( + config, ddp_config, model, disable_bucketing=True, ntp_config=ntp_config + ) + # Verify the hook is registered on the parameter + assert model.weight._backward_hooks or ntp_ddp is not None + + +class TestNonuniformTPEndToEnd: + """ + End-to-end test for NTP without mocking. + + Tests NTP with 8 GPUs configured as: + - 2 data-parallel workers + - DP rank 0: TP=2 (reduced, using 2 out of 4 GPUs) + - DP rank 1: TP=4 (healthy, using all 4 GPUs) + - Total: 2 + 4 = 6 active GPUs out of 8 + """ + + @classmethod + def setup_class(cls): + """Initialize model parallel for NTP testing.""" + if Utils.world_size != 8: + pytest.skip(f"This test requires 8 GPUs, but only {Utils.world_size} are available") + # Initialize with tp_base=4 + Utils.initialize_model_parallel(tensor_model_parallel_size=4) + + @classmethod + def teardown_class(cls): + """Clean up model parallel.""" + Utils.destroy_model_parallel() + + def test_ntp_end_to_end_with_8_gpus(self): + """ + End-to-end test using 8 GPUs with 2 DP workers: + - DP rank 0: uses TP=2 (reduced from tp_base=4) + - DP rank 1: uses TP=4 (healthy, full tp_base) + """ + import torch.distributed as dist + + from megatron.core import parallel_state + + # Check we have 8 GPUs + world_size = dist.get_world_size() if dist.is_initialized() else 1 + if world_size != 8: + pytest.skip(f"This test requires 8 GPUs, but only {world_size} are available") + + # Get current rank info + rank = dist.get_rank() + tp_rank = parallel_state.get_tensor_model_parallel_rank() + tp_size = parallel_state.get_tensor_model_parallel_world_size() + dp_rank = parallel_state.get_data_parallel_rank() + + # Configure NTP: first DP rank uses reduced TP=2 + ntp_config = NonuniformTPConfig( + tp_base=4, + tp_spares=2, + num_reduced_tp_dp_ranks=1, + non_active_ranks_per_dp={(0, 0, 0): [2, 3]}, # DP=0: GPUs 2,3 are spares + ) + + # Check if this rank is a spare (will exit during initialization) + # Spare ranks: DP=0 with tp_rank=2,3 + is_spare = dp_rank == 0 and tp_rank in [2, 3] + + # Reconfigure process groups for NTP + # Note: spare ranks will call sys.exit(0) in initialize_nonuniform_tp_process_groups + from megatron.core.distributed.nonuniform_tp import initialize_nonuniform_tp_process_groups + + if is_spare: + # For spare ranks in test, just mark as passed and exit gracefully + pytest.skip(f"Rank {rank} is a spare rank, skipping test gracefully") + + initialize_nonuniform_tp_process_groups(ntp_config) + + # After reconfiguration, check TP size + tp_size_after = parallel_state.get_tensor_model_parallel_world_size() + + # Verify the configuration + if dp_rank == 0: + # First DP rank should have reduced TP=2 + assert tp_size_after == 2, f"DP rank 0 should have TP=2, got {tp_size_after}" + assert tp_rank < 2, f"DP rank 0 should have tp_rank < 2, got {tp_rank}" + else: + # Other DP ranks keep TP=4 + assert tp_size_after == 4, f"DP rank {dp_rank} should have TP=4, got {tp_size_after}" + assert tp_rank < 4, f"DP rank {dp_rank} should have tp_rank < 4, got {tp_rank}" + + # Create a simple model with tensor-parallel parameters + hidden_size = 128 + model = torch.nn.Linear(hidden_size, hidden_size, bias=False).cuda() + + # Mark it as tensor-parallel + model.weight.tensor_model_parallel = True + model.weight.partition_dim = 0 + + # Initialize NTP mappings + from megatron.core.distributed.nonuniform_tp import ntp_map + + # For healthy ranks (DP=1), initialize send/recv splits + if dp_rank == 1: + # Create a mock module to test ntp_map + class MockModule: + def __init__(self, param): + self.param = param + + def parameters(self): + return [self.param] + + mock_module = MockModule(model.weight) + ntp_map(mock_module, ntp_config, num_shards=hidden_size) + + # Verify send_splits and recv_splits were added + assert hasattr(model.weight, 'send_splits'), "Healthy rank should have send_splits" + assert hasattr(model.weight, 'recv_splits'), "Healthy rank should have recv_splits" + assert len(model.weight.send_splits) == 4, "Should have splits for all tp_base ranks" + + # Test forward pass + batch_size = 4 + input_tensor = torch.randn(batch_size, hidden_size, device='cuda') + output = model(input_tensor) + + # Verify output shape + assert output.shape == (batch_size, hidden_size), f"Unexpected output shape: {output.shape}" + + # Verify gradients work + loss = output.sum() + loss.backward() + assert model.weight.grad is not None, "Gradients should be computed" + + +def _new_group_for_current_rank(group_ranks, rank): + group = dist.new_group(ranks=group_ranks) + return group if rank in group_ranks else None + + +def _initialize_packed_tp2_tp4_groups(): + """Initialize a 6-rank packed NTP layout with no spare processes.""" + rank = dist.get_rank() + world_size = dist.get_world_size() + if world_size != 6: + raise RuntimeError(f"Packed TP2/TP4 NTP test requires WORLD_SIZE=6, got {world_size}") + + reduced_ranks = [0, 1] + healthy_ranks = [2, 3, 4, 5] + tp_domains = [reduced_ranks, healthy_ranks] + dp_domains = [[0, 2], [1, 3], [4], [5]] + singleton_domains = [[group_rank] for group_rank in range(world_size)] + + tp_groups = {} + for group_ranks in tp_domains: + group = _new_group_for_current_rank(group_ranks, rank) + for group_rank in group_ranks: + tp_groups[group_rank] = (group, group_ranks) + + dp_groups = {} + for group_ranks in dp_domains: + group = _new_group_for_current_rank(group_ranks, rank) + for group_rank in group_ranks: + dp_groups[group_rank] = (group, group_ranks) + + singleton_groups = {} + for group_ranks in singleton_domains: + group = _new_group_for_current_rank(group_ranks, rank) + singleton_groups[group_ranks[0]] = (group, group_ranks) + + if rank in reduced_ranks: + dp_rank = 0 + tp_rank = rank + tp_size = 2 + else: + dp_rank = 1 + tp_rank = rank - healthy_ranks[0] + tp_size = 4 + + tp_group, tp_ranks = tp_groups[rank] + dp_group, dp_ranks = dp_groups[rank] + singleton_group, singleton_ranks = singleton_groups[rank] + + parallel_state._TENSOR_MODEL_PARALLEL_GROUP = tp_group + parallel_state._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = tp_ranks + parallel_state._MODEL_PARALLEL_GROUP = tp_group + parallel_state._MODEL_PARALLEL_GLOBAL_RANKS = tp_ranks + parallel_state._PIPELINE_MODEL_PARALLEL_GROUP = singleton_group + parallel_state._PIPELINE_GLOBAL_RANKS = singleton_ranks + parallel_state._CONTEXT_PARALLEL_GROUP = singleton_group + parallel_state._CONTEXT_PARALLEL_GLOBAL_RANKS = singleton_ranks + parallel_state._HIERARCHICAL_CONTEXT_PARALLEL_GROUPS = [singleton_group] + parallel_state._TENSOR_AND_CONTEXT_PARALLEL_GROUP = tp_group + parallel_state._DATA_PARALLEL_GROUP = dp_group + parallel_state._DATA_PARALLEL_GROUP_WITH_CP = dp_group + parallel_state._INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP = dp_group + parallel_state._DATA_PARALLEL_GLOBAL_RANKS = dp_ranks + parallel_state._DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = dp_ranks + parallel_state._TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = dist.group.WORLD + parallel_state._EMBEDDING_GROUP = singleton_group + parallel_state._EMBEDDING_GLOBAL_RANKS = singleton_ranks + parallel_state._POSITION_EMBEDDING_GROUP = singleton_group + parallel_state._POSITION_EMBEDDING_GLOBAL_RANKS = singleton_ranks + parallel_state._EXPERT_MODEL_PARALLEL_GROUP = singleton_group + parallel_state._EXPERT_MODEL_PARALLEL_RANKS = singleton_ranks + parallel_state._EXPERT_TENSOR_PARALLEL_GROUP = tp_group + parallel_state._EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP = tp_group + parallel_state._EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP = tp_group + parallel_state._EXPERT_DATA_PARALLEL_GROUP = dp_group + parallel_state._INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP = dp_group + parallel_state._INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP = dp_group + + parallel_state.set_tensor_model_parallel_world_size(tp_size) + parallel_state.set_tensor_model_parallel_rank(tp_rank) + parallel_state.set_pipeline_model_parallel_world_size(1) + parallel_state.set_pipeline_model_parallel_rank(0) + parallel_state.set_data_parallel_rank(dp_rank) + parallel_state._set_global_memory_buffer() + + return ProcessGroupCollection( + tp=tp_group, + pp=singleton_group, + mp=tp_group, + embd=singleton_group, + pos_embd=singleton_group, + cp=singleton_group, + tp_cp=tp_group, + hcp=[singleton_group], + ep=singleton_group, + expt_tp=tp_group, + tp_ep=tp_group, + tp_ep_pp=tp_group, + dp=dp_group, + dp_cp=dp_group, + dp_cp_ag=None, + expt_dp=dp_group, + expt_dp_ag=None, + intra_dp_cp=dp_group, + intra_expt_dp=dp_group, + inter_dist_opt=None, + intra_dist_opt=dp_group, + tp_dp_cp=dist.group.WORLD, + ) + + +def _apply_ntp_mappings_to_gpt(model, ntp_config): + for module in model.modules(): + if module.__class__.__name__ == "TransformerLayer": + ntp_init(module, ntp_config) + if hasattr(model, "embedding") and hasattr(model.embedding, "word_embeddings"): + ntp_map(model.embedding.word_embeddings, ntp_config, 512) + if hasattr(model, "output_layer"): + ntp_map(model.output_layer, ntp_config, 512) + + +@pytest.mark.skipif(not is_te_min_version("1.9.0"), reason="TE userbuffers require TE >= 1.9") +class TestNonuniformTPPackedTEEndToEnd: + """End-to-end NTP Megatron test using only active TP2 + TP4 ranks.""" + + @classmethod + def setup_class(cls): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + if world_size != 6: + pytest.skip(f"This test requires 6 GPUs, but only {world_size} are available") + + rank = int(os.environ["RANK"]) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + torch.cuda.set_device(local_rank % torch.cuda.device_count()) + + if not dist.is_initialized(): + init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" + dist.init_process_group( + backend="nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + timeout=timedelta(minutes=5), + device_id=torch.device("cuda", torch.cuda.current_device()), + ) + dist.barrier(device_ids=[torch.cuda.current_device()]) + + Utils.world_size = world_size + Utils.rank = rank + Utils.inited = True + parallel_state.destroy_model_parallel() + + @classmethod + def teardown_class(cls): + try: + from transformer_engine.pytorch import module as te_module + + te_module.base.destroy_ub() + except Exception: + pass + if dist.is_initialized(): + try: + torch.cuda.synchronize() + dist.barrier() + except Exception: + pass + parallel_state.destroy_model_parallel() + dist.destroy_process_group() + Utils.inited = False + + def test_ntp_te_tp_comm_overlap_with_packed_tp2_tp4(self): + pg_collection = _initialize_packed_tp2_tp4_groups() + model_parallel_cuda_manual_seed(123) + + ntp_config = NonuniformTPConfig( + tp_base=4, + tp_spares=2, + num_reduced_tp_dp_ranks=1, + non_active_ranks_per_dp={(0, 0, 0): [2, 3]}, + ) + + seq_len = 16 + micro_batch_size = 1 + hidden_size = 256 + tp_size = parallel_state.get_tensor_model_parallel_world_size() + + tp_domains = [[0, 1], [2, 3, 4, 5]] + normalized_domains = initialize_transformer_engine_userbuffers_for_nonuniform_tp( + shape=[seq_len * micro_batch_size, hidden_size], + tp_size=tp_size, + tp_domains=tp_domains, + bootstrap_backend="nccl", + ) + assert normalized_domains == ((0, 1), (2, 3, 4, 5)) + + for env_var in ("NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN"): + os.environ.pop(env_var, None) + config = TransformerConfig( + num_layers=1, + hidden_size=hidden_size, + ffn_hidden_size=1024, + num_attention_heads=8, + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=1, + sequence_parallel=True, + tp_comm_overlap=True, + tp_comm_overlap_rs_dgrad=True, + bf16=True, + params_dtype=torch.bfloat16, + pipeline_dtype=torch.bfloat16, + use_cpu_initialization=False, + attention_dropout=0.0, + hidden_dropout=0.0, + ) + model = GPTModel( + config=config, + transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(), + vocab_size=512, + max_sequence_length=seq_len, + parallel_output=True, + share_embeddings_and_output_weights=False, + position_embedding_type="none", + pg_collection=pg_collection, + ).cuda() + model.bfloat16() + _apply_ntp_mappings_to_gpt(model, ntp_config) + + ddp_config = DistributedDataParallelConfig( + grad_reduce_in_fp32=False, overlap_grad_reduce=True, bucket_size=100_000_000 + ) + ddp_model = NonuniformTPDistributedDataParallel( + config=config, + ddp_config=ddp_config, + module=model, + disable_bucketing=False, + pg_collection=pg_collection, + ntp_config=ntp_config, + ) + + tokens = torch.randint( + low=0, high=512, size=(micro_batch_size, seq_len), dtype=torch.long, device="cuda" + ) + labels = torch.randint( + low=0, high=512, size=(micro_batch_size, seq_len), dtype=torch.long, device="cuda" + ) + position_ids = torch.arange(seq_len, dtype=torch.long, device="cuda").unsqueeze(0) + loss_mask = torch.ones((micro_batch_size, seq_len), dtype=torch.float32, device="cuda") + + ddp_model.zero_grad_buffer() + losses = ddp_model(tokens, position_ids, None, labels=labels, loss_mask=loss_mask) + loss = torch.sum(losses.float() * loss_mask) / loss_mask.sum() + loss.backward() + ddp_model.finish_grad_sync() + + assert torch.isfinite(loss.detach()).item() + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/unit_tests/extension/test_nonuniform_tp_transformer_engine.py b/tests/unit_tests/extension/test_nonuniform_tp_transformer_engine.py new file mode 100644 index 00000000000..24462f813ca --- /dev/null +++ b/tests/unit_tests/extension/test_nonuniform_tp_transformer_engine.py @@ -0,0 +1,78 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import pytest + +from megatron.core.extensions.nonuniform_tp_transformer_engine import ( + normalize_tp_domains, + transformer_engine_userbuffer_tp_domains, +) + + +class FakeDistributed: + def __init__(self, rank): + self.rank = rank + self.created_groups = [] + + def get_rank(self): + return self.rank + + def new_group(self, ranks, **kwargs): + group = {"ranks": tuple(ranks), "kwargs": kwargs} + self.created_groups.append(group) + return group + + def new_subgroups_by_enumeration(self, ranks_per_subgroup_list, *args, **kwargs): + return "original", ranks_per_subgroup_list, args, kwargs + + +def test_normalize_tp_domains_sorts_domains_for_collective_creation_order(): + domains = normalize_tp_domains([[4, 5, 6, 7], [0, 1], [2, 3]]) + + assert domains == ((0, 1), (2, 3), (4, 5, 6, 7)) + + +def test_normalize_tp_domains_rejects_overlapping_domains(): + with pytest.raises(ValueError, match="overlap"): + normalize_tp_domains([[0, 1], [1, 2]]) + + +def test_userbuffer_tp_domains_overrides_and_restores_subgroup_enumeration(): + fake_dist = FakeDistributed(rank=5) + original_new_subgroups = fake_dist.new_subgroups_by_enumeration + + with transformer_engine_userbuffer_tp_domains( + [[4, 5, 6, 7], [0, 1]], distributed=fake_dist + ) as domains: + current_group, groups = fake_dist.new_subgroups_by_enumeration( + [[0, 1], [2, 3]], backend="nccl", group_desc="UB" + ) + + assert domains == ((0, 1), (4, 5, 6, 7)) + assert current_group is groups[1] + assert [group["ranks"] for group in groups] == [(0, 1), (4, 5, 6, 7)] + assert groups[0]["kwargs"] == {"backend": "nccl", "group_desc": "UB_ntp_0"} + assert groups[1]["kwargs"] == {"backend": "nccl", "group_desc": "UB_ntp_1"} + assert fake_dist.new_subgroups_by_enumeration is original_new_subgroups + + +def test_userbuffer_tp_domains_redirects_default_group_to_current_tp_group(): + fake_dist = FakeDistributed(rank=5) + tp_group = {"ranks": (4, 5, 6, 7)} + original_new_group = fake_dist.new_group + + with transformer_engine_userbuffer_tp_domains( + [[4, 5, 6, 7], [0, 1]], distributed=fake_dist, tp_group=tp_group + ): + assert fake_dist.new_group(backend="nccl") is tp_group + explicit_group = fake_dist.new_group(ranks=[0, 1], backend="nccl") + + assert explicit_group == {"ranks": (0, 1), "kwargs": {"backend": "nccl"}} + assert fake_dist.new_group is original_new_group + + +def test_userbuffer_tp_domains_requires_current_rank_to_be_in_a_domain(): + fake_dist = FakeDistributed(rank=8) + + with pytest.raises(RuntimeError, match="not present"): + with transformer_engine_userbuffer_tp_domains([[0, 1], [4, 5]], distributed=fake_dist): + pass