From 3a097b07ff2ca0c3572c7188f7da7df14e8503a2 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 25 Jul 2025 13:45:29 +0800 Subject: [PATCH] Use Python 3.9 syntax and typing Signed-off-by: cyy --- muon.py | 188 +++++++++++++++++++++++++++++++++++++++++-------------- setup.py | 2 - 2 files changed, 142 insertions(+), 48 deletions(-) diff --git a/muon.py b/muon.py index 8f11732..7d94d89 100644 --- a/muon.py +++ b/muon.py @@ -1,8 +1,11 @@ import torch +from typing import Any, Callable, Optional +from collections.abc import Iterable import torch.distributed as dist +from torch.optim.optimizer import ParamsT -def zeropower_via_newtonschulz5(G, steps: int): +def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int) -> torch.Tensor: """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose @@ -12,8 +15,10 @@ def zeropower_via_newtonschulz5(G, steps: int): where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model performance at all relative to UV^T, where USV^T = G is the SVD. """ - assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng - a, b, c = (3.4445, -4.7750, 2.0315) + assert ( + G.ndim >= 2 + ) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng + a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() if G.size(-2) > G.size(-1): X = X.mT @@ -23,21 +28,29 @@ def zeropower_via_newtonschulz5(G, steps: int): # Perform the NS iterations for _ in range(steps): A = X @ X.mT - B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + B = ( + b * A + c * A @ A + ) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng X = a * X + B @ X - + if G.size(-2) > G.size(-1): X = X.mT return X -def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True): +def muon_update( + grad: torch.Tensor, + momentum: torch.Tensor, + beta: float = 0.95, + ns_steps: int = 5, + nesterov: bool = True, +) -> torch.Tensor: momentum.lerp_(grad, 1 - beta) update = grad.lerp_(momentum, beta) if nesterov else momentum - if update.ndim == 4: # for the case of conv filters + if update.ndim == 4: # for the case of conv filters update = update.view(len(update), -1) update = zeropower_via_newtonschulz5(update, steps=ns_steps) - update *= max(1, grad.size(-2) / grad.size(-1))**0.5 + update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5 return update @@ -62,15 +75,27 @@ class Muon(torch.optim.Optimizer): weight_decay: The AdamW-style weight decay. momentum: The momentum. A value of 0.95 here is usually fine. """ - def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter) - params = sorted(params, key=lambda x: x.size(), reverse=True) + + def __init__( + self, + params: ParamsT, + lr=0.02, + weight_decay: float = 0, + momentum: float = 0.95, + ): + defaults: dict[str, Any] = dict( + lr=lr, weight_decay=weight_decay, momentum=momentum + ) + assert ( + isinstance(params, list) + and len(params) >= 1 + and isinstance(params[0], torch.nn.Parameter) + ) + # params = sorted(params, key=lambda x: x.size(), reverse=True) super().__init__(params, defaults) @torch.no_grad() - def step(self, closure=None): - + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: loss = None if closure is not None: with torch.enable_grad(): @@ -78,8 +103,10 @@ def step(self, closure=None): for group in self.param_groups: params = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) - for base_i in range(len(params))[::dist.get_world_size()]: + params_pad = params + [torch.empty_like(params[-1])] * ( + dist.get_world_size() - len(params) % dist.get_world_size() + ) + for base_i in range(len(params))[:: dist.get_world_size()]: if base_i + dist.get_rank() < len(params): p = params[base_i + dist.get_rank()] if p.grad is None: @@ -88,10 +115,15 @@ def step(self, closure=None): state = self.state[p] if len(state) == 0: state["momentum_buffer"] = torch.zeros_like(p) - update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + update = muon_update( + p.grad, state["momentum_buffer"], beta=group["momentum"] + ) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update.reshape(p.shape), alpha=-group["lr"]) - dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) + dist.all_gather( + params_pad[base_i : base_i + dist.get_world_size()], + params_pad[base_i + dist.get_rank()], + ) return loss @@ -100,13 +132,19 @@ class SingleDeviceMuon(torch.optim.Optimizer): """ Muon variant for usage in non-distributed settings. """ - def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): + + def __init__( + self, + params: ParamsT, + lr: float = 0.02, + weight_decay: float = 0, + momentum: float = 0.95, + ) -> None: defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) super().__init__(params, defaults) @torch.no_grad() - def step(self, closure=None): - + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: loss = None if closure is not None: with torch.enable_grad(): @@ -120,18 +158,27 @@ def step(self, closure=None): state = self.state[p] if len(state) == 0: state["momentum_buffer"] = torch.zeros_like(p) - update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + update = muon_update( + p.grad, state["momentum_buffer"], beta=group["momentum"] + ) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update.reshape(p.shape), alpha=-group["lr"]) return loss -def adam_update(grad, buf1, buf2, step, betas, eps): +def adam_update( + grad: torch.Tensor, + buf1: torch.Tensor, + buf2: torch.Tensor, + step: int, + betas: tuple[float, float], + eps: float, +) -> torch.Tensor: buf1.lerp_(grad, 1 - betas[0]) buf2.lerp_(grad.square(), 1 - betas[1]) - buf1c = buf1 / (1 - betas[0]**step) - buf2c = buf2 / (1 - betas[1]**step) + buf1c = buf1 / (1 - betas[0] ** step) + buf2c = buf2 / (1 - betas[1] ** step) return buf1c / (buf2c.sqrt() + eps) @@ -162,28 +209,43 @@ class MuonWithAuxAdam(torch.optim.Optimizer): optimizer = MuonWithAuxAdam(param_groups) ``` """ - def __init__(self, param_groups): + + def __init__(self, param_groups: Iterable[dict[str, Any]]) -> None: for group in param_groups: assert "use_muon" in group if group["use_muon"]: - group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True) + group["params"] = sorted( + group["params"], key=lambda x: x.size(), reverse=True + ) # defaults group["lr"] = group.get("lr", 0.02) group["momentum"] = group.get("momentum", 0.95) group["weight_decay"] = group.get("weight_decay", 0) - assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"]) + assert set(group.keys()) == { + "params", + "lr", + "momentum", + "weight_decay", + "use_muon", + } else: # defaults group["lr"] = group.get("lr", 3e-4) group["betas"] = group.get("betas", (0.9, 0.95)) group["eps"] = group.get("eps", 1e-10) group["weight_decay"] = group.get("weight_decay", 0) - assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) + assert set(group.keys()) == { + "params", + "lr", + "betas", + "eps", + "weight_decay", + "use_muon", + } super().__init__(param_groups, dict()) @torch.no_grad() - def step(self, closure=None): - + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: loss = None if closure is not None: with torch.enable_grad(): @@ -192,8 +254,10 @@ def step(self, closure=None): for group in self.param_groups: if group["use_muon"]: params = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) - for base_i in range(len(params))[::dist.get_world_size()]: + params_pad = params + [torch.empty_like(params[-1])] * ( + dist.get_world_size() - len(params) % dist.get_world_size() + ) + for base_i in range(len(params))[:: dist.get_world_size()]: if base_i + dist.get_rank() < len(params): p = params[base_i + dist.get_rank()] if p.grad is None: @@ -202,10 +266,15 @@ def step(self, closure=None): state = self.state[p] if len(state) == 0: state["momentum_buffer"] = torch.zeros_like(p) - update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + update = muon_update( + p.grad, state["momentum_buffer"], beta=group["momentum"] + ) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update.reshape(p.shape), alpha=-group["lr"]) - dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) + dist.all_gather( + params_pad[base_i : base_i + dist.get_world_size()], + params_pad[base_i + dist.get_rank()], + ) else: for p in group["params"]: if p.grad is None: @@ -217,8 +286,14 @@ def step(self, closure=None): state["exp_avg_sq"] = torch.zeros_like(p) state["step"] = 0 state["step"] += 1 - update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], - state["step"], group["betas"], group["eps"]) + update = adam_update( + p.grad, + state["exp_avg"], + state["exp_avg_sq"], + state["step"], + group["betas"], + group["eps"], + ) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update, alpha=-group["lr"]) @@ -229,7 +304,8 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer): """ Non-distributed variant of MuonWithAuxAdam. """ - def __init__(self, param_groups): + + def __init__(self, param_groups: Iterable[dict[str, Any]]) -> None: for group in param_groups: assert "use_muon" in group if group["use_muon"]: @@ -237,20 +313,32 @@ def __init__(self, param_groups): group["lr"] = group.get("lr", 0.02) group["momentum"] = group.get("momentum", 0.95) group["weight_decay"] = group.get("weight_decay", 0) - assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"]) + assert set(group.keys()) == { + "params", + "lr", + "momentum", + "weight_decay", + "use_muon", + } else: # defaults group["lr"] = group.get("lr", 3e-4) group["betas"] = group.get("betas", (0.9, 0.95)) group["eps"] = group.get("eps", 1e-10) group["weight_decay"] = group.get("weight_decay", 0) - assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) + assert set(group.keys()) == { + "params", + "lr", + "betas", + "eps", + "weight_decay", + "use_muon", + } super().__init__(param_groups, dict()) @torch.no_grad() - def step(self, closure=None): - - loss = None + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + loss: Optional[float] = None if closure is not None: with torch.enable_grad(): loss = closure() @@ -264,7 +352,9 @@ def step(self, closure=None): state = self.state[p] if len(state) == 0: state["momentum_buffer"] = torch.zeros_like(p) - update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + update = muon_update( + p.grad, state["momentum_buffer"], beta=group["momentum"] + ) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update.reshape(p.shape), alpha=-group["lr"]) else: @@ -278,8 +368,14 @@ def step(self, closure=None): state["exp_avg_sq"] = torch.zeros_like(p) state["step"] = 0 state["step"] += 1 - update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], - state["step"], group["betas"], group["eps"]) + update = adam_update( + p.grad, + state["exp_avg"], + state["exp_avg_sq"], + state["step"], + group["betas"], + group["eps"], + ) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update, alpha=-group["lr"]) diff --git a/setup.py b/setup.py index 8ae0848..c56a22b 100644 --- a/setup.py +++ b/setup.py @@ -19,8 +19,6 @@ "Topic :: Scientific/Engineering :: Information Analysis", 'License :: OSI Approved :: MIT License', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11',