Skip to content

feat: support fsdp2 muon optimizer#1486

Open
RangiLyu wants to merge 4 commits intoInternLM:mainfrom
RangiLyu:fsdp2-muon
Open

feat: support fsdp2 muon optimizer#1486
RangiLyu wants to merge 4 commits intoInternLM:mainfrom
RangiLyu:fsdp2-muon

Conversation

@RangiLyu
Copy link
Contributor

@RangiLyu RangiLyu commented Feb 6, 2026

image

@nil0x9
Copy link
Contributor

nil0x9 commented Feb 9, 2026

It might be more desirable to introduce some mechanism to separate fused params IMHO? For example, MoE projections are implemented with GroupedLinear in XTuner, where 2D param actually means a bundle of linear weights. In this case, it might be more beneficial to split this single param tensor and treat them as different projections in NS iteration.

@HAOCHENYE
Copy link
Collaborator

@claude

@claude
Copy link

claude bot commented Feb 28, 2026

Claude Code is working…

I'll analyze this and get back to you.

View job run

@HAOCHENYE
Copy link
Collaborator

@claude

Comment on lines +36 to +51
def build(self, model):
params = [p for p in model.parameters() if p.requires_grad]

trainable_parameters_names = model.trainable_parameters()
trainable_names = [name for name, _ in trainable_parameters_names]
untrainable_names = []
num_total_requires_grad = 0
num_total = 0
for name, params_ in model.named_parameters():
num_total += params_.numel()
num_total_requires_grad += params_.numel() if name in trainable_names else 0
if name not in trainable_names:
untrainable_names.append(name)

if dist.get_rank() == 0:
logger.info(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Warning] The parameter counting and logging logic is now duplicated between AdamWConfig.build and MuonConfig.build. Consider extracting this into a shared method on OptimConfig (or a standalone helper) to avoid repetition. When a third optimizer is added, this would need to be copy-pasted again.

num_muon = 0
num_adamw = 0

for name, p in model.named_parameters():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Warning] The Muon/AdamW parameter split uses hardcoded layer names "embed_tokens" and "lm_head". This is fragile and model-architecture specific — different model families may use different names for embedding/output layers (e.g., "wte", "embed", "output", "head").

Consider making these configurable via MuonConfig fields (e.g., adamw_name_patterns: list[str] = ["embed_tokens", "lm_head"]) so users can customize which parameters use AdamW vs Muon without modifying the code.

This is also the concern raised by @nil0x9 about GroupedLinear — the current string matching approach has no mechanism to handle fused parameters that may structurally look like 2D matrices but semantically represent bundled weights.

)
logger.info(f"Muon params: {num_muon // 1e6}M, AdamW params: {num_adamw // 1e6}M (counts by numel)")
logger.info(f"Untrainable parameters names: {untrainable_names}")
logger.info(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Nit] Chinese comment in source code. For consistency in an open-source project, comments should be in English.

Suggested change
logger.info(
distributed_mesh=model.fsdp_mesh, # TODO: EP>1 not supported yet

@@ -0,0 +1,839 @@
# type: ignore
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Warning] File-level # type: ignore suppresses all mypy checking for 838 lines. Per CLAUDE.md, mypy is strict for xtuner/v1/*. While I understand this is pragmatic for a complex optimizer with torch.compile and DTensor, it masks real type errors (e.g., step: int receiving a Tensor — see my comment below).

Consider either:

  1. Adding targeted # type: ignore[specific-error] annotations only where needed, or
  2. Adding this file to mypy.ini's overrides with specific relaxations rather than a blanket ignore.

Comment on lines +519 to +522
variances = [s["variance"] for s in states]

# Wrap hyperparameters in tensors for torch.compile
lr = torch.tensor(group["lr"])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Warning] Type annotations are incorrect for step and epsilon. In _create_adamw_tasks, both are wrapped with torch.tensor(...), so they arrive as Tensor, not int/float:

step = torch.tensor(group["step"])
epsilon = torch.tensor(group["epsilon"])

The same mismatch exists in adamw_update_foreach (line 218). These annotations should be Tensor to match actual usage. The # type: ignore at file level currently hides this.

Comment on lines +481 to +484
)

