From 9b48e2d67dbaf2204f245e2f76ffef8a32b556cb Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 6 Feb 2026 16:44:14 +0800 Subject: [PATCH 1/4] support fsdp2 muon optimizer --- xtuner/v1/config/__init__.py | 3 +- xtuner/v1/config/optim.py | 97 ++- xtuner/v1/engine/train_engine.py | 20 +- xtuner/v1/optim/__init__.py | 4 + xtuner/v1/optim/muon.py | 838 ++++++++++++++++++++++++ xtuner/v1/optim/newton_schulz_triton.py | 350 ++++++++++ xtuner/v1/train/arguments/arguments.py | 12 +- 7 files changed, 1299 insertions(+), 25 deletions(-) create mode 100644 xtuner/v1/optim/__init__.py create mode 100644 xtuner/v1/optim/muon.py create mode 100644 xtuner/v1/optim/newton_schulz_triton.py diff --git a/xtuner/v1/config/__init__.py b/xtuner/v1/config/__init__.py index 2fff8d16d..64eb447d8 100644 --- a/xtuner/v1/config/__init__.py +++ b/xtuner/v1/config/__init__.py @@ -1,6 +1,6 @@ from .fsdp import FSDPConfig from .generate import GenerateConfig -from .optim import AdamWConfig, LRConfig, OptimConfig +from .optim import AdamWConfig, LRConfig, MuonConfig, OptimConfig __all__ = [ @@ -9,4 +9,5 @@ "AdamWConfig", "LRConfig", "GenerateConfig", + "MuonConfig", ] diff --git a/xtuner/v1/config/optim.py b/xtuner/v1/config/optim.py index c165118e7..5827edd8c 100644 --- a/xtuner/v1/config/optim.py +++ b/xtuner/v1/config/optim.py @@ -2,10 +2,17 @@ from typing import Literal, Optional, Tuple import torch +import torch.distributed as dist from cyclopts import Parameter from pydantic import BaseModel, ConfigDict from typing_extensions import Annotated +from xtuner.v1.optim import Muon +from xtuner.v1.utils import get_logger + + +logger = get_logger() + class OptimConfig(BaseModel): model_config = ConfigDict(extra="forbid") @@ -26,12 +33,100 @@ class AdamWConfig(OptimConfig): eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Adam optimizer")] = 1e-8 foreach: Annotated[Optional[bool], Parameter(help="Use foreach implementation for AdamW")] = None - def build(self, params): + def build(self, model): + params = [p for p in model.parameters() if p.requires_grad] + + trainable_parameters_names = model.trainable_parameters() + trainable_names = [name for name, _ in trainable_parameters_names] + untrainable_names = [] + num_total_requires_grad = 0 + num_total = 0 + for name, params_ in model.named_parameters(): + num_total += params_.numel() + num_total_requires_grad += params_.numel() if name in trainable_names else 0 + if name not in trainable_names: + untrainable_names.append(name) + + if dist.get_rank() == 0: + logger.info( + f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M" + ) + logger.info(f"Untrainable parameters names: {untrainable_names}") return torch.optim.AdamW( params, lr=self.lr, betas=self.betas, eps=self.eps, weight_decay=self.weight_decay, foreach=self.foreach ) +class MuonConfig(OptimConfig): + weight_decay: Annotated[float, Parameter(help="Weight decay coefficient for L2 regularization")] = 0.1 + momentum: Annotated[float, Parameter(help="Momentum coefficients for Muon optimizer")] = 0.95 + betas: Annotated[Tuple[float, float], Parameter(help="Beta coefficients for AdamW optimizer")] = (0.9, 0.95) + eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Muon optimizer")] = 1e-8 + + def build(self, model): + trainable_parameters_names = model.trainable_parameters() + trainable_names = {name for name, _ in trainable_parameters_names} + + untrainable_names = [] + num_total = 0 + num_total_requires_grad = 0 + num_muon = 0 + num_adamw = 0 + + for name, p in model.named_parameters(): + n = p.numel() + num_total += n + if name in trainable_names: + num_total_requires_grad += n + is_muon_tensor = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + if is_muon_tensor: + num_muon += n + else: + num_adamw += n + else: + untrainable_names.append(name) + + muon_params = [ + p + for name, p in model.named_parameters() + if name in trainable_names and p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name + ] + adamw_params = [ + p + for name, p in model.named_parameters() + if name in trainable_names and not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name) + ] + param_groups = [ + dict(params=muon_params), + dict(params=adamw_params, algorithm="adamw"), + ] + + if dist.get_rank() == 0: + logger.info( + f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M" + ) + logger.info(f"Muon params: {num_muon // 1e6}M, AdamW params: {num_adamw // 1e6}M (counts by numel)") + logger.info(f"Untrainable parameters names: {untrainable_names}") + logger.info( + f"using Muon optimizer distributed_mesh_size: {model.fsdp_mesh.size()}, " + f"distributed_mesh: {model.fsdp_mesh}" + ) + + optimizer = Muon( + param_groups, + distributed_mesh=model.fsdp_mesh, # TODO: ζš‚δΈζ”―ζŒ EP>1 + lr=self.lr, + mu=self.momentum, + betas=self.betas, + weight_decay=self.weight_decay, + nesterov=True, + adjust_lr="rms_norm", + use_triton=False, + epsilon=self.eps, + ) + return optimizer + + class LRConfig(BaseModel): model_config = ConfigDict(extra="forbid") lr_type: Annotated[Literal["cosine", "linear", "constant"], Parameter(help="Type of learning rate schedule")] = ( diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index f725627f1..b5e77dbe7 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -184,25 +184,7 @@ def build_model(self) -> BaseModel: return model def build_optimizer(self, optim_cfg: OptimConfig) -> torch.optim.Optimizer: - params = [p for p in self.model.parameters() if p.requires_grad] - - trainable_parameters_names = self.model.trainable_parameters() - trainable_names = [name for name, _ in trainable_parameters_names] - untrainable_names = [] - num_total_requires_grad = 0 - num_total = 0 - for name, params_ in self.model.named_parameters(): - num_total += params_.numel() - num_total_requires_grad += params_.numel() if name in trainable_names else 0 - if name not in trainable_names: - untrainable_names.append(name) - - if dist.get_rank() == 0: - logger.info( - f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M" - ) - logger.info(f"Untrainable parameters names: {untrainable_names}") - return optim_cfg.build(params) + return optim_cfg.build(self.model) @property def data_replicate_size(self) -> int: diff --git a/xtuner/v1/optim/__init__.py b/xtuner/v1/optim/__init__.py new file mode 100644 index 000000000..3a80f8163 --- /dev/null +++ b/xtuner/v1/optim/__init__.py @@ -0,0 +1,4 @@ +from .muon import Muon + + +__all__ = ["Muon"] diff --git a/xtuner/v1/optim/muon.py b/xtuner/v1/optim/muon.py new file mode 100644 index 000000000..1faa2ab89 --- /dev/null +++ b/xtuner/v1/optim/muon.py @@ -0,0 +1,838 @@ +# ================================================================ +# Copyright (c) 2026 InternLM. All rights reserved. +# +# This file is developed based on and inspired by the following projects: +# +# 1. Muon Optimizer +# Author: Keller Jordan +# Blog: https://kellerjordan.github.io/posts/muon/ +# +# 2. Moonlight +# https://github.com/MoonshotAI/Moonlight +# Licensed under the MIT License +# Copyright (c) 2025 Moonshot AI +# +# 3. DION +# https://github.com/microsoft/dion +# Licensed under the MIT License +# Copyright (c) 2025 Microsoft +# +# Portions of this code are adapted from the above projects. +# All rights and licenses of the original works are respected. +# ================================================================ + + +import math +from collections import defaultdict +from itertools import chain +from typing import Callable, Generator, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DeviceMesh, DTensor +from torch.optim.optimizer import Optimizer, ParamsT + +from .newton_schulz_triton import newton_schulz_triton + + +def to_local(tensor: Union[Tensor, List[Tensor]]) -> Union[Tensor, List[Tensor]]: + """Convert a single DTensor or list of DTensors to local tensors. + + This is a no-op for regular tensors. + """ + if isinstance(tensor, Tensor): + return tensor.to_local() if isinstance(tensor, DTensor) else tensor + return [t.to_local() if isinstance(t, DTensor) else t for t in tensor] + + +def dtensor_from_local(tensor: Union[Tensor, List[Tensor]], ref: Tensor) -> Union[DTensor, List[DTensor]]: + """Convert a single local Tensor or list of local Tensors to DTensor. + + The reference tensor's device mesh and placements are used to create the DTensor. if the reference tensor is not a + DTensor, we return the input unmodified. + """ + if not isinstance(ref, DTensor): + assert isinstance(ref, Tensor) + return tensor + + device_mesh = ref.device_mesh + placements = ref.placements + + # If we have a single tensor + if isinstance(tensor, Tensor): + assert not isinstance(tensor, DTensor) + return DTensor.from_local(tensor, device_mesh=device_mesh, placements=placements) + + # We have a list of tensors + assert not isinstance(tensor[0], DTensor) + return [DTensor.from_local(t, device_mesh=device_mesh, placements=placements) for t in tensor] + + +def create_param_batches(params: List[Tensor], batch_size: int) -> Generator[List[Tensor], None, None]: + """Batch parameters into groups of size `batch_size`. + + Tensors in each batch will have identical shape, sharding, and dtype. + """ + # Group parameters by shape, sharding, and dtype + groups = defaultdict(list) + for p in params: + sharding = p.placements if isinstance(p, DTensor) else None + groups[(p.shape, sharding, p.dtype)].append(p) + + # Create batches from grouped parameters + for group in groups.values(): + for i in range(0, len(group), batch_size): + batch = group[i : i + batch_size] + yield batch + + +def pad_batch(batch: List[Tensor], batch_size: int) -> List[Tensor]: + """Insert dummy tensors so the batch has exactly `batch_size` elements.""" + assert len(batch) > 0 + assert len(batch) <= batch_size + while len(batch) < batch_size: + batch.append(torch.zeros_like(batch[0])) + return batch + + +class AsyncTask: + """AsyncTask wraps a Python generator to run until the next yield + statement. + + This is used to allow other tasks to run while waiting for distributed operations. + """ + + def __init__(self, generator: Generator[None, None, None]): + self._generator = generator + self.run() # Start running the generator + + def run(self) -> bool: + # Run the next step of the async task. + # Returns True if the task is still running and False if completed. + try: + next(self._generator) + return True + except StopIteration: + pass + return False + + +class AsyncRuntime: + """Event loop for running multiple async tasks concurrently.""" + + def __init__(self, task_gen: Generator["AsyncTask", None, None], max_concurrent_tasks: int): + # Initialize runtime with a generator that produces AsyncTask objects + if max_concurrent_tasks <= 0: + raise ValueError(f"{max_concurrent_tasks=} cannot be <= 0") + self._task_gen = task_gen + self._max_concurrent_tasks = max_concurrent_tasks + + def _get_next_task(self) -> Optional["AsyncTask"]: + try: + task = next(self._task_gen) + return task + except StopIteration: + return None + + def run(self): + # Run the event loop until all tasks are completed + have_new_tasks = True + previous_tasks: List["AsyncTask"] = [] + + while have_new_tasks or previous_tasks: + # See if we can add another task + running_tasks = [] + if have_new_tasks and len(previous_tasks) < self._max_concurrent_tasks: + new_task = self._get_next_task() + if new_task is not None: + # Add new task to the queue + running_tasks.append(new_task) + else: + # No more tasks left + have_new_tasks = False + + # Run all previous tasks for one step + for task in previous_tasks: + still_running = task.run() + if still_running: + running_tasks.append(task) + + # Update task list for next iteration + previous_tasks = running_tasks + + +@torch.compile(fullgraph=True) +def adamw_update( + X: Tensor, # Model weights (modified in place) + G: Tensor, # Gradient + M: Tensor, # Momentum buffer (modified in place) + V: Tensor, # Variance buffer (modified in place) + lr: Tensor, # Learning rate (scalar tensor) + beta1: Tensor, # Beta 1 (scalar tensor) + beta2: Tensor, # Beta 2 (scalar tensor) + weight_decay: Tensor, # Weight decay (scalar tensor) + step: int, + epsilon: float, +): + """AdamW optimizer algorithm.""" + assert X.shape == G.shape + assert X.shape == M.shape + + # Update momentum and variance + # M = beta1 * M + (1 - beta1) * G + M.lerp_(G.to(M.dtype), 1 - beta1) + # V = beta2 * V + (1 - beta2) * G * G + V.mul_(beta2).addcmul_(G, G, value=1 - beta2) + + # Bias correction + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + bias_correction2_sqrt = bias_correction2.sqrt() + + # The goal is to compute the following in-place: + # M = M / bias_correction1 + # V = V / bias_correction2 + # X = X - lr * M / (sqrt(V) + epsilon) + + # sqrt(V / bias_correction2) = sqrt(V) / sqrt(bias_correction2) + denom = V.sqrt().div_(bias_correction2_sqrt).add_(epsilon) + + # Adjust learning rate to include bias correction 1 + adj_lr = lr / bias_correction1 + + # Apply weight decay + X.mul_(1 - lr * weight_decay) + + # Weight update + # X = X - adj_lr * M / denom + X.addcdiv_(M, denom, value=-adj_lr) + + +@torch.compile(fullgraph=True) +def adamw_update_foreach( + X: List[Tensor], # Model weights (modified in place) + G: List[Tensor], # Gradient + M: List[Tensor], # Momentum buffer (modified in place) + V: List[Tensor], # Variance buffer (modified in place) + lr: Tensor, # Learning rate (scalar tensor) + beta1: Tensor, # Beta 1 (scalar tensor) + beta2: Tensor, # Beta 2 (scalar tensor) + weight_decay: Tensor, # Weight decay (scalar tensor) + step: int, + epsilon: float, +): + """AdamW optimizer algorithm (foreach implementation).""" + batch_size = len(X) + assert batch_size == len(G) + assert batch_size == len(M) + assert batch_size == len(V) + + M_dtype = M[0].dtype + V_dtype = V[0].dtype + + # Update momentum and variance + # M = beta1 * M + (1 - beta1) * G + G = [g.to(dtype=M_dtype) for g in G] + torch._foreach_lerp_(M, G, [1 - beta1] * batch_size) + + # V = beta2 * V + (1 - beta2) * G * G + G_square = torch._foreach_mul(G, G) + G_square = [g.to(dtype=V_dtype) for g in G_square] + torch._foreach_lerp_(V, G_square, [1 - beta2] * batch_size) + + # Bias correction + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + bias_correction2_sqrt = bias_correction2.sqrt() + + # The goal is to compute the following in-place: + # M = M / bias_correction1 + # V = V / bias_correction2 + # X = X - lr * M / (sqrt(V) + epsilon) + + # Compute the denominator for the weight update + # sqrt(V / bias_correction2) = sqrt(V) / sqrt(bias_correction2) + denom = torch._foreach_sqrt(V) + torch._foreach_div_(denom, bias_correction2_sqrt) + torch._foreach_add_(denom, [epsilon] * batch_size) + + # Adjust learning rate to include bias correction 1 + adj_lr = lr / bias_correction1 + + # Apply weight decay + torch._foreach_mul_(X, 1 - lr * weight_decay) + + # Weight update + # X = X - adj_lr * M / denom + M_div = torch._foreach_div(M, denom) + torch._foreach_mul_(M_div, adj_lr) + torch._foreach_sub_(X, M_div) + + +class Muon(Optimizer): + """Distributed Muon optimizer for PyTorch FSDP2. Also compatible with DDP. + + Args: + params: Parameters for the optimizer. + distributed_mesh: DeviceMesh or ProcessGroup for distributed training. + Use DeviceMesh for FSDP2 and ProcessGroup for DistributedDataParallel. + lr: Base learning rate. For Muon, this will be scaled based on the matrix dimensions. + For element-wise update rules, this is the actual learning rate and no additional scaling is done. + mu: Momentum factor for Muon algorithm. + betas: Tuple of (beta1, beta2) for AdamW algorithms. + weight_decay: Weight decay factor. + epsilon: Small value to avoid division by zero. + nesterov: Whether to use Nesterov momentum. + adjust_lr: How to adjust the learning rate for Muon updates ("spectral_norm" or "rms_norm" or None). + "spectral_norm": Adjust based on spectral norm, for learning rate transfer across model scale. + "rms_norm": Adjust based on RMS norm, for learning rate compatibility with Adam/AdamW. + None: Do not adjust the learning rate. + flatten: Whether to flatten 3D+ tensors to 2D for Muon updates. + True: Tensors with 3+ dimensions are flattened to 2D. Use this for convolutional layers. + False: Tensors are not flattened. 3D+ tensors are treated as batches of 2D matrices. + use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided. + newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization. + Signature is `func(input: Tensor, epsilon: float) -> Tensor`. + + Muon optimizer algorithm by Keller Jordan: https://kellerjordan.github.io/posts/muon/ + FSDP2 Muon uses all-to-all communications: https://www.essential.ai/blog/infra + """ + + def __init__( + self, + params: ParamsT, + distributed_mesh: Optional[Union[DeviceMesh, ProcessGroup]] = None, + lr: float = 0.01, + mu: float = 0.95, + betas: Tuple[float, float] = (0.9, 0.95), + weight_decay: float = 0.01, + epsilon: float = 1e-8, + nesterov: bool = False, + adjust_lr: Optional[str] = "spectral_norm", + flatten: bool = False, + use_triton: bool = False, + newton_schulz_func: Optional[Callable] = None, + ): + # Check hyperparameters + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if mu < 0.0: + raise ValueError(f"Invalid momentum factor (mu): {mu}") + if len(betas) != 2 or betas[0] < 0.0 or betas[1] < 0.0: + raise ValueError(f"Invalid betas: {betas}") + if adjust_lr not in ("spectral_norm", "rms_norm", None): + raise ValueError(f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None.") + + # Default arguments for each param group + defaults = dict( + lr=lr, + mu=mu, + beta1=betas[0], + beta2=betas[1], + weight_decay=weight_decay, + algorithm="muon", + step=0, + epsilon=epsilon, + nesterov=nesterov, + flatten=flatten, + adjust_lr=adjust_lr, + ) + super().__init__(params, defaults) + + # Distributed configuration + if isinstance(distributed_mesh, DeviceMesh): + if distributed_mesh.ndim != 1: + raise ValueError( + f"Only 1D DeviceMesh is supported, but got {distributed_mesh.ndim}D. For HSDP, provide the 1D sharded sub-mesh." + ) + self._device_rank = distributed_mesh.get_local_rank() + self._world_size = distributed_mesh.size() + self._process_group = distributed_mesh.get_group() + elif isinstance(distributed_mesh, ProcessGroup): + self._device_rank = dist.get_rank(distributed_mesh) + self._world_size = dist.get_world_size(distributed_mesh) + self._process_group = distributed_mesh + elif distributed_mesh is None: + self._device_rank = 0 + self._world_size = 1 + self._process_group = None + else: + raise TypeError( + f"Invalid distributed_mesh type: {type(distributed_mesh)}. Expected DeviceMesh or ProcessGroup." + ) + self._distributed_mesh = distributed_mesh + + # Newton-Schulz configuration + if newton_schulz_func is not None: + if not callable(newton_schulz_func): + raise TypeError(f"newton_schulz_func must be a callable function, got {type(newton_schulz_func)}") + self._newton_schulz_func = newton_schulz_func + elif use_triton: + self._newton_schulz_func = newton_schulz_triton + else: + self._newton_schulz_func = zeropower_via_newtonschulz5 + + @torch.no_grad() + def step(self, closure=None): + """Perform a single optimization step.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + muon_groups = [] + adamw_groups = [] + + for group in self.param_groups: + # Increment step + group["step"] += 1 + + # Split parameter groups by algorithm + algo = group["algorithm"] + if algo == "muon": + muon_groups.append(group) + elif algo == "adamw": + adamw_groups.append(group) + else: + raise ValueError(f"Unknown algorithm: {algo}") + + # Create async tasks for each algorithm + muon_tasks = self._create_muon_tasks(muon_groups) + adamw_tasks = self._create_adamw_tasks(adamw_groups) + + all_tasks = chain(muon_tasks, adamw_tasks) + runtime = AsyncRuntime(all_tasks, max_concurrent_tasks=3) + runtime.run() + + return loss + + def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: + """Get optimizer state for the given parameter tensor, or lazy- + initialize it if it doesn't exist.""" + state = self.state[param] + if not state: + state["momentum"] = torch.zeros_like(param) + if algo == "adamw": + state["variance"] = torch.zeros_like(param) + return state + + def _create_muon_tasks( + self, + param_groups: List[dict], + algo_name: str = "muon", + ) -> Generator["AsyncTask", None, None]: + """Helper function to create batches of Muon matrices and generate + AsyncTask objects so we can process multiple batches concurrently.""" + for group in param_groups: + assert group["algorithm"] == algo_name + assert all(p.ndim >= 2 for p in group["params"]), "Muon optimizer only supports matrix parameters." + + group_params = [p for p in group["params"] if p.grad is not None] + if not group_params: + continue + + # Wrap hyperparameters in tensors for torch.compile + lr = torch.tensor(group["lr"]) + mu = torch.tensor(group["mu"]) + weight_decay = torch.tensor(group["weight_decay"]) + epsilon = torch.tensor(group["epsilon"]) + nesterov = group["nesterov"] + flatten = group["flatten"] + adjust_lr = group["adjust_lr"] + + # Create batches of parameters of size self._world_size + for params in create_param_batches(group_params, batch_size=self._world_size): + gradients = [p.grad for p in params] + states = [self._get_or_initialize_state(p, algo_name) for p in params] + momentums = [s["momentum"] for s in states] + + # Get sharding dimension + sharded_mesh_dim = None + sharded_tensor_dim = None + if isinstance(params[0], DTensor): + if not isinstance(self._distributed_mesh, DeviceMesh): + raise RuntimeError("Must create optimizer with DeviceMesh if using DTensor parameters.") + + # Find the sharded placement and get its mesh and tensor dimensions + # Skip any Shard() placements on size-1 mesh dimension = Replicate() + shard_placements = [ + (i, p) + for i, p in enumerate(params[0].placements) + if p.is_shard() and params[0].device_mesh.size(i) > 1 + ] + if len(shard_placements) == 1: + sharded_mesh_dim = shard_placements[0][0] + sharded_tensor_dim = shard_placements[0][1].dim + elif len(shard_placements) > 1: + raise NotImplementedError("Muon does not support parameters with multiple sharded dimensions.") + + # Check that the sharded mesh dimension matches optimizer's device mesh + if ( + sharded_mesh_dim is not None + and params[0].device_mesh.get_group(sharded_mesh_dim) != self._process_group + ): + raise RuntimeError( + f"Got DTensor sharded over mesh dimension {sharded_mesh_dim} different from the optimizer's device mesh" + f"param device mesh: {params[0].device_mesh}, optimizer's device mesh: {self._distributed_mesh}" + f"param group: {params[0].device_mesh.get_group(sharded_mesh_dim)}, optimizer's group: {self._process_group}" + ) + + yield AsyncTask( + muon_update_batch_async( + X=pad_batch(params, self._world_size), + G=pad_batch(gradients, self._world_size), + M=pad_batch(momentums, self._world_size), + lr=lr, + momentum=mu, + weight_decay=weight_decay, + epsilon=epsilon, + nesterov=nesterov, + flatten=flatten, + adjust_lr=adjust_lr, + device_rank=self._device_rank, + world_size=self._world_size, + shard_dim=sharded_tensor_dim, + process_group=self._process_group, + newton_schulz_func=self._newton_schulz_func, + ) + ) + + def _create_adamw_tasks( + self, + param_groups: List[dict], + algo_name: str = "adamw", + ) -> Generator["AsyncTask", None, None]: + """Helper function to generate AsyncTask objects for AdamW updates.""" + for group in param_groups: + assert group["algorithm"] == algo_name + + # Get parameters and optimizer states + params = [p for p in group["params"] if p.grad is not None] + if not params: + continue + gradients = [p.grad for p in params] + states = [self._get_or_initialize_state(p, algo_name) for p in params] + momentums = [s["momentum"] for s in states] + variances = [s["variance"] for s in states] + + # Wrap hyperparameters in tensors for torch.compile + lr = torch.tensor(group["lr"]) + beta1 = torch.tensor(group["beta1"]) + beta2 = torch.tensor(group["beta2"]) + weight_decay = torch.tensor(group["weight_decay"]) + epsilon = torch.tensor(group["epsilon"]) + step = torch.tensor(group["step"]) + + yield AsyncTask( + adamw_update_foreach_async( + X=to_local(params), + G=to_local(gradients), + M=to_local(momentums), + V=to_local(variances), + lr=lr, + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + step=step, + epsilon=epsilon, + ) + ) + + +def muon_update_batch_async( + X: List[Tensor], # Model weights (modified in place) + G: List[Tensor], # Gradient + M: List[Tensor], # Momentum buffer (modified in place) + lr: Tensor, # Learning rate (scalar tensor) + momentum: Tensor, # Momentum factor (scalar tensor) + weight_decay: Tensor, # Weight decay (scalar tensor) + epsilon: Tensor, # Epsilon (scalar tensor) + nesterov: bool, # Whether to use Nesterov momentum + flatten: bool, # Whether to flatten 3D+ tensors to 2D + adjust_lr: Optional[str], # How to adjust learning rate + device_rank: int, # Rank of the current device + world_size: int, # Total number of devices to parallelize over + shard_dim: Optional[int] = None, # Shard dimension for DTensor (if applicable) + process_group: Optional[ProcessGroup] = None, + newton_schulz_func: Optional[Callable] = None, +) -> Generator[None, None, None]: + """Batched version of Muon update. + + Batch size should be equal to number of GPUs. All tensors in a batch should have identical shape, sharding, and + dtype. Identical hyperparameters are used for all tensors in the batch. + """ + + assert len(X) == len(G) + assert len(X) == len(M) + assert len(X) == world_size + + # Update momentum and compute the inputs for orthogonalization + U = muon_update_pre_orthogonalize( + G=to_local(G), + M=to_local(M), + momentum=momentum, + nesterov=nesterov, + ) + + # Get one whole matrix for each device to orthogonalize + if shard_dim is not None: + # Use all-to-all to transform from a batch of shards to a single whole matrix + # https://www.essential.ai/blog/infra + assert process_group is not None, "process_group must be provided for sharded DTensors" + assert isinstance(X[0], DTensor), "X should contain DTensors" + assert not isinstance(U[0], DTensor), "U should contain local shards" + assert X[0].size(shard_dim) % world_size == 0, ( + f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}." + ) + + # Pack the list of shards into a single contiguous tensor + # U is currently [Shard_for_Rank0, Shard_for_Rank1, ...] + # Stack creates shape: (World_Size, *Shard_Shape) + U_packed = torch.stack(U) + + # Allocate buffer to receive parts of the "Single Matrix" + # Shape: (World_Size, *Shard_Shape) + single_matrix_parts = torch.empty_like(U_packed) + + # Perform optimized All-to-All + # This sends one large contiguous buffer instead of many small ones + work = dist.all_to_all_single(single_matrix_parts, U_packed, group=process_group, async_op=True) + yield + work.wait() + + # Reconstruct the full matrix + # single_matrix_parts has shape (World_Size, D0, D1...) + if shard_dim == 0: + # Optimization: If sharded on dim 0, we can simply flatten the batch dim + # to reconstruct the full matrix. This is a Zero-Copy View. + single_matrix = single_matrix_parts.flatten(0, 1) + else: + # General case (e.g., Col-wise sharding): We must concatenate along shard_dim. + # This requires a memory copy. + single_matrix = torch.cat(single_matrix_parts.unbind(0), dim=shard_dim) + + # 5. Perform Newton-Schulz Orthogonalization + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + + # Prepare to scatter results back + if shard_dim == 0: + # Optimization: View back to (World_Size, Shard_Size, ...) + # This is a Zero-Copy View. + single_matrix_shards_packed = single_matrix.view(world_size, -1, *single_matrix.shape[1:]) + else: + # General case: Split back into chunks and stack them. + # We use stack to ensure the output is contiguous (World_Size, ...) for NCCL + single_matrix_shards_packed = torch.stack(single_matrix.chunk(world_size, dim=shard_dim)) + + # Ensure contiguity is preserved (crucial for NCCL) + if not single_matrix_shards_packed.is_contiguous(): + single_matrix_shards_packed = single_matrix_shards_packed.contiguous() + + # Allocate buffer for receiving updated gradients + U_packed_back = torch.empty_like(single_matrix_shards_packed) + + # Perform optimized All-to-All (Scatter back) + work = dist.all_to_all_single(U_packed_back, single_matrix_shards_packed, group=process_group, async_op=True) + yield + work.wait() + + # Unpack back to list form for the post-processing function + # unbind(0) is a view operation (slicing) + U = list(U_packed_back.unbind(0)) + + else: + # Matrices are not sharded, so we can directly orthogonalize + # Get a single matrix corresponding to this device + single_matrix = U[device_rank] + assert not isinstance(single_matrix, DTensor) + + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + + if process_group is not None and process_group.size() > 1: + # Ensure input is contiguous + input_tensor = single_matrix.contiguous() + + # Create output buffer + gathered_U = torch.empty( + (world_size,) + input_tensor.shape, dtype=input_tensor.dtype, device=input_tensor.device + ) + + # Use efficient all_gather + work = dist.all_gather_into_tensor(gathered_U, input_tensor, group=process_group, async_op=True) + yield + work.wait() + + # Unbind to list for compatibility + U = list(gathered_U.unbind(0)) + + else: + # Single GPU case, no need to gather + assert world_size == 1 + U = [single_matrix] + + # Compute scaled learning rate + # Do this before to_local(X) because we use the full tensor shape, not the shard shape + if adjust_lr is None: + adjusted_lr = lr + elif adjust_lr == "spectral_norm": + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape) + elif adjust_lr == "rms_norm": + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape) + else: + raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") + + # Update model parameters with orthogonalized output + muon_update_post_orthogonalize( + X=to_local(X), + U=U, + base_lr=lr, + adjusted_lr=adjusted_lr, + weight_decay=weight_decay, + ) + + +def adamw_update_foreach_async( + X: List[Tensor], # Model weights (modified in place) + G: List[Tensor], # Gradient + M: List[Tensor], # Momentum buffer (modified in place) + V: List[Tensor], # Variance buffer (modified in place) + lr: Tensor, # Learning rate (scalar tensor) + beta1: Tensor, # Beta 1 (scalar tensor) + beta2: Tensor, # Beta 2 (scalar tensor) + weight_decay: Tensor, # Weight decay (scalar tensor) + step: int, + epsilon: float, +) -> Generator[None, None, None]: + """Async wrapper around foreach AdamW update.""" + adamw_update_foreach(X, G, M, V, lr, beta1, beta2, weight_decay, step, epsilon) + yield + + +@torch.compile(fullgraph=True) +def muon_update_pre_orthogonalize( + G: List[Tensor], + M: List[Tensor], + momentum: Tensor, + nesterov: bool, +) -> List[Tensor]: + """Update momentum with gradient and compute the input to + orthogonalization. + + Inputs and outputs should be lists of regular Tensor, not DTensor. This is a separate function for compatibility + with torch.compile(). + """ + dtype = M[0].dtype + G = [g.to(dtype=dtype) for g in G] + + # Update momentum with new gradient + torch._foreach_mul_(M, momentum) + torch._foreach_add_(M, G) + + if nesterov: + U = torch._foreach_mul(M, momentum) + torch._foreach_add_(U, G) + else: + U = M + + # Convert to bfloat16 before communication + U = [u.to(dtype=torch.bfloat16) for u in U] + + return U + + +@torch.compile(fullgraph=True) +def muon_update_post_orthogonalize( + X: List[Tensor], + U: List[Tensor], + base_lr: Tensor, + adjusted_lr: Tensor, + weight_decay: Tensor, +): + """Apply weight decay and weight update after orthogonalization. + + Inputs and outputs should be lists of regular Tensor, not DTensor. This is a separate function for compatibility + with torch.compile(). + """ + # Apply weight decay + torch._foreach_mul_(X, 1 - base_lr * weight_decay) + + # Weight update + U = torch._foreach_mul(U, adjusted_lr) + torch._foreach_sub_(X, U) + + +def muon_update_newton_schulz( + X: Tensor, + newton_schulz_func: Callable, + flatten: bool, + epsilon: Tensor, +) -> Tensor: + """Flatten the input tensor if needed and call the Newton-Schulz + function.""" + original_shape = X.shape + if flatten and X.ndim >= 3: + # Flatten 3D+ tensors to 2D matrix + X = X.flatten(start_dim=1) + elif X.ndim >= 4: + # Given 4D+ batch, flatten to 3D batch + X = X.flatten(end_dim=-3) + + return newton_schulz_func(X, epsilon=epsilon).reshape(original_shape) + + +def adjust_lr_rms_norm(lr, param_shape): + # Adjust learning rate for constant element-wise RMS norm + # https://arxiv.org/abs/2502.16982 + A, B = param_shape[:2] + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + +def adjust_lr_spectral_norm(lr, param_shape): + # Adjust from spectral norm 1 to RMS operator norm 1 + # https://arxiv.org/abs/2310.17813 + fan_out, fan_in = param_shape[:2] + adjusted_lr = lr * math.sqrt(fan_out / fan_in) + return adjusted_lr + + +@torch.compile(fullgraph=True) +def zeropower_via_newtonschulz5(G: Tensor, epsilon: float = 1e-7): + """Newton-Schulz iteration to approximate the orthogonalization of X.""" + # Newton-Schulz constants + ns_consts = [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ] + + X = G.to(dtype=torch.bfloat16) + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + epsilon) + + for a, b, c in ns_consts: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + + if G.size(-2) > G.size(-1): + X = X.mT + return X diff --git a/xtuner/v1/optim/newton_schulz_triton.py b/xtuner/v1/optim/newton_schulz_triton.py new file mode 100644 index 000000000..e0f4abbd1 --- /dev/null +++ b/xtuner/v1/optim/newton_schulz_triton.py @@ -0,0 +1,350 @@ +# From https://github.com/microsoft/dion +import torch +import triton +import triton.language as tl +from torch import Tensor + + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Helper function to map Triton program ID to (batch, row, col) of the + output matrix.""" + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + + return batch_idx, m_idx, n_idx + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, + C_ptr, + M, + K, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + """Input A has shape (M, K) + + Output C has shape (M, M) Compute C = A @ A.T + """ + + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block(pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_1(A: Tensor, *, out: Tensor = None): + """Launch Triton kernel to compute C = A @ A.T.""" + if A.ndim > 3 or A.ndim < 2: + raise ValueError(f"Input tensor must be 2D or 3D, but got {A.ndim}D tensor.") + + M, K = A.shape[-2:] + + if out is None: + out = torch.empty((*A.shape[:-1], M), device=A.device, dtype=A.dtype) + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: (batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]),) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + + return out + + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, + C_ptr, + M, + a_stride_b, + a_stride_r, + a_stride_c, + c_stride_b, + c_stride_r, + c_stride_c, + alpha, + beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + """Input A is square matrix with shape (M, M) + + Output C has shape (M, M) Compute C = alpha * A @ A.T + beta * A + """ + + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block(pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + + +def ns_line_2(A: Tensor, alpha: float, beta: float, *, out: Tensor = None): + """Launch Triton kernel to compute C = alpha * A @ A.T + beta * A.""" + if A.ndim > 3 or A.ndim < 2: + raise ValueError(f"Input tensor must be 2D or 3D, but got {A.ndim}D tensor.") + + M, K = A.shape[-2:] + if M != K: + raise ValueError(f"Input must be symmetric square matrix, but got shape {A.shape}") + + if out is None: + out = torch.empty((*A.shape[:-1], M), device=A.device, dtype=A.dtype) + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: (batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]),) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + + return out + + +# @torch.compile(dynamic=False, fullgraph=True) +def zeropower_via_newtonschulz5(G: Tensor, epsilon: float = 1e-7): + """Reference implementation of Newton-Schulz without Triton.""" + # Newton-Schulz constants + ns_consts = [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ] + + X = G.to(dtype=torch.bfloat16) + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + epsilon) + + for a, b, c in ns_consts: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +# @torch.compile(dynamic=False, fullgraph=True) +def newton_schulz_triton(G: Tensor, epsilon: float = 1e-7): + """Triton implementation of Newton-Schulz iteration.""" + # Newton-Schulz constants + ns_consts = [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ] + + X = G.to(dtype=torch.bfloat16) + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + epsilon) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for a, b, c in ns_consts: + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X diff --git a/xtuner/v1/train/arguments/arguments.py b/xtuner/v1/train/arguments/arguments.py index d9c29d076..213c7d371 100644 --- a/xtuner/v1/train/arguments/arguments.py +++ b/xtuner/v1/train/arguments/arguments.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, ConfigDict from xtuner.v1.config import FSDPConfig -from xtuner.v1.config.optim import AdamWConfig, LRConfig +from xtuner.v1.config.optim import AdamWConfig, LRConfig, MuonConfig from xtuner.v1.datasets import FTDPTokenizeFnConfig from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig, DatasetConfigList from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig @@ -73,7 +73,7 @@ class TrainingArguments(BaseModel): max_length: Annotated[int, Parameter(group=dataset_group, help="max single sequence length")] = 4096 # optimizer lr: Annotated[float, Parameter(group=optimizer_group, help="learning rate")] = 6e-5 - optim: Annotated[Literal["AdamW"], Parameter(group=optimizer_group, help="optimizer type")] = "AdamW" + optim: Annotated[Literal["AdamW", "Muon"], Parameter(group=optimizer_group, help="optimizer type")] = "AdamW" # lr-scheduler lr_min: Annotated[float, Parameter(group=lr_scheduler_group, help="minimum learning rate")] = 1e-6 scheduler_type: Annotated[ @@ -124,8 +124,12 @@ def to_trainer_config(self): # Load dataset config dataset_cfg = self._get_dataset_config() # Create optimizer config - optim_cfg = AdamWConfig(lr=self.lr, foreach=False) - + if self.optim == "AdamW": + optim_cfg = AdamWConfig(lr=self.lr, foreach=False) + elif self.optim == "Muon": + optim_cfg = MuonConfig(lr=self.lr) + else: + raise ValueError(f"Unkown optimizer type {self.optim}, only support `AdamW` and `Muon`.") # Create learning rate scheduler config lr_cfg = LRConfig(lr_type=self.scheduler_type, warmup_ratio=self.warmup_ratio, lr_min=self.lr_min) From 3d380da2ee477beca549747d288d31cb6e3e3d1a Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Mon, 9 Feb 2026 14:21:07 +0800 Subject: [PATCH 2/4] lint --- xtuner/v1/optim/muon.py | 5 +++-- xtuner/v1/optim/newton_schulz_triton.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/xtuner/v1/optim/muon.py b/xtuner/v1/optim/muon.py index 1faa2ab89..22f583125 100644 --- a/xtuner/v1/optim/muon.py +++ b/xtuner/v1/optim/muon.py @@ -1,3 +1,4 @@ +# type: ignore # ================================================================ # Copyright (c) 2026 InternLM. All rights reserved. # @@ -47,7 +48,7 @@ def to_local(tensor: Union[Tensor, List[Tensor]]) -> Union[Tensor, List[Tensor]] return [t.to_local() if isinstance(t, DTensor) else t for t in tensor] -def dtensor_from_local(tensor: Union[Tensor, List[Tensor]], ref: Tensor) -> Union[DTensor, List[DTensor]]: +def dtensor_from_local(tensor: Union[Tensor, List[Tensor]], ref: Tensor) -> Union[DTensor, List[DTensor]]: # type: ignore """Convert a single local Tensor or list of local Tensors to DTensor. The reference tensor's device mesh and placements are used to create the DTensor. if the reference tensor is not a @@ -211,7 +212,7 @@ def adamw_update( @torch.compile(fullgraph=True) -def adamw_update_foreach( +def adamw_update_foreach( # type: ignore X: List[Tensor], # Model weights (modified in place) G: List[Tensor], # Gradient M: List[Tensor], # Momentum buffer (modified in place) diff --git a/xtuner/v1/optim/newton_schulz_triton.py b/xtuner/v1/optim/newton_schulz_triton.py index e0f4abbd1..c9aa8c91e 100644 --- a/xtuner/v1/optim/newton_schulz_triton.py +++ b/xtuner/v1/optim/newton_schulz_triton.py @@ -128,7 +128,7 @@ def ns_line_1_kernel( tl.store(c_ptrs_t, output.T, mask=c_mask_t) -def ns_line_1(A: Tensor, *, out: Tensor = None): +def ns_line_1(A: Tensor, *, out: Tensor | None = None): """Launch Triton kernel to compute C = A @ A.T.""" if A.ndim > 3 or A.ndim < 2: raise ValueError(f"Input tensor must be 2D or 3D, but got {A.ndim}D tensor.") @@ -144,7 +144,7 @@ def ns_line_1(A: Tensor, *, out: Tensor = None): input_batch_stride = A.stride(0) if A.ndim == 3 else 0 output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - grid = lambda meta: (batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]),) + grid = lambda meta: (batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]),) # noqa: E731 ns_line_1_kernel[grid]( A_ptr=A, C_ptr=out, @@ -246,7 +246,7 @@ def ns_line_2_kernel( tl.store(c_ptrs_t, output.T, mask=c_mask_t) -def ns_line_2(A: Tensor, alpha: float, beta: float, *, out: Tensor = None): +def ns_line_2(A: Tensor, alpha: float, beta: float, *, out: Tensor | None = None): """Launch Triton kernel to compute C = alpha * A @ A.T + beta * A.""" if A.ndim > 3 or A.ndim < 2: raise ValueError(f"Input tensor must be 2D or 3D, but got {A.ndim}D tensor.") @@ -264,7 +264,7 @@ def ns_line_2(A: Tensor, alpha: float, beta: float, *, out: Tensor = None): input_batch_stride = A.stride(0) if A.ndim == 3 else 0 output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - grid = lambda meta: (batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]),) + grid = lambda meta: (batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]),) # noqa: E731 ns_line_2_kernel[grid]( A_ptr=A, C_ptr=out, From 62187fadc49f289961bc262cc2a051a93b3e42bd Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 27 Feb 2026 12:04:19 +0800 Subject: [PATCH 3/4] lint --- xtuner/v1/optim/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xtuner/v1/optim/__init__.py b/xtuner/v1/optim/__init__.py index 3a80f8163..0ea5fdeb0 100644 --- a/xtuner/v1/optim/__init__.py +++ b/xtuner/v1/optim/__init__.py @@ -1,4 +1,4 @@ -from .muon import Muon +from .muon import Muon # type: ignore __all__ = ["Muon"] From 995f7154915c3cbf1c413ae540bc6886c06b20ad Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 27 Feb 2026 13:01:08 +0800 Subject: [PATCH 4/4] lint --- xtuner/v1/train/arguments/arguments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xtuner/v1/train/arguments/arguments.py b/xtuner/v1/train/arguments/arguments.py index 213c7d371..f763cd4ee 100644 --- a/xtuner/v1/train/arguments/arguments.py +++ b/xtuner/v1/train/arguments/arguments.py @@ -129,7 +129,7 @@ def to_trainer_config(self): elif self.optim == "Muon": optim_cfg = MuonConfig(lr=self.lr) else: - raise ValueError(f"Unkown optimizer type {self.optim}, only support `AdamW` and `Muon`.") + raise ValueError(f"Unknown optimizer type {self.optim}, only support `AdamW` and `Muon`.") # Create learning rate scheduler config lr_cfg = LRConfig(lr_type=self.scheduler_type, warmup_ratio=self.warmup_ratio, lr_min=self.lr_min)