Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 142 additions & 46 deletions muon.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch
from typing import Any, Callable, Optional
from collections.abc import Iterable
import torch.distributed as dist
from torch.optim.optimizer import ParamsT


def zeropower_via_newtonschulz5(G, steps: int):
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int) -> torch.Tensor:
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
Expand All @@ -12,8 +15,10 @@ def zeropower_via_newtonschulz5(G, steps: int):
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
assert (
G.ndim >= 2
) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
Expand All @@ -23,21 +28,29 @@ def zeropower_via_newtonschulz5(G, steps: int):
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
B = (
b * A + c * A @ A
) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X

if G.size(-2) > G.size(-1):
X = X.mT
return X


def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
def muon_update(
grad: torch.Tensor,
momentum: torch.Tensor,
beta: float = 0.95,
ns_steps: int = 5,
nesterov: bool = True,
) -> torch.Tensor:
momentum.lerp_(grad, 1 - beta)
update = grad.lerp_(momentum, beta) if nesterov else momentum
if update.ndim == 4: # for the case of conv filters
if update.ndim == 4: # for the case of conv filters
update = update.view(len(update), -1)
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5
return update


Expand All @@ -62,24 +75,38 @@ class Muon(torch.optim.Optimizer):
weight_decay: The AdamW-style weight decay.
momentum: The momentum. A value of 0.95 here is usually fine.
"""
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
params = sorted(params, key=lambda x: x.size(), reverse=True)

def __init__(
self,
params: ParamsT,
lr=0.02,
weight_decay: float = 0,
momentum: float = 0.95,
):
defaults: dict[str, Any] = dict(
lr=lr, weight_decay=weight_decay, momentum=momentum
)
assert (
isinstance(params, list)
and len(params) >= 1
and isinstance(params[0], torch.nn.Parameter)
)
# params = sorted(params, key=lambda x: x.size(), reverse=True)
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure=None):

def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]:
params_pad = params + [torch.empty_like(params[-1])] * (
dist.get_world_size() - len(params) % dist.get_world_size()
)
for base_i in range(len(params))[:: dist.get_world_size()]:
if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()]
if p.grad is None:
Expand All @@ -88,10 +115,15 @@ def step(self, closure=None):
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
update = muon_update(
p.grad, state["momentum_buffer"], beta=group["momentum"]
)
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
dist.all_gather(
params_pad[base_i : base_i + dist.get_world_size()],
params_pad[base_i + dist.get_rank()],
)

return loss

Expand All @@ -100,13 +132,19 @@ class SingleDeviceMuon(torch.optim.Optimizer):
"""
Muon variant for usage in non-distributed settings.
"""
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):

def __init__(
self,
params: ParamsT,
lr: float = 0.02,
weight_decay: float = 0,
momentum: float = 0.95,
) -> None:
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure=None):

def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
loss = None
if closure is not None:
with torch.enable_grad():
Expand All @@ -120,18 +158,27 @@ def step(self, closure=None):
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
update = muon_update(
p.grad, state["momentum_buffer"], beta=group["momentum"]
)
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])

return loss


def adam_update(grad, buf1, buf2, step, betas, eps):
def adam_update(
grad: torch.Tensor,
buf1: torch.Tensor,
buf2: torch.Tensor,
step: int,
betas: tuple[float, float],
eps: float,
) -> torch.Tensor:
buf1.lerp_(grad, 1 - betas[0])
buf2.lerp_(grad.square(), 1 - betas[1])
buf1c = buf1 / (1 - betas[0]**step)
buf2c = buf2 / (1 - betas[1]**step)
buf1c = buf1 / (1 - betas[0] ** step)
buf2c = buf2 / (1 - betas[1] ** step)
return buf1c / (buf2c.sqrt() + eps)


Expand Down Expand Up @@ -162,28 +209,43 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
optimizer = MuonWithAuxAdam(param_groups)
```
"""
def __init__(self, param_groups):

def __init__(self, param_groups: Iterable[dict[str, Any]]) -> None:
for group in param_groups:
assert "use_muon" in group
if group["use_muon"]:
group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True)
group["params"] = sorted(
group["params"], key=lambda x: x.size(), reverse=True
)
# defaults
group["lr"] = group.get("lr", 0.02)
group["momentum"] = group.get("momentum", 0.95)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
assert set(group.keys()) == {
"params",
"lr",
"momentum",
"weight_decay",
"use_muon",
}
else:
# defaults
group["lr"] = group.get("lr", 3e-4)
group["betas"] = group.get("betas", (0.9, 0.95))
group["eps"] = group.get("eps", 1e-10)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
assert set(group.keys()) == {
"params",
"lr",
"betas",
"eps",
"weight_decay",
"use_muon",
}
super().__init__(param_groups, dict())

@torch.no_grad()
def step(self, closure=None):

def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
loss = None
if closure is not None:
with torch.enable_grad():
Expand All @@ -192,8 +254,10 @@ def step(self, closure=None):
for group in self.param_groups:
if group["use_muon"]:
params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]:
params_pad = params + [torch.empty_like(params[-1])] * (
dist.get_world_size() - len(params) % dist.get_world_size()
)
for base_i in range(len(params))[:: dist.get_world_size()]:
if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()]
if p.grad is None:
Expand All @@ -202,10 +266,15 @@ def step(self, closure=None):
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
update = muon_update(
p.grad, state["momentum_buffer"], beta=group["momentum"]
)
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
dist.all_gather(
params_pad[base_i : base_i + dist.get_world_size()],
params_pad[base_i + dist.get_rank()],
)
else:
for p in group["params"]:
if p.grad is None:
Expand All @@ -217,8 +286,14 @@ def step(self, closure=None):
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0
state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
state["step"], group["betas"], group["eps"])
update = adam_update(
p.grad,
state["exp_avg"],
state["exp_avg_sq"],
state["step"],
group["betas"],
group["eps"],
)
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])

Expand All @@ -229,28 +304,41 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
"""
Non-distributed variant of MuonWithAuxAdam.
"""
def __init__(self, param_groups):

def __init__(self, param_groups: Iterable[dict[str, Any]]) -> None:
for group in param_groups:
assert "use_muon" in group
if group["use_muon"]:
# defaults
group["lr"] = group.get("lr", 0.02)
group["momentum"] = group.get("momentum", 0.95)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
assert set(group.keys()) == {
"params",
"lr",
"momentum",
"weight_decay",
"use_muon",
}
else:
# defaults
group["lr"] = group.get("lr", 3e-4)
group["betas"] = group.get("betas", (0.9, 0.95))
group["eps"] = group.get("eps", 1e-10)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
assert set(group.keys()) == {
"params",
"lr",
"betas",
"eps",
"weight_decay",
"use_muon",
}
super().__init__(param_groups, dict())

@torch.no_grad()
def step(self, closure=None):

loss = None
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
loss: Optional[float] = None
if closure is not None:
with torch.enable_grad():
loss = closure()
Expand All @@ -264,7 +352,9 @@ def step(self, closure=None):
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
update = muon_update(
p.grad, state["momentum_buffer"], beta=group["momentum"]
)
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
else:
Expand All @@ -278,8 +368,14 @@ def step(self, closure=None):
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0
state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
state["step"], group["betas"], group["eps"])
update = adam_update(
p.grad,
state["exp_avg"],
state["exp_avg_sq"],
state["step"],
group["betas"],
group["eps"],
)
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])

Expand Down
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
"Topic :: Scientific/Engineering :: Information Analysis",
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
Expand Down