yield AsyncTask(
muon_update_batch_async(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Warning] Missing whitespace between concatenated f-strings produces a garbled error message. For example, the output would read: "...optimizer's device meshparam device mesh: ...device meshparam group: ...".

Each f-string needs a separator (space, period, or newline):

Suggested change
)
yield AsyncTask(
muon_update_batch_async(
raise RuntimeError(
f"Got DTensor sharded over mesh dimension {sharded_mesh_dim} different from the optimizer's device mesh. "
f"param device mesh: {params[0].device_mesh}, optimizer's device mesh: {self._distributed_mesh}. "
f"param group: {params[0].device_mesh.get_group(sharded_mesh_dim)}, optimizer's group: {self._process_group}"
)

Comment on lines +47 to +72
return tensor.to_local() if isinstance(tensor, DTensor) else tensor
return [t.to_local() if isinstance(t, DTensor) else t for t in tensor]


def dtensor_from_local(tensor: Union[Tensor, List[Tensor]], ref: Tensor) -> Union[DTensor, List[DTensor]]: # type: ignore
"""Convert a single local Tensor or list of local Tensors to DTensor.

The reference tensor's device mesh and placements are used to create the DTensor. if the reference tensor is not a
DTensor, we return the input unmodified.
"""
if not isinstance(ref, DTensor):
assert isinstance(ref, Tensor)
return tensor

device_mesh = ref.device_mesh
placements = ref.placements

# If we have a single tensor
if isinstance(tensor, Tensor):
assert not isinstance(tensor, DTensor)
return DTensor.from_local(tensor, device_mesh=device_mesh, placements=placements)

# We have a list of tensors
assert not isinstance(tensor[0], DTensor)
return [DTensor.from_local(t, device_mesh=device_mesh, placements=placements) for t in tensor]

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Nit] dtensor_from_local is defined but never called anywhere in this PR. If it's intended for future use, consider removing it for now per YAGNI and adding it when actually needed. Dead code adds maintenance burden.

Comment on lines +209 to +254
# Weight update
# X = X - adj_lr * M / denom
X.addcdiv_(M, denom, value=-adj_lr)


@torch.compile(fullgraph=True)
def adamw_update_foreach( # type: ignore
X: List[Tensor], # Model weights (modified in place)
G: List[Tensor], # Gradient
M: List[Tensor], # Momentum buffer (modified in place)
V: List[Tensor], # Variance buffer (modified in place)
lr: Tensor, # Learning rate (scalar tensor)
beta1: Tensor, # Beta 1 (scalar tensor)
beta2: Tensor, # Beta 2 (scalar tensor)
weight_decay: Tensor, # Weight decay (scalar tensor)
step: int,
epsilon: float,
):
"""AdamW optimizer algorithm (foreach implementation)."""
batch_size = len(X)
assert batch_size == len(G)
assert batch_size == len(M)
assert batch_size == len(V)

M_dtype = M[0].dtype
V_dtype = V[0].dtype

# Update momentum and variance
# M = beta1 * M + (1 - beta1) * G
G = [g.to(dtype=M_dtype) for g in G]
torch._foreach_lerp_(M, G, [1 - beta1] * batch_size)

# V = beta2 * V + (1 - beta2) * G * G
G_square = torch._foreach_mul(G, G)
G_square = [g.to(dtype=V_dtype) for g in G_square]
torch._foreach_lerp_(V, G_square, [1 - beta2] * batch_size)

# Bias correction
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
bias_correction2_sqrt = bias_correction2.sqrt()

# The goal is to compute the following in-place:
# M = M / bias_correction1
# V = V / bias_correction2
# X = X - lr * M / (sqrt(V) + epsilon)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Nit] adamw_update (single-element version) is defined but never called — only the foreach variant is used via adamw_update_foreach_async. Consider removing it to reduce dead code, or adding a comment explaining when it would be used.

Comment on lines +286 to +313
def zeropower_via_newtonschulz5(G: Tensor, epsilon: float = 1e-7):
"""Reference implementation of Newton-Schulz without Triton."""
# Newton-Schulz constants
ns_consts = [
(4.0848, -6.8946, 2.9270),
(3.9505, -6.3029, 2.6377),
(3.7418, -5.5913, 2.3037),
(2.8769, -3.1427, 1.2046),
(2.8366, -3.0525, 1.2012),
]

X = G.to(dtype=torch.bfloat16)
if G.size(-2) > G.size(-1):
X = X.mT

# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + epsilon)

for a, b, c in ns_consts:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X

if G.size(-2) > G.size(-1):
X = X.mT
return X


Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Nit] zeropower_via_newtonschulz5 in this file appears to be a reference/test implementation (it's an exact duplicate of the one in muon.py). It's not called by newton_schulz_triton nor exported. Two commented-out @torch.compile decorators (lines 286, 315) also suggest this is WIP code. Consider removing it or adding a clear comment that it's a reference implementation for testing.

@@ -26,12 +33,100 @@ class AdamWConfig(OptimConfig):
eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Adam optimizer")] = 1e-8
foreach: Annotated[Optional[bool], Parameter(help="Use foreach implementation for AdamW")] = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Critical] The base class OptimConfig.build(self, params) abstract method signature was not updated, but both AdamWConfig.build and MuonConfig.build now take model instead of params. This is a breaking API inconsistency — any third-party or future subclass of OptimConfig following the abstract method's parameter name would expect to receive params, not a model object.

The abstract method in OptimConfig (line 20) should be updated to build(self, model) to match the new contract.

@claude
Copy link

claude bot commented Feb 28, 2026

Claude: ## Summary

Adds Muon optimizer support with FSDP2, including an async runtime for overlapping communication with compute, Newton-Schulz orthogonalization (both torch.compile and Triton paths), and the plumbing through config/arguments/engine. Well-structured implementation overall with good attribution to upstream projects (Moonlight, DION).

Issues

Critical

Warning

Nit

Verdict

REQUEST_CHANGES — the abstract method signature mismatch is a real API contract bug that should be fixed before merge. The hardcoded layer name filtering and dead code are worth addressing too.

@HAOCHENYE
Copy link
Collaborator

@claude Please modify this PR as you comment

params = [p for p in model.parameters() if p.requires_grad]

trainable_parameters_names = model.trainable_parameters()
trainable_names = [name for name, _ in trainable_parameters_names]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Nit] trainable_names is constructed as a list, but it's used only for in membership checks (lines 46-47). This is O(n) per lookup, making the entire loop O(n²) for models with many parameters.

MuonConfig.build (line 68) correctly uses a set comprehension for the same purpose. This should be consistent:

Suggested change
trainable_names = [name for name, _ in trainable_parameters_names]
trainable_names = {name for name, _ in trainable_parameters_names}

Comment on lines +89 to +98
muon_params = [
p
for name, p in model.named_parameters()
if name in trainable_names and p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
]
adamw_params = [
p
for name, p in model.named_parameters()
if name in trainable_names and not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name)
]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Warning] The Muon/AdamW parameter filtering condition (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name) is duplicated three times: once for counting (line 81), once for muon_params (line 92), and once negated for adamw_params (line 97). This also iterates model.named_parameters() three separate times.

If one instance is updated but the others aren't, parameters will be misclassified. Consider collecting all three lists in a single pass:

muon_params = []
adamw_params = []
for name, p in model.named_parameters():
    if name not in trainable_names:
        ...
        continue
    is_muon = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
    if is_muon:
        muon_params.append(p)
        num_muon += n
    else:
        adamw_params.append(p)
        num_adamw += n

This eliminates the duplication and reduces the three passes over parameters to one.

Comment on lines +122 to +124
nesterov=True,
adjust_lr="rms_norm",
use_triton=False,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: [Nit] nesterov=True, adjust_lr="rms_norm", and use_triton=False are hardcoded here rather than exposed as config fields on MuonConfig. Users who want nesterov=False, a different adjust_lr strategy, or Triton kernels would need to modify this code. Consider whether these should be configurable on MuonConfig like the other hyperparameters, or at minimum add a brief comment explaining why these defaults were chosen.

@claude
Copy link

claude bot commented Feb 28, 2026

Claude: ## Review Summary

This PR adds Muon optimizer support with FSDP2 via all-to-all communication for distributed Newton-Schulz orthogonalization, along with both torch.compile and Triton kernel paths. The core algorithm implementation is solid — the all-to-all redistribution pattern for reassembling/scattering sharded matrices is correct, the padding with dummy zero tensors is safe (they don't affect real parameter updates), and the Newton-Schulz transpose logic for tall matrices is mathematically sound.

Key issues from prior review (still open)

The previous review identified important issues that should be addressed before merge:

  1. OptimConfig.build(self, params) abstract signature not updated — both subclasses now take model, creating an API contract mismatch
  2. Hardcoded "embed_tokens" / "lm_head" layer names — fragile across model architectures
  3. Missing f-string separators in error messages (line 484 of muon.py)
  4. Dead code (dtensor_from_local, adamw_update, duplicate zeropower_via_newtonschulz5 in triton file)
  5. File-level # type: ignore suppresses all mypy in strict xtuner/v1/ directory

New observations

  1. trainable_names as list in AdamWConfig.build — O(n²) membership checking vs O(1) with set (as MuonConfig.build correctly uses)
  2. Muon/AdamW filter condition duplicated 3× in MuonConfig.build — and model.named_parameters() iterated 3 separate times; a single-pass refactor eliminates both the duplication risk and extra iterations
  3. Hardcoded nesterov=True, adjust_lr="rms_norm", use_triton=False — consider exposing as MuonConfig fields or documenting why these defaults are fixed
  4. No tests for the new optimizer — while integration testing requires GPUs, unit tests for create_param_batches, pad_batch, AsyncRuntime, and the Newton-Schulz functions could run on CPU

Verdict

REQUEST_CHANGES — the abstract method signature mismatch (#1) is a real API contract bug. The dead code, duplicated filter logic, and missing tests are worth addressing for long-term maintainability.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants