-
Notifications
You must be signed in to change notification settings - Fork 402
feat: support fsdp2 muon optimizer #1486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -2,10 +2,17 @@ | |||||
| from typing import Literal, Optional, Tuple | ||||||
|
|
||||||
| import torch | ||||||
| import torch.distributed as dist | ||||||
| from cyclopts import Parameter | ||||||
| from pydantic import BaseModel, ConfigDict | ||||||
| from typing_extensions import Annotated | ||||||
|
|
||||||
| from xtuner.v1.optim import Muon | ||||||
| from xtuner.v1.utils import get_logger | ||||||
|
|
||||||
|
|
||||||
| logger = get_logger() | ||||||
|
|
||||||
|
|
||||||
| class OptimConfig(BaseModel): | ||||||
| model_config = ConfigDict(extra="forbid") | ||||||
|
|
@@ -26,12 +33,100 @@ class AdamWConfig(OptimConfig): | |||||
| eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Adam optimizer")] = 1e-8 | ||||||
| foreach: Annotated[Optional[bool], Parameter(help="Use foreach implementation for AdamW")] = None | ||||||
|
|
||||||
| def build(self, params): | ||||||
| def build(self, model): | ||||||
| params = [p for p in model.parameters() if p.requires_grad] | ||||||
|
|
||||||
| trainable_parameters_names = model.trainable_parameters() | ||||||
| trainable_names = [name for name, _ in trainable_parameters_names] | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: [Nit]
Suggested change
|
||||||
| untrainable_names = [] | ||||||
| num_total_requires_grad = 0 | ||||||
| num_total = 0 | ||||||
| for name, params_ in model.named_parameters(): | ||||||
| num_total += params_.numel() | ||||||
| num_total_requires_grad += params_.numel() if name in trainable_names else 0 | ||||||
| if name not in trainable_names: | ||||||
| untrainable_names.append(name) | ||||||
|
|
||||||
| if dist.get_rank() == 0: | ||||||
| logger.info( | ||||||
|
Comment on lines
+36
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: [Warning] The parameter counting and logging logic is now duplicated between |
||||||
| f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M" | ||||||
| ) | ||||||
| logger.info(f"Untrainable parameters names: {untrainable_names}") | ||||||
| return torch.optim.AdamW( | ||||||
| params, lr=self.lr, betas=self.betas, eps=self.eps, weight_decay=self.weight_decay, foreach=self.foreach | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| class MuonConfig(OptimConfig): | ||||||
| weight_decay: Annotated[float, Parameter(help="Weight decay coefficient for L2 regularization")] = 0.1 | ||||||
| momentum: Annotated[float, Parameter(help="Momentum coefficients for Muon optimizer")] = 0.95 | ||||||
| betas: Annotated[Tuple[float, float], Parameter(help="Beta coefficients for AdamW optimizer")] = (0.9, 0.95) | ||||||
| eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Muon optimizer")] = 1e-8 | ||||||
|
|
||||||
| def build(self, model): | ||||||
| trainable_parameters_names = model.trainable_parameters() | ||||||
| trainable_names = {name for name, _ in trainable_parameters_names} | ||||||
|
|
||||||
| untrainable_names = [] | ||||||
| num_total = 0 | ||||||
| num_total_requires_grad = 0 | ||||||
| num_muon = 0 | ||||||
| num_adamw = 0 | ||||||
|
|
||||||
| for name, p in model.named_parameters(): | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: [Warning] The Muon/AdamW parameter split uses hardcoded layer names Consider making these configurable via This is also the concern raised by @nil0x9 about |
||||||
| n = p.numel() | ||||||
| num_total += n | ||||||
| if name in trainable_names: | ||||||
| num_total_requires_grad += n | ||||||
| is_muon_tensor = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name | ||||||
| if is_muon_tensor: | ||||||
| num_muon += n | ||||||
| else: | ||||||
| num_adamw += n | ||||||
| else: | ||||||
| untrainable_names.append(name) | ||||||
|
|
||||||
| muon_params = [ | ||||||
| p | ||||||
| for name, p in model.named_parameters() | ||||||
| if name in trainable_names and p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name | ||||||
| ] | ||||||
| adamw_params = [ | ||||||
| p | ||||||
| for name, p in model.named_parameters() | ||||||
| if name in trainable_names and not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name) | ||||||
| ] | ||||||
|
Comment on lines
+89
to
+98
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: [Warning] The Muon/AdamW parameter filtering condition ( If one instance is updated but the others aren't, parameters will be misclassified. Consider collecting all three lists in a single pass: muon_params = []
adamw_params = []
for name, p in model.named_parameters():
if name not in trainable_names:
...
continue
is_muon = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
if is_muon:
muon_params.append(p)
num_muon += n
else:
adamw_params.append(p)
num_adamw += nThis eliminates the duplication and reduces the three passes over parameters to one. |
||||||
| param_groups = [ | ||||||
| dict(params=muon_params), | ||||||
| dict(params=adamw_params, algorithm="adamw"), | ||||||
| ] | ||||||
|
|
||||||
| if dist.get_rank() == 0: | ||||||
| logger.info( | ||||||
| f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M" | ||||||
| ) | ||||||
| logger.info(f"Muon params: {num_muon // 1e6}M, AdamW params: {num_adamw // 1e6}M (counts by numel)") | ||||||
| logger.info(f"Untrainable parameters names: {untrainable_names}") | ||||||
| logger.info( | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: [Nit] Chinese comment in source code. For consistency in an open-source project, comments should be in English.
Suggested change
|
||||||
| f"using Muon optimizer distributed_mesh_size: {model.fsdp_mesh.size()}, " | ||||||
| f"distributed_mesh: {model.fsdp_mesh}" | ||||||
| ) | ||||||
|
|
||||||
| optimizer = Muon( | ||||||
| param_groups, | ||||||
| distributed_mesh=model.fsdp_mesh, # TODO: 暂不支持 EP>1 | ||||||
| lr=self.lr, | ||||||
| mu=self.momentum, | ||||||
| betas=self.betas, | ||||||
| weight_decay=self.weight_decay, | ||||||
| nesterov=True, | ||||||
| adjust_lr="rms_norm", | ||||||
| use_triton=False, | ||||||
|
Comment on lines
+122
to
+124
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: [Nit] |
||||||
| epsilon=self.eps, | ||||||
| ) | ||||||
| return optimizer | ||||||
|
|
||||||
|
|
||||||
| class LRConfig(BaseModel): | ||||||
| model_config = ConfigDict(extra="forbid") | ||||||
| lr_type: Annotated[Literal["cosine", "linear", "constant"], Parameter(help="Type of learning rate schedule")] = ( | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from .muon import Muon # type: ignore | ||
|
|
||
|
|
||
| __all__ = ["Muon"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Claude: [Critical] The base class
OptimConfig.build(self, params)abstract method signature was not updated, but bothAdamWConfig.buildandMuonConfig.buildnow takemodelinstead ofparams. This is a breaking API inconsistency — any third-party or future subclass ofOptimConfigfollowing the abstract method's parameter name would expect to receiveparams, not a model object.The abstract method in
OptimConfig(line 20) should be updated tobuild(self, model)to match the new contract.