Conversation
RangiLyu
commented
Feb 6, 2026
|
It might be more desirable to introduce some mechanism to separate fused params IMHO? For example, MoE projections are implemented with |
|
I'll analyze this and get back to you. |
| def build(self, model): | ||
| params = [p for p in model.parameters() if p.requires_grad] | ||
|
|
||
| trainable_parameters_names = model.trainable_parameters() | ||
| trainable_names = [name for name, _ in trainable_parameters_names] | ||
| untrainable_names = [] | ||
| num_total_requires_grad = 0 | ||
| num_total = 0 | ||
| for name, params_ in model.named_parameters(): | ||
| num_total += params_.numel() | ||
| num_total_requires_grad += params_.numel() if name in trainable_names else 0 | ||
| if name not in trainable_names: | ||
| untrainable_names.append(name) | ||
|
|
||
| if dist.get_rank() == 0: | ||
| logger.info( |
There was a problem hiding this comment.
Claude: [Warning] The parameter counting and logging logic is now duplicated between AdamWConfig.build and MuonConfig.build. Consider extracting this into a shared method on OptimConfig (or a standalone helper) to avoid repetition. When a third optimizer is added, this would need to be copy-pasted again.
| num_muon = 0 | ||
| num_adamw = 0 | ||
|
|
||
| for name, p in model.named_parameters(): |
There was a problem hiding this comment.
Claude: [Warning] The Muon/AdamW parameter split uses hardcoded layer names "embed_tokens" and "lm_head". This is fragile and model-architecture specific — different model families may use different names for embedding/output layers (e.g., "wte", "embed", "output", "head").
Consider making these configurable via MuonConfig fields (e.g., adamw_name_patterns: list[str] = ["embed_tokens", "lm_head"]) so users can customize which parameters use AdamW vs Muon without modifying the code.
This is also the concern raised by @nil0x9 about GroupedLinear — the current string matching approach has no mechanism to handle fused parameters that may structurally look like 2D matrices but semantically represent bundled weights.
| ) | ||
| logger.info(f"Muon params: {num_muon // 1e6}M, AdamW params: {num_adamw // 1e6}M (counts by numel)") | ||
| logger.info(f"Untrainable parameters names: {untrainable_names}") | ||
| logger.info( |
There was a problem hiding this comment.
Claude: [Nit] Chinese comment in source code. For consistency in an open-source project, comments should be in English.
| logger.info( | |
| distributed_mesh=model.fsdp_mesh, # TODO: EP>1 not supported yet |
| @@ -0,0 +1,839 @@ | |||
| # type: ignore | |||
There was a problem hiding this comment.
Claude: [Warning] File-level # type: ignore suppresses all mypy checking for 838 lines. Per CLAUDE.md, mypy is strict for xtuner/v1/*. While I understand this is pragmatic for a complex optimizer with torch.compile and DTensor, it masks real type errors (e.g., step: int receiving a Tensor — see my comment below).
Consider either:
- Adding targeted
# type: ignore[specific-error]annotations only where needed, or - Adding this file to
mypy.ini's overrides with specific relaxations rather than a blanket ignore.
| variances = [s["variance"] for s in states] | ||
|
|
||
| # Wrap hyperparameters in tensors for torch.compile | ||
| lr = torch.tensor(group["lr"]) |
There was a problem hiding this comment.
Claude: [Warning] Type annotations are incorrect for step and epsilon. In _create_adamw_tasks, both are wrapped with torch.tensor(...), so they arrive as Tensor, not int/float:
step = torch.tensor(group["step"])
epsilon = torch.tensor(group["epsilon"])The same mismatch exists in adamw_update_foreach (line 218). These annotations should be Tensor to match actual usage. The # type: ignore at file level currently hides this.
| ) | ||
|
|
||
| yield AsyncTask( | ||
| muon_update_batch_async( |
There was a problem hiding this comment.
Claude: [Warning] Missing whitespace between concatenated f-strings produces a garbled error message. For example, the output would read: "...optimizer's device meshparam device mesh: ...device meshparam group: ...".
Each f-string needs a separator (space, period, or newline):
| ) | |
| yield AsyncTask( | |
| muon_update_batch_async( | |
| raise RuntimeError( | |
| f"Got DTensor sharded over mesh dimension {sharded_mesh_dim} different from the optimizer's device mesh. " | |
| f"param device mesh: {params[0].device_mesh}, optimizer's device mesh: {self._distributed_mesh}. " | |
| f"param group: {params[0].device_mesh.get_group(sharded_mesh_dim)}, optimizer's group: {self._process_group}" | |
| ) |
| return tensor.to_local() if isinstance(tensor, DTensor) else tensor | ||
| return [t.to_local() if isinstance(t, DTensor) else t for t in tensor] | ||
|
|
||
|
|
||
| def dtensor_from_local(tensor: Union[Tensor, List[Tensor]], ref: Tensor) -> Union[DTensor, List[DTensor]]: # type: ignore | ||
| """Convert a single local Tensor or list of local Tensors to DTensor. | ||
|
|
||
| The reference tensor's device mesh and placements are used to create the DTensor. if the reference tensor is not a | ||
| DTensor, we return the input unmodified. | ||
| """ | ||
| if not isinstance(ref, DTensor): | ||
| assert isinstance(ref, Tensor) | ||
| return tensor | ||
|
|
||
| device_mesh = ref.device_mesh | ||
| placements = ref.placements | ||
|
|
||
| # If we have a single tensor | ||
| if isinstance(tensor, Tensor): | ||
| assert not isinstance(tensor, DTensor) | ||
| return DTensor.from_local(tensor, device_mesh=device_mesh, placements=placements) | ||
|
|
||
| # We have a list of tensors | ||
| assert not isinstance(tensor[0], DTensor) | ||
| return [DTensor.from_local(t, device_mesh=device_mesh, placements=placements) for t in tensor] | ||
|
|
There was a problem hiding this comment.
Claude: [Nit] dtensor_from_local is defined but never called anywhere in this PR. If it's intended for future use, consider removing it for now per YAGNI and adding it when actually needed. Dead code adds maintenance burden.
| # Weight update | ||
| # X = X - adj_lr * M / denom | ||
| X.addcdiv_(M, denom, value=-adj_lr) | ||
|
|
||
|
|
||
| @torch.compile(fullgraph=True) | ||
| def adamw_update_foreach( # type: ignore | ||
| X: List[Tensor], # Model weights (modified in place) | ||
| G: List[Tensor], # Gradient | ||
| M: List[Tensor], # Momentum buffer (modified in place) | ||
| V: List[Tensor], # Variance buffer (modified in place) | ||
| lr: Tensor, # Learning rate (scalar tensor) | ||
| beta1: Tensor, # Beta 1 (scalar tensor) | ||
| beta2: Tensor, # Beta 2 (scalar tensor) | ||
| weight_decay: Tensor, # Weight decay (scalar tensor) | ||
| step: int, | ||
| epsilon: float, | ||
| ): | ||
| """AdamW optimizer algorithm (foreach implementation).""" | ||
| batch_size = len(X) | ||
| assert batch_size == len(G) | ||
| assert batch_size == len(M) | ||
| assert batch_size == len(V) | ||
|
|
||
| M_dtype = M[0].dtype | ||
| V_dtype = V[0].dtype | ||
|
|
||
| # Update momentum and variance | ||
| # M = beta1 * M + (1 - beta1) * G | ||
| G = [g.to(dtype=M_dtype) for g in G] | ||
| torch._foreach_lerp_(M, G, [1 - beta1] * batch_size) | ||
|
|
||
| # V = beta2 * V + (1 - beta2) * G * G | ||
| G_square = torch._foreach_mul(G, G) | ||
| G_square = [g.to(dtype=V_dtype) for g in G_square] | ||
| torch._foreach_lerp_(V, G_square, [1 - beta2] * batch_size) | ||
|
|
||
| # Bias correction | ||
| bias_correction1 = 1 - beta1**step | ||
| bias_correction2 = 1 - beta2**step | ||
| bias_correction2_sqrt = bias_correction2.sqrt() | ||
|
|
||
| # The goal is to compute the following in-place: | ||
| # M = M / bias_correction1 | ||
| # V = V / bias_correction2 | ||
| # X = X - lr * M / (sqrt(V) + epsilon) |
There was a problem hiding this comment.
Claude: [Nit] adamw_update (single-element version) is defined but never called — only the foreach variant is used via adamw_update_foreach_async. Consider removing it to reduce dead code, or adding a comment explaining when it would be used.
| def zeropower_via_newtonschulz5(G: Tensor, epsilon: float = 1e-7): | ||
| """Reference implementation of Newton-Schulz without Triton.""" | ||
| # Newton-Schulz constants | ||
| ns_consts = [ | ||
| (4.0848, -6.8946, 2.9270), | ||
| (3.9505, -6.3029, 2.6377), | ||
| (3.7418, -5.5913, 2.3037), | ||
| (2.8769, -3.1427, 1.2046), | ||
| (2.8366, -3.0525, 1.2012), | ||
| ] | ||
|
|
||
| X = G.to(dtype=torch.bfloat16) | ||
| if G.size(-2) > G.size(-1): | ||
| X = X.mT | ||
|
|
||
| # Ensure spectral norm is at most 1 | ||
| X = X / (X.norm(dim=(-2, -1), keepdim=True) + epsilon) | ||
|
|
||
| for a, b, c in ns_consts: | ||
| A = X @ X.mT | ||
| B = b * A + c * (A @ A) | ||
| X = a * X + B @ X | ||
|
|
||
| if G.size(-2) > G.size(-1): | ||
| X = X.mT | ||
| return X | ||
|
|
||
|
|
There was a problem hiding this comment.
Claude: [Nit] zeropower_via_newtonschulz5 in this file appears to be a reference/test implementation (it's an exact duplicate of the one in muon.py). It's not called by newton_schulz_triton nor exported. Two commented-out @torch.compile decorators (lines 286, 315) also suggest this is WIP code. Consider removing it or adding a clear comment that it's a reference implementation for testing.
| @@ -26,12 +33,100 @@ class AdamWConfig(OptimConfig): | |||
| eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Adam optimizer")] = 1e-8 | |||
| foreach: Annotated[Optional[bool], Parameter(help="Use foreach implementation for AdamW")] = None | |||
There was a problem hiding this comment.
Claude: [Critical] The base class OptimConfig.build(self, params) abstract method signature was not updated, but both AdamWConfig.build and MuonConfig.build now take model instead of params. This is a breaking API inconsistency — any third-party or future subclass of OptimConfig following the abstract method's parameter name would expect to receive params, not a model object.
The abstract method in OptimConfig (line 20) should be updated to build(self, model) to match the new contract.
|
Claude: ## Summary Adds Muon optimizer support with FSDP2, including an async runtime for overlapping communication with compute, Newton-Schulz orthogonalization (both IssuesCritical
Warning
Nit
VerdictREQUEST_CHANGES — the abstract method signature mismatch is a real API contract bug that should be fixed before merge. The hardcoded layer name filtering and dead code are worth addressing too. |
|
@claude Please modify this PR as you comment |
| params = [p for p in model.parameters() if p.requires_grad] | ||
|
|
||
| trainable_parameters_names = model.trainable_parameters() | ||
| trainable_names = [name for name, _ in trainable_parameters_names] |
There was a problem hiding this comment.
Claude: [Nit] trainable_names is constructed as a list, but it's used only for in membership checks (lines 46-47). This is O(n) per lookup, making the entire loop O(n²) for models with many parameters.
MuonConfig.build (line 68) correctly uses a set comprehension for the same purpose. This should be consistent:
| trainable_names = [name for name, _ in trainable_parameters_names] | |
| trainable_names = {name for name, _ in trainable_parameters_names} |
| muon_params = [ | ||
| p | ||
| for name, p in model.named_parameters() | ||
| if name in trainable_names and p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name | ||
| ] | ||
| adamw_params = [ | ||
| p | ||
| for name, p in model.named_parameters() | ||
| if name in trainable_names and not (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name) | ||
| ] |
There was a problem hiding this comment.
Claude: [Warning] The Muon/AdamW parameter filtering condition (p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name) is duplicated three times: once for counting (line 81), once for muon_params (line 92), and once negated for adamw_params (line 97). This also iterates model.named_parameters() three separate times.
If one instance is updated but the others aren't, parameters will be misclassified. Consider collecting all three lists in a single pass:
muon_params = []
adamw_params = []
for name, p in model.named_parameters():
if name not in trainable_names:
...
continue
is_muon = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
if is_muon:
muon_params.append(p)
num_muon += n
else:
adamw_params.append(p)
num_adamw += nThis eliminates the duplication and reduces the three passes over parameters to one.
| nesterov=True, | ||
| adjust_lr="rms_norm", | ||
| use_triton=False, |
There was a problem hiding this comment.
Claude: [Nit] nesterov=True, adjust_lr="rms_norm", and use_triton=False are hardcoded here rather than exposed as config fields on MuonConfig. Users who want nesterov=False, a different adjust_lr strategy, or Triton kernels would need to modify this code. Consider whether these should be configurable on MuonConfig like the other hyperparameters, or at minimum add a brief comment explaining why these defaults were chosen.
|
Claude: ## Review Summary This PR adds Muon optimizer support with FSDP2 via all-to-all communication for distributed Newton-Schulz orthogonalization, along with both Key issues from prior review (still open)The previous review identified important issues that should be addressed before merge:
New observations
VerdictREQUEST_CHANGES — the abstract method signature mismatch (#1) is a real API contract bug. The dead code, duplicated filter logic, and missing tests are worth addressing for long-term maintainability. |
