Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions distributed_shampoo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Key distinctives of this implementation include:
- Option to (approximately) correct the eigenvalues/run Adam in the eigenbasis of Shampoo's preconditioner (SOAP) [2,6,7].
- Option to use an adaptive preconditioner update frequency when symmetric eigendecompositions or the QR algorithm is used [8].
- Spectral descent via reduced SVD or Newton-Schulz iteration for 2D gradients, or gradients that have been reshaped to 2D [9,10]. This can be used to implement Muon [11], see [Example 6](#example-6-muon).
- KL-Shampoo (without per-factor matrix eigenvalue correction) [12].

## Requirements

Expand Down Expand Up @@ -784,3 +785,4 @@ If you use PyTorch Distributed Shampoo in your work, please use the following Bi
9. [Preconditioned Spectral Descent for Deep Learning](https://papers.nips.cc/paper_files/paper/2015/hash/f50a6c02a3fc5a3a5d4d9391f05f3efc-Abstract.html). David E. Carlson, Edo Collins, Ya-Ping Hsieh, Lawrence Carin, Volkan Cevher. NeurIPS, 2015.
10. [Old Optimizer, New Norm: An Anthology](https://arxiv.org/abs/2409.20325). Jeremy Bernstein, Laker Newhouse. Tech report, 2024.
11. [Muon: An optimizer for hidden layers in neural networks](https://kellerjordan.github.io/posts/muon/). Keller Jordan, Yuchen Jin, Vlado Boza, Jiacheng You, Franz Cesista, Laker Newhouse, Jeremy Bernstein. Blog post, 2024.
12. [Understanding and Improving Shampoo and SOAP via Kullback-Leibler Minimization](https://arxiv.org/abs/2509.03378). Wu Lin, Scott C. Lowe, Felix Dangel, Runa Eschenhagen, Zikun Xu, Roger B. Grosse. Tech report, 2025.
4 changes: 4 additions & 0 deletions distributed_shampoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
DefaultSOAPConfig,
DefaultSpectralDescentPreconditionerConfig,
DistributedConfig,
EigendecomposedKLShampooPreconditionerConfig,
EigendecomposedShampooPreconditionerConfig,
EigenvalueCorrectedShampooPreconditionerConfig,
FSDPDistributedConfig,
Expand All @@ -52,6 +53,7 @@
HybridShardDistributedConfig,
PreconditionerConfig,
RMSpropPreconditionerConfig,
RootInvKLShampooPreconditionerConfig,
RootInvShampooPreconditionerConfig,
SGDPreconditionerConfig,
ShampooPreconditionerConfig,
Expand Down Expand Up @@ -83,7 +85,9 @@
"ShampooPreconditionerConfig", # Abstract base class (based on `AmortizedPreconditionerConfig`).
"RootInvShampooPreconditionerConfig", # Based on `ShampooPreconditionerConfig`.
"DefaultShampooConfig", # Default `RootInvShampooPreconditionerConfig` using `EigenConfig`.
"RootInvKLShampooPreconditionerConfig", # Based on `RootInvShampooPreconditionerConfig`.
"EigendecomposedShampooPreconditionerConfig", # Based on `ShampooPreconditionerConfig`.
"EigendecomposedKLShampooPreconditionerConfig", # Based on `EigendecomposedShampooPreconditionerConfig`.
"EigenvalueCorrectedShampooPreconditionerConfig", # Based on `AmortizedPreconditionerConfig`.
"DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighEigendecompositionConfig`.
"DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QREigendecompositionConfig`.
Expand Down
8 changes: 8 additions & 0 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@
SGDPreconditionerList,
)
from distributed_shampoo.preconditioner.shampoo_preconditioner_list import (
EigendecomposedKLShampooPreconditionerList,
EigendecomposedShampooPreconditionerList,
EigenvalueCorrectedShampooPreconditionerList,
RootInvKLShampooPreconditionerList,
RootInvShampooPreconditionerList,
)
from distributed_shampoo.preconditioner.sign_descent_preconditioner_list import (
Expand All @@ -72,6 +74,7 @@
DISTRIBUTED_CONFIG,
DistributedConfig,
DISTRIBUTOR,
EigendecomposedKLShampooPreconditionerConfig,
EigendecomposedShampooPreconditionerConfig,
EigenvalueCorrectedShampooPreconditionerConfig,
EPSILON,
Expand All @@ -98,6 +101,7 @@
PreconditionerConfig,
PREVIOUS_GRAD_SELECTOR,
RMSpropPreconditionerConfig,
RootInvKLShampooPreconditionerConfig,
RootInvShampooPreconditionerConfig,
SGDPreconditionerConfig,
SHAMPOO_PRECONDITIONER_LIST,
Expand Down Expand Up @@ -638,13 +642,17 @@ def _preconditioner_config_to_list_cls(
RootInvShampooPreconditionerConfig()
| EigendecomposedShampooPreconditionerConfig()
| EigenvalueCorrectedShampooPreconditionerConfig()
| RootInvKLShampooPreconditionerConfig()
| EigendecomposedKLShampooPreconditionerConfig()
):
preconditioner_config_to_list_cls: dict[
type[PreconditionerConfig], Callable[..., PreconditionerList]
] = {
RootInvShampooPreconditionerConfig: RootInvShampooPreconditionerList,
EigendecomposedShampooPreconditionerConfig: EigendecomposedShampooPreconditionerList,
EigenvalueCorrectedShampooPreconditionerConfig: EigenvalueCorrectedShampooPreconditionerList,
RootInvKLShampooPreconditionerConfig: RootInvKLShampooPreconditionerList,
EigendecomposedKLShampooPreconditionerConfig: EigendecomposedKLShampooPreconditionerList,
}
return preconditioner_config_to_list_cls[type(preconditioner_config)](
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
Expand Down
2 changes: 2 additions & 0 deletions distributed_shampoo/examples/argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class PreconditionerComputationType(enum.Enum):
ADAM = enum.auto()
EIGEN_ROOT_INV = enum.auto()
EIGENDECOMPOSED_ROOT_INV = enum.auto()
KL_EIGEN_ROOT_INV = enum.auto()
KL_EIGENDECOMPOSED_ROOT_INV = enum.auto()
COUPLED_NEWTON_ROOT_INV = enum.auto()
COUPLED_HIGHER_ORDER_ROOT_INV = enum.auto()
EIGH_EIGENVALUE_CORRECTION = enum.auto()
Expand Down
14 changes: 14 additions & 0 deletions distributed_shampoo/examples/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@
DefaultSpectralDescentPreconditionerConfig,
DistributedConfig,
DistributedShampoo,
EigendecomposedKLShampooPreconditionerConfig,
EigendecomposedShampooPreconditionerConfig,
PreconditionerConfig,
RMSpropPreconditionerConfig,
RootInvKLShampooPreconditionerConfig,
RootInvShampooPreconditionerConfig,
SGDPreconditionerConfig,
)
Expand Down Expand Up @@ -119,6 +121,18 @@ def instantiate_preconditioner_config(
== PreconditionerComputationType.EIGENDECOMPOSED_ROOT_INV
):
return EigendecomposedShampooPreconditionerConfig()
elif (
preconditioner_computation_type
== PreconditionerComputationType.KL_EIGEN_ROOT_INV
):
return RootInvKLShampooPreconditionerConfig(
amortized_computation_config=EigenConfig()
)
elif (
preconditioner_computation_type
== PreconditionerComputationType.KL_EIGENDECOMPOSED_ROOT_INV
):
return EigendecomposedKLShampooPreconditionerConfig()
elif (
preconditioner_computation_type
== PreconditionerComputationType.COUPLED_NEWTON_ROOT_INV
Expand Down
132 changes: 122 additions & 10 deletions distributed_shampoo/preconditioner/shampoo_preconditioner_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,25 @@ def compress_preconditioner_list(
self._local_preconditioned_dims_selector_list, local_grad_selector
)

@profile_decorator
def _compute_outer_product_list(
self,
grad: Tensor,
order: int,
preconditioned_dims_selector: tuple[bool, ...],
kronecker_factors: _ShampooKroneckerFactorsUnwrappedType,
) -> tuple[Tensor, ...]:
# Construct outer product list for updating Kronecker factors.
return tuple(
torch.tensordot(
grad,
grad,
# Contracts across all dimensions except for k.
dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type]
)
for k in compress_list(range(order), preconditioned_dims_selector)
)

@profile_decorator
def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None:
# NOTE: Unlike AdagradPreconditionerList, we will loop through each gradient individually.
Expand All @@ -1286,15 +1305,8 @@ def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None:
if not kronecker_factors.factor_matrices:
continue

# Construct outer product list for updating Kronecker factors.
outer_product_list = tuple(
torch.tensordot(
grad,
grad,
# Contracts across all dimensions except for k.
dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type]
)
for k in compress_list(range(order), preconditioned_dims_selector)
outer_product_list = self._compute_outer_product_list(
grad, order, preconditioned_dims_selector, kronecker_factors
)

if self._beta2 != 1.0:
Expand Down Expand Up @@ -1673,7 +1685,6 @@ def update_preconditioners(
kronecker_factors.corrected_eigenvalues.mul_(self._beta2)

# NOTE: The case when self._weighting_factor == 1.0 is not well tested and might not be stable.
print(f"{self._weighting_factor=}")
kronecker_factors.corrected_eigenvalues.addcmul_(
grad, grad, value=self._weighting_factor
)
Expand Down Expand Up @@ -1707,3 +1718,104 @@ def _compute_preconditioned_gradient(
preconditioner_list=kronecker_factors.factor_matrices_eigenvectors,
dims=([0], [1]),
)


class RootInvKLShampooPreconditionerList(RootInvShampooPreconditionerList):
"""Root inverse KL-Shampoo preconditioners for list of parameters."""

@profile_decorator
def _compute_outer_product_list(
self,
grad: Tensor,
order: int,
preconditioned_dims_selector: tuple[bool, ...],
kronecker_factors: RootInvShampooKroneckerFactorsUnwrapped,
) -> tuple[Tensor, ...]:
# Construct outer product list for updating Kronecker factors.
outer_product_list = []
for idx_of_k, k in enumerate(
compress_list(range(order), preconditioned_dims_selector)
):
# KL-Shampoo uses the gradient preconditioned (along all dimensions that are contracted) with the inverse root of the factor matrices to compute the outer products.
local_preconditioned_dims_selector = list(preconditioned_dims_selector)
local_preconditioned_dims_selector[k] = False
preconditioned_grad = self._precondition_grad(
grad=grad,
preconditioned_dims_selector=tuple(local_preconditioned_dims_selector),
preconditioner_list=tuple(
inv_factor_matrix
for idx, inv_factor_matrix in enumerate(
kronecker_factors.inv_factor_matrices
)
if idx != idx_of_k
),
)
outer_product_list.append(
torch.tensordot(
preconditioned_grad,
preconditioned_grad,
# Contracts across all dimensions except for k.
dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type]
)
)
return tuple(outer_product_list)


class EigendecomposedKLShampooPreconditionerList(
EigendecomposedShampooPreconditionerList
):
"""Eigendecomposed KL-Shampoo preconditioners for list of parameters."""

@profile_decorator
def _compute_outer_product_list(
self,
grad: Tensor,
order: int,
preconditioned_dims_selector: tuple[bool, ...],
kronecker_factors: EigendecomposedShampooKroneckerFactorsUnwrapped,
) -> tuple[Tensor, ...]:
# TODO: remove assertion when rank_deficient_stability_config is generalized to MatrixFunctionConfig
assert isinstance(
self._preconditioner_config.amortized_computation_config,
EigendecompositionConfig,
)
rank_deficient_stability_config = self._preconditioner_config.amortized_computation_config.rank_deficient_stability_config

# Construct outer product list for updating Kronecker factors.
outer_product_list = []
for idx_of_k, k in enumerate(
compress_list(range(order), preconditioned_dims_selector)
):
# KL-Shampoo uses the gradient preconditioned (along all dimensions that are contracted) with the inverse root of the factor matrices to compute the outer products.
local_preconditioned_dims_selector = list(preconditioned_dims_selector)
local_preconditioned_dims_selector[k] = False
preconditioned_grad = self._precondition_grad(
grad=grad,
preconditioned_dims_selector=tuple(local_preconditioned_dims_selector),
preconditioner_list=tuple(
matrix_inverse_root_from_eigendecomposition(
L=eigenvalues,
Q=eigenvectors,
root=Fraction(root),
epsilon=self._epsilon,
rank_deficient_stability_config=rank_deficient_stability_config,
)
for idx, (eigenvalues, eigenvectors, root) in enumerate(
zip(
kronecker_factors.factor_matrices_eigenvalues,
kronecker_factors.factor_matrices_eigenvectors,
kronecker_factors.roots,
strict=True,
)
)
if idx != idx_of_k
),
)
outer_product = torch.tensordot(
preconditioned_grad,
preconditioned_grad,
# Contracts across all dimensions except for k.
dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type]
)
outer_product_list.append(outer_product)
return tuple(outer_product_list)
Loading
Loading