Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xtuner/v1/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .fsdp import FSDPConfig
from .generate import GenerateConfig
from .optim import AdamWConfig, LRConfig, OptimConfig
from .optim import AdamWConfig, LRConfig, MuonConfig, OptimConfig


__all__ = [
Expand All @@ -9,4 +9,5 @@
"AdamWConfig",
"LRConfig",
"GenerateConfig",
"MuonConfig",
]
97 changes: 96 additions & 1 deletion xtuner/v1/config/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@
from typing import Literal, Optional, Tuple

import torch
import torch.distributed as dist
from cyclopts import Parameter
from pydantic import BaseModel, ConfigDict
from typing_extensions import Annotated

from xtuner.v1.optim import Muon
from xtuner.v1.utils import get_logger


logger = get_logger()


class OptimConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
Expand All @@ -26,12 +33,100 @@ class AdamWConfig(OptimConfig):
eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Adam optimizer")] = 1e-8
foreach: Annotated[Optional[bool], Parameter(help="Use foreach implementation for AdamW")] = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Critical] The base class OptimConfig.build(self, params) abstract method signature was not updated, but both AdamWConfig.build and MuonConfig.build now take model instead of params. This is a breaking API inconsistency — any third-party or future subclass of OptimConfig following the abstract method's parameter name would expect to receive params, not a model object.

The abstract method in OptimConfig (line 20) should be updated to build(self, model) to match the new contract.


def build(self, params):
def build(self, model):
params = [p for p in model.parameters() if p.requires_grad]

trainable_parameters_names = model.trainable_parameters()
trainable_names = [name for name, _ in trainable_parameters_names]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Nit] trainable_names is constructed as a list, but it's used only for in membership checks (lines 46-47). This is O(n) per lookup, making the entire loop O(n²) for models with many parameters.

MuonConfig.build (line 68) correctly uses a set comprehension for the same purpose. This should be consistent:

Suggested change
trainable_names = [name for name, _ in trainable_parameters_names]
trainable_names = {name for name, _ in trainable_parameters_names}

untrainable_names = []
num_total_requires_grad = 0
num_total = 0
for name, params_ in model.named_parameters():
num_total += params_.numel()
num_total_requires_grad += params_.numel() if name in trainable_names else 0
if name not in trainable_names:
untrainable_names.append(name)

if dist.get_rank() == 0:
logger.info(
Comment on lines +36 to +51
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Warning] The parameter counting and logging logic is now duplicated between AdamWConfig.build and MuonConfig.build. Consider extracting this into a shared method on OptimConfig (or a standalone helper) to avoid repetition. When a third optimizer is added, this would need to be copy-pasted again.

f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M"
)
logger.info(f"Untrainable parameters names: {untrainable_names}")
return torch.optim.AdamW(
params, lr=self.lr, betas=self.betas, eps=self.eps, weight_decay=self.weight_decay, foreach=self.foreach
)


class MuonConfig(OptimConfig):
weight_decay: Annotated[float, Parameter(help="Weight decay coefficient for L2 regularization")] = 0.1
momentum: Annotated[float, Parameter(help="Momentum coefficients for Muon optimizer")] = 0.95
betas: Annotated[Tuple[float, float], Parameter(help="Beta coefficients for AdamW optimizer")] = (0.9, 0.95)
eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Muon optimizer")] = 1e-8

def build(self, model):
trainable_parameters_names = model.trainable_parameters()
trainable_names = {name for name, _ in trainable_parameters_names}

untrainable_names = []
num_total = 0
num_total_requires_grad = 0
num_muon = 0
num_adamw = 0

for name, p in model.named_parameters():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Warning] The Muon/AdamW parameter split uses hardcoded layer names "embed_tokens" and "lm_head". This is fragile and model-architecture specific — different model families may use different names for embedding/output layers (e.g., "wte", "embed", "output", "head").

Consider making these configurable via MuonConfig fields (e.g., adamw_name_patterns: list[str] = ["embed_tokens", "lm_head"]) so users can customize which parameters use AdamW vs Muon without modifying the code.

This is also the concern raised by @nil0x9 about GroupedLinear — the current string matching approach has no mechanism to handle fused parameters that may structurally look like 2D matrices but semantically represent bundled weights.

n = p.numel()
num_total += n
if name in trainable_names:
num_total_requires_grad += n
is_muon_tensor = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
if is_muon_tensor:
num_muon += n
else:
num_adamw += n
else:
untrainable_names.append(name)

muon_params = [
p
for name, p in model.named_parameters()
if name in trainable_names and p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
]
adamw_params = [
p
for name, p in model.named_parameters()
if name in trainable_names and not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name)
]
Comment on lines +89 to +98
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Warning] The Muon/AdamW parameter filtering condition (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name) is duplicated three times: once for counting (line 81), once for muon_params (line 92), and once negated for adamw_params (line 97). This also iterates model.named_parameters() three separate times.

If one instance is updated but the others aren't, parameters will be misclassified. Consider collecting all three lists in a single pass:

muon_params = []
adamw_params = []
for name, p in model.named_parameters():
    if name not in trainable_names:
        ...
        continue
    is_muon = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
    if is_muon:
        muon_params.append(p)
        num_muon += n
    else:
        adamw_params.append(p)
        num_adamw += n

This eliminates the duplication and reduces the three passes over parameters to one.

param_groups = [
dict(params=muon_params),
dict(params=adamw_params, algorithm="adamw"),
]

if dist.get_rank() == 0:
logger.info(
f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M"
)
logger.info(f"Muon params: {num_muon // 1e6}M, AdamW params: {num_adamw // 1e6}M (counts by numel)")
logger.info(f"Untrainable parameters names: {untrainable_names}")
logger.info(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Nit] Chinese comment in source code. For consistency in an open-source project, comments should be in English.

Suggested change
logger.info(
distributed_mesh=model.fsdp_mesh, # TODO: EP>1 not supported yet

f"using Muon optimizer distributed_mesh_size: {model.fsdp_mesh.size()}, "
f"distributed_mesh: {model.fsdp_mesh}"
)

optimizer = Muon(
param_groups,
distributed_mesh=model.fsdp_mesh, # TODO: 暂不支持 EP>1
lr=self.lr,
mu=self.momentum,
betas=self.betas,
weight_decay=self.weight_decay,
nesterov=True,
adjust_lr="rms_norm",
use_triton=False,
Comment on lines +122 to +124
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Nit] nesterov=True, adjust_lr="rms_norm", and use_triton=False are hardcoded here rather than exposed as config fields on MuonConfig. Users who want nesterov=False, a different adjust_lr strategy, or Triton kernels would need to modify this code. Consider whether these should be configurable on MuonConfig like the other hyperparameters, or at minimum add a brief comment explaining why these defaults were chosen.

epsilon=self.eps,
)
return optimizer


class LRConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
lr_type: Annotated[Literal["cosine", "linear", "constant"], Parameter(help="Type of learning rate schedule")] = (
Expand Down
20 changes: 1 addition & 19 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,25 +184,7 @@ def build_model(self) -> BaseModel:
return model

def build_optimizer(self, optim_cfg: OptimConfig) -> torch.optim.Optimizer:
params = [p for p in self.model.parameters() if p.requires_grad]

trainable_parameters_names = self.model.trainable_parameters()
trainable_names = [name for name, _ in trainable_parameters_names]
untrainable_names = []
num_total_requires_grad = 0
num_total = 0
for name, params_ in self.model.named_parameters():
num_total += params_.numel()
num_total_requires_grad += params_.numel() if name in trainable_names else 0
if name not in trainable_names:
untrainable_names.append(name)

if dist.get_rank() == 0:
logger.info(
f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M"
)
logger.info(f"Untrainable parameters names: {untrainable_names}")
return optim_cfg.build(params)
return optim_cfg.build(self.model)

@property
def data_replicate_size(self) -> int:
Expand Down
4 changes: 4 additions & 0 deletions xtuner/v1/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .muon import Muon # type: ignore


__all__ = ["Muon"]
Loading