From 9e178276317ae9adc234a4f9ca113ab87ff705b8 Mon Sep 17 00:00:00 2001 From: Runa Mikael Eschenhagen Date: Mon, 15 Dec 2025 13:19:33 -0800 Subject: [PATCH] Sync KL-Shampoo to OSS (#255) Summary: Sync KL-Shampoo implementation to OSS. Also, add KL-Shampoo reference to README. **Original summary:** Adding a simplified version of [KL-Shampoo](https://arxiv.org/abs/2509.03378). The main difference to the algorithm presented in the paper is that here only the update of the factor matrices is adjusted, but the eigenvalue correction to the individual factor matrices is not implemented. This correction can also be used for regular Shampoo and will be added in the future. KL-Shampoo is implemented by storing either the inverse root matrices directly, or storing the eigendecompositions of the factor matrices, analogous to regular Shampoo. Reviewed By: hjmshi Differential Revision: D89193149 --- distributed_shampoo/README.md | 2 + distributed_shampoo/__init__.py | 4 + distributed_shampoo/distributed_shampoo.py | 8 + .../examples/argument_parser.py | 2 + distributed_shampoo/examples/trainer_utils.py | 14 ++ .../shampoo_preconditioner_list.py | 132 +++++++++++++++-- .../tests/shampoo_preconditioner_list_test.py | 139 ++++++++++++++++++ distributed_shampoo/shampoo_types.py | 98 ++++++++++++ .../tests/distributed_shampoo_test.py | 14 ++ 9 files changed, 403 insertions(+), 10 deletions(-) diff --git a/distributed_shampoo/README.md b/distributed_shampoo/README.md index 772706e..8a3799c 100644 --- a/distributed_shampoo/README.md +++ b/distributed_shampoo/README.md @@ -51,6 +51,7 @@ Key distinctives of this implementation include: - Option to (approximately) correct the eigenvalues/run Adam in the eigenbasis of Shampoo's preconditioner (SOAP) [2,6,7]. - Option to use an adaptive preconditioner update frequency when symmetric eigendecompositions or the QR algorithm is used [8]. - Spectral descent via reduced SVD or Newton-Schulz iteration for 2D gradients, or gradients that have been reshaped to 2D [9,10]. This can be used to implement Muon [11], see [Example 6](#example-6-muon). +- KL-Shampoo (without per-factor matrix eigenvalue correction) [12]. ## Requirements @@ -784,3 +785,4 @@ If you use PyTorch Distributed Shampoo in your work, please use the following Bi 9. [Preconditioned Spectral Descent for Deep Learning](https://papers.nips.cc/paper_files/paper/2015/hash/f50a6c02a3fc5a3a5d4d9391f05f3efc-Abstract.html). David E. Carlson, Edo Collins, Ya-Ping Hsieh, Lawrence Carin, Volkan Cevher. NeurIPS, 2015. 10. [Old Optimizer, New Norm: An Anthology](https://arxiv.org/abs/2409.20325). Jeremy Bernstein, Laker Newhouse. Tech report, 2024. 11. [Muon: An optimizer for hidden layers in neural networks](https://kellerjordan.github.io/posts/muon/). Keller Jordan, Yuchen Jin, Vlado Boza, Jiacheng You, Franz Cesista, Laker Newhouse, Jeremy Bernstein. Blog post, 2024. +12. [Understanding and Improving Shampoo and SOAP via Kullback-Leibler Minimization](https://arxiv.org/abs/2509.03378). Wu Lin, Scott C. Lowe, Felix Dangel, Runa Eschenhagen, Zikun Xu, Roger B. Grosse. Tech report, 2025. diff --git a/distributed_shampoo/__init__.py b/distributed_shampoo/__init__.py index 6c6bcb4..41988c8 100644 --- a/distributed_shampoo/__init__.py +++ b/distributed_shampoo/__init__.py @@ -43,6 +43,7 @@ DefaultSOAPConfig, DefaultSpectralDescentPreconditionerConfig, DistributedConfig, + EigendecomposedKLShampooPreconditionerConfig, EigendecomposedShampooPreconditionerConfig, EigenvalueCorrectedShampooPreconditionerConfig, FSDPDistributedConfig, @@ -52,6 +53,7 @@ HybridShardDistributedConfig, PreconditionerConfig, RMSpropPreconditionerConfig, + RootInvKLShampooPreconditionerConfig, RootInvShampooPreconditionerConfig, SGDPreconditionerConfig, ShampooPreconditionerConfig, @@ -83,7 +85,9 @@ "ShampooPreconditionerConfig", # Abstract base class (based on `AmortizedPreconditionerConfig`). "RootInvShampooPreconditionerConfig", # Based on `ShampooPreconditionerConfig`. "DefaultShampooConfig", # Default `RootInvShampooPreconditionerConfig` using `EigenConfig`. + "RootInvKLShampooPreconditionerConfig", # Based on `RootInvShampooPreconditionerConfig`. "EigendecomposedShampooPreconditionerConfig", # Based on `ShampooPreconditionerConfig`. + "EigendecomposedKLShampooPreconditionerConfig", # Based on `EigendecomposedShampooPreconditionerConfig`. "EigenvalueCorrectedShampooPreconditionerConfig", # Based on `AmortizedPreconditionerConfig`. "DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighEigendecompositionConfig`. "DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QREigendecompositionConfig`. diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index 7e21d91..c31ec91 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -48,8 +48,10 @@ SGDPreconditionerList, ) from distributed_shampoo.preconditioner.shampoo_preconditioner_list import ( + EigendecomposedKLShampooPreconditionerList, EigendecomposedShampooPreconditionerList, EigenvalueCorrectedShampooPreconditionerList, + RootInvKLShampooPreconditionerList, RootInvShampooPreconditionerList, ) from distributed_shampoo.preconditioner.sign_descent_preconditioner_list import ( @@ -72,6 +74,7 @@ DISTRIBUTED_CONFIG, DistributedConfig, DISTRIBUTOR, + EigendecomposedKLShampooPreconditionerConfig, EigendecomposedShampooPreconditionerConfig, EigenvalueCorrectedShampooPreconditionerConfig, EPSILON, @@ -98,6 +101,7 @@ PreconditionerConfig, PREVIOUS_GRAD_SELECTOR, RMSpropPreconditionerConfig, + RootInvKLShampooPreconditionerConfig, RootInvShampooPreconditionerConfig, SGDPreconditionerConfig, SHAMPOO_PRECONDITIONER_LIST, @@ -638,6 +642,8 @@ def _preconditioner_config_to_list_cls( RootInvShampooPreconditionerConfig() | EigendecomposedShampooPreconditionerConfig() | EigenvalueCorrectedShampooPreconditionerConfig() + | RootInvKLShampooPreconditionerConfig() + | EigendecomposedKLShampooPreconditionerConfig() ): preconditioner_config_to_list_cls: dict[ type[PreconditionerConfig], Callable[..., PreconditionerList] @@ -645,6 +651,8 @@ def _preconditioner_config_to_list_cls( RootInvShampooPreconditionerConfig: RootInvShampooPreconditionerList, EigendecomposedShampooPreconditionerConfig: EigendecomposedShampooPreconditionerList, EigenvalueCorrectedShampooPreconditionerConfig: EigenvalueCorrectedShampooPreconditionerList, + RootInvKLShampooPreconditionerConfig: RootInvKLShampooPreconditionerList, + EigendecomposedKLShampooPreconditionerConfig: EigendecomposedKLShampooPreconditionerList, } return preconditioner_config_to_list_cls[type(preconditioner_config)]( block_list=state_lists[DISTRIBUTOR].local_blocked_params, diff --git a/distributed_shampoo/examples/argument_parser.py b/distributed_shampoo/examples/argument_parser.py index 2d90928..3b6b45b 100644 --- a/distributed_shampoo/examples/argument_parser.py +++ b/distributed_shampoo/examples/argument_parser.py @@ -34,6 +34,8 @@ class PreconditionerComputationType(enum.Enum): ADAM = enum.auto() EIGEN_ROOT_INV = enum.auto() EIGENDECOMPOSED_ROOT_INV = enum.auto() + KL_EIGEN_ROOT_INV = enum.auto() + KL_EIGENDECOMPOSED_ROOT_INV = enum.auto() COUPLED_NEWTON_ROOT_INV = enum.auto() COUPLED_HIGHER_ORDER_ROOT_INV = enum.auto() EIGH_EIGENVALUE_CORRECTION = enum.auto() diff --git a/distributed_shampoo/examples/trainer_utils.py b/distributed_shampoo/examples/trainer_utils.py index 3f0f5d4..defdbad 100644 --- a/distributed_shampoo/examples/trainer_utils.py +++ b/distributed_shampoo/examples/trainer_utils.py @@ -31,9 +31,11 @@ DefaultSpectralDescentPreconditionerConfig, DistributedConfig, DistributedShampoo, + EigendecomposedKLShampooPreconditionerConfig, EigendecomposedShampooPreconditionerConfig, PreconditionerConfig, RMSpropPreconditionerConfig, + RootInvKLShampooPreconditionerConfig, RootInvShampooPreconditionerConfig, SGDPreconditionerConfig, ) @@ -119,6 +121,18 @@ def instantiate_preconditioner_config( == PreconditionerComputationType.EIGENDECOMPOSED_ROOT_INV ): return EigendecomposedShampooPreconditionerConfig() + elif ( + preconditioner_computation_type + == PreconditionerComputationType.KL_EIGEN_ROOT_INV + ): + return RootInvKLShampooPreconditionerConfig( + amortized_computation_config=EigenConfig() + ) + elif ( + preconditioner_computation_type + == PreconditionerComputationType.KL_EIGENDECOMPOSED_ROOT_INV + ): + return EigendecomposedKLShampooPreconditionerConfig() elif ( preconditioner_computation_type == PreconditionerComputationType.COUPLED_NEWTON_ROOT_INV diff --git a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py index a5b6d97..6d1feab 100644 --- a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py +++ b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py @@ -1270,6 +1270,25 @@ def compress_preconditioner_list( self._local_preconditioned_dims_selector_list, local_grad_selector ) + @profile_decorator + def _compute_outer_product_list( + self, + grad: Tensor, + order: int, + preconditioned_dims_selector: tuple[bool, ...], + kronecker_factors: _ShampooKroneckerFactorsUnwrappedType, + ) -> tuple[Tensor, ...]: + # Construct outer product list for updating Kronecker factors. + return tuple( + torch.tensordot( + grad, + grad, + # Contracts across all dimensions except for k. + dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] + ) + for k in compress_list(range(order), preconditioned_dims_selector) + ) + @profile_decorator def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None: # NOTE: Unlike AdagradPreconditionerList, we will loop through each gradient individually. @@ -1286,15 +1305,8 @@ def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None: if not kronecker_factors.factor_matrices: continue - # Construct outer product list for updating Kronecker factors. - outer_product_list = tuple( - torch.tensordot( - grad, - grad, - # Contracts across all dimensions except for k. - dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] - ) - for k in compress_list(range(order), preconditioned_dims_selector) + outer_product_list = self._compute_outer_product_list( + grad, order, preconditioned_dims_selector, kronecker_factors ) if self._beta2 != 1.0: @@ -1673,7 +1685,6 @@ def update_preconditioners( kronecker_factors.corrected_eigenvalues.mul_(self._beta2) # NOTE: The case when self._weighting_factor == 1.0 is not well tested and might not be stable. - print(f"{self._weighting_factor=}") kronecker_factors.corrected_eigenvalues.addcmul_( grad, grad, value=self._weighting_factor ) @@ -1707,3 +1718,104 @@ def _compute_preconditioned_gradient( preconditioner_list=kronecker_factors.factor_matrices_eigenvectors, dims=([0], [1]), ) + + +class RootInvKLShampooPreconditionerList(RootInvShampooPreconditionerList): + """Root inverse KL-Shampoo preconditioners for list of parameters.""" + + @profile_decorator + def _compute_outer_product_list( + self, + grad: Tensor, + order: int, + preconditioned_dims_selector: tuple[bool, ...], + kronecker_factors: RootInvShampooKroneckerFactorsUnwrapped, + ) -> tuple[Tensor, ...]: + # Construct outer product list for updating Kronecker factors. + outer_product_list = [] + for idx_of_k, k in enumerate( + compress_list(range(order), preconditioned_dims_selector) + ): + # KL-Shampoo uses the gradient preconditioned (along all dimensions that are contracted) with the inverse root of the factor matrices to compute the outer products. + local_preconditioned_dims_selector = list(preconditioned_dims_selector) + local_preconditioned_dims_selector[k] = False + preconditioned_grad = self._precondition_grad( + grad=grad, + preconditioned_dims_selector=tuple(local_preconditioned_dims_selector), + preconditioner_list=tuple( + inv_factor_matrix + for idx, inv_factor_matrix in enumerate( + kronecker_factors.inv_factor_matrices + ) + if idx != idx_of_k + ), + ) + outer_product_list.append( + torch.tensordot( + preconditioned_grad, + preconditioned_grad, + # Contracts across all dimensions except for k. + dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] + ) + ) + return tuple(outer_product_list) + + +class EigendecomposedKLShampooPreconditionerList( + EigendecomposedShampooPreconditionerList +): + """Eigendecomposed KL-Shampoo preconditioners for list of parameters.""" + + @profile_decorator + def _compute_outer_product_list( + self, + grad: Tensor, + order: int, + preconditioned_dims_selector: tuple[bool, ...], + kronecker_factors: EigendecomposedShampooKroneckerFactorsUnwrapped, + ) -> tuple[Tensor, ...]: + # TODO: remove assertion when rank_deficient_stability_config is generalized to MatrixFunctionConfig + assert isinstance( + self._preconditioner_config.amortized_computation_config, + EigendecompositionConfig, + ) + rank_deficient_stability_config = self._preconditioner_config.amortized_computation_config.rank_deficient_stability_config + + # Construct outer product list for updating Kronecker factors. + outer_product_list = [] + for idx_of_k, k in enumerate( + compress_list(range(order), preconditioned_dims_selector) + ): + # KL-Shampoo uses the gradient preconditioned (along all dimensions that are contracted) with the inverse root of the factor matrices to compute the outer products. + local_preconditioned_dims_selector = list(preconditioned_dims_selector) + local_preconditioned_dims_selector[k] = False + preconditioned_grad = self._precondition_grad( + grad=grad, + preconditioned_dims_selector=tuple(local_preconditioned_dims_selector), + preconditioner_list=tuple( + matrix_inverse_root_from_eigendecomposition( + L=eigenvalues, + Q=eigenvectors, + root=Fraction(root), + epsilon=self._epsilon, + rank_deficient_stability_config=rank_deficient_stability_config, + ) + for idx, (eigenvalues, eigenvectors, root) in enumerate( + zip( + kronecker_factors.factor_matrices_eigenvalues, + kronecker_factors.factor_matrices_eigenvectors, + kronecker_factors.roots, + strict=True, + ) + ) + if idx != idx_of_k + ), + ) + outer_product = torch.tensordot( + preconditioned_grad, + preconditioned_grad, + # Contracts across all dimensions except for k. + dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] + ) + outer_product_list.append(outer_product) + return tuple(outer_product_list) diff --git a/distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py index bf3e14b..3928477 100644 --- a/distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py @@ -29,8 +29,10 @@ from distributed_shampoo.preconditioner.preconditioner_list import PreconditionerList from distributed_shampoo.preconditioner.shampoo_preconditioner_list import ( BaseShampooPreconditionerList, + EigendecomposedKLShampooPreconditionerList, EigendecomposedShampooPreconditionerList, EigenvalueCorrectedShampooPreconditionerList, + RootInvKLShampooPreconditionerList, RootInvShampooPreconditionerList, ) from distributed_shampoo.preconditioner.tests.preconditioner_list_test_utils import ( @@ -40,9 +42,11 @@ AmortizedPreconditionerConfig, DefaultShampooConfig, DefaultSOAPConfig, + EigendecomposedKLShampooPreconditionerConfig, EigendecomposedShampooPreconditionerConfig, EigenvalueCorrectedShampooPreconditionerConfig, PreconditionerValueError, + RootInvKLShampooPreconditionerConfig, RootInvShampooPreconditionerConfig, ShampooPreconditionerConfig, ) @@ -1342,3 +1346,138 @@ def test_inverse_exponent_override(self) -> None: masked_grad_lists=[masked_grad_list1, masked_grad_list2], masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, ) + + +class RootInvKLShampooPreconditionerListTest(RootInvShampooPreconditionerListTest): + @property + def _default_preconditioner_config(self) -> RootInvShampooPreconditionerConfig: + return replace( + RootInvKLShampooPreconditionerConfig(), + factor_matrix_dtype=torch.float64, + ) + + @property + def _preconditioner_list_factory(self) -> Callable[..., PreconditionerList]: + return RootInvKLShampooPreconditionerList + + @unittest.skip( + "RootInvKLShampooPreconditionerList does not support adaptive computation frequency." + ) + def test_adaptive_amortized_computation_frequency(self) -> None: ... + + +class EigendecomposedKLShampooPreconditionerListTest( + EigendecomposedShampooPreconditionerListTest +): + @property + def _default_preconditioner_config( # type: ignore[override] + self, + ) -> EigendecomposedKLShampooPreconditionerConfig: + return EigendecomposedKLShampooPreconditionerConfig( + amortized_computation_config=QREigendecompositionConfig(), + factor_matrix_dtype=torch.float64, + factor_matrix_eigenvectors_dtype=torch.float64, + factor_matrix_eigenvalues_dtype=torch.float64, + ) + + @property + def _preconditioner_list_factory(self) -> Callable[..., PreconditionerList]: + return EigendecomposedKLShampooPreconditionerList + + def test_update_preconditioners_and_precondition_with_epsilon(self) -> None: + """ + We provide examples where we deliberately choose a large epsilon. This is to ensure that + the matrix inverse computation behaves as expected. Below we update the preconditioners twice + for 4 different blocks and check if the preconditioned gradients are as expected. When + performing the inverse computation, epsilon is chosen to be 80 in (L + epsilon * I) and + (R + epsilon * I). + + G_{ij}: at the i-th step, the gradient for the j-th block. For example, + G12 is the gradient for the second block, at the first step. + + For KL-Shampoo, we will correct the scale of the 2D gradients such that the Kronecker factors + match the ones of regular Shampoo with a scale correction s. + + epsilon = 80.0 + s = epsilon^{1/4} + + Gradients for block 1: (no right preconditioner) + (1) 1D Tensor of Size 2 + G11 = [1, 0]^T + G21 = [0, 1]^T + + L = G11 * G11^T + G21 * G21^T = [[1, 0], [0, 1]] + P = (L + epsilon * I)^{-1/2} G21 = [[81, 0], [0, 81]]^{-1/2} [0, 1]^T + = [0, 1/9]^T. + + Gradients for block 2: (both left and right preconditioner) + (2) Tensor of Size 2 x 2 + G12 = s * [[1, 0], [0, 1]] / sqrt(2) + G22 = s * [[1, 0], [0, 1]] / sqrt(2) + + L = G12 * G12^T + G22 * G22^T = [[1, 0], [0, 1]] + R = G12^T * G12 + G22^T * G22 = [[1, 0], [0, 1]] + P = (L + epsilon * I)^{-1/4} G22 (R + epsilon * I)^{-1/4} + = [[1/3, 0], [0, 1/3]] * G22 * [[1/3, 0], [0, 1/3]] = s * I / (9 * sqrt(2)) + + Gradients for block 3: (both left and right preconditioner) + (3) Tensor of Size 1 x 2 + G13 = s * [[1, 0]] + G23 = s * [[0, 1]] + + L = G13 * G13^T + G23 * G23^T = I + R = G13^T * G13 + G23^T * G23 = 2 + P = (L + epsilon * I)^{-1/4} G22 (R + epsilon * I)^{-1/4} + = [[1/3, 0], [0, 1/3]] * G22 * (80 + 2)^{-1/4} = + = s * [[0.0, 1.0/3.0 * 82.0 ** (-1/4)]] + + Gradients for block 4: (no preconditioner) + (4) Tensor of Size 0 + G14 = 1 + G24 = 1 + + No preconditioner is applied. Expected gradient is 1. + + """ + + epsilon = 80.0 + scale_correction = epsilon**0.25 + + # Blocked gradients at the first step: masked_grad_list1 = (G11, G12, G13, G14) + masked_grad_list1 = ( + torch.tensor([1.0, 0.0]), + scale_correction * torch.eye(2) / math.sqrt(2), + scale_correction * torch.tensor([[1.0, 0.0]]), + torch.tensor(1.0), + ) + + # Blocked gradients at the second step: masked_grad_list2 = (G21, G22, G23, G24) + masked_grad_list2 = ( + torch.tensor([0.0, 1.0]), + scale_correction * torch.eye(2) / math.sqrt(2), + scale_correction * torch.tensor([[0.0, 1.0]]), + torch.tensor(1.0), + ) + + # Manually apply the preconditioners to the gradients at the second step (masked_grad_list2) with epsilon. + # The result is stored in masked_expected_preconditioned_grad_list. + masked_expected_preconditioned_grad_list = ( + torch.tensor([0.0, 1.0 / 9.0]), + scale_correction * torch.eye(2) / (9 * math.sqrt(2)), + scale_correction * torch.tensor([[0.0, 1.0 / 3.0 * 82.0 ** (-1 / 4)]]), + torch.tensor(1.0), + ) + + # Apply preconditioner to the last step (masked_grad_list2) with epsilon. The result should be the same as the expected preconditioned grad list. + self._verify_preconditioner_updates( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=1.0, + weighting_factor=1.0, + use_bias_correction=True, + epsilon=epsilon, + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=tuple( + masked_expected_preconditioned_grad_list + ), + ) diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 20bdfe2..6c1786b 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -513,6 +513,104 @@ def __post_init__(self) -> None: ) +@dataclass(kw_only=True) +class RootInvKLShampooPreconditionerConfig(RootInvShampooPreconditionerConfig): + """Configuration for KL-Shampoo preconditioner computation with caching of the root inverse factor matrices. + + Attributes: + amortized_computation_config (RootInvConfig): Configuration for the inverse-root computation. (Default: DefaultEigenConfig) + num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3) + factor_matrix_dtype (torch.dtype): Data type for factor matrix. (Default: torch.float32) + inverse_exponent_override (dict[int, dict[int, float] | float]): The inverse_exponent_override attribute is a dictionary that allows for customizing the inverse exponent used in the KL-Shampoo preconditioner computation. + The keys of the dictionary represent the order of the tensor, and the values are either dictionaries with dimension indices as keys and override values as values, or a single float value for all dimensions. All unspecified dimensions use a default exponent of 1/(2*max(o,1)), where o is the order of the tensor. (Default: {}) + + As an example, suppose inverse_exponent_override={2: 0.2, 3: {0: 0.0, 1: 0.25}}. In this case, all 1-D tensors will use the default exponent of 0.5 for preconditioning the first (and only) dimension. All 2-D tensors will be preconditioned with an exponent of 0.2 on all dimensions. All 3-D tensors will have the first dimension be preconditioned with an exponent of 0.5, the second dimension not preconditioned, and the third dimension preconditioned with the default exponent 0.1667. + A visualization of this example can be seen below: + 1-D: + +-------x-------+ + | + | + (^0.5), the default inverse exponent 1/(2*1) since inverse_exponent_override[1] is not specified + 2-D: + +-----------+ + | | + | | + | |-----(^0.2), as specified by inverse_exponent_override[2]=0.2 + | | + | | + +-----------+ + | + | + (^0.2), as specified by inverse_exponent_override[2]=0.2 + 3-D: + +---------------+ + / /| + / / | + +---------------+ | + | | | + | | -|---(^0.25), as specified by inverse_exponent_override[3][1]=0.25 + | | + + | | / + | |/\ + +---------------+ \ + | (^0.1667), the default inverse exponent 1/(2*3) since inverse_exponent_override[3][2] is not specified + | + no preconditioning since inverse_exponent_override[3][0]=0.0 + + + """ + + +@dataclass(kw_only=True) +class EigendecomposedKLShampooPreconditionerConfig( + EigendecomposedShampooPreconditionerConfig +): + """Configuration for KL-Shampoo preconditioner computation with caching of the eigendecomposed factor matrices. + + Attributes: + amortized_computation_config (EigendecompositionConfig): Configuration for the eigendecomposition computation. (Default: DefaultEigendecompositionConfig) + num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3) + factor_matrix_dtype (torch.dtype): Data type for factor matrix. (Default: torch.float32) + inverse_exponent_override (dict[int, dict[int, float] | float]): The inverse_exponent_override attribute is a dictionary that allows for customizing the inverse exponent used in the KL-Shampoo preconditioner computation. + The keys of the dictionary represent the order of the tensor, and the values are either dictionaries with dimension indices as keys and override values as values, or a single float value for all dimensions. All unspecified dimensions use a default exponent of 1/(2*max(o,1)), where o is the order of the tensor. (Default: {}) + + As an example, suppose inverse_exponent_override={2: 0.2, 3: {0: 0.0, 1: 0.25}}. In this case, all 1-D tensors will use the default exponent of 0.5 for preconditioning the first (and only) dimension. All 2-D tensors will be preconditioned with an exponent of 0.2 on all dimensions. All 3-D tensors will have the first dimension be preconditioned with an exponent of 0.5, the second dimension not preconditioned, and the third dimension preconditioned with the default exponent 0.1667. + A visualization of this example can be seen below: + 1-D: + +-------x-------+ + | + | + (^0.5), the default inverse exponent 1/(2*1) since inverse_exponent_override[1] is not specified + 2-D: + +-----------+ + | | + | | + | |-----(^0.2), as specified by inverse_exponent_override[2]=0.2 + | | + | | + +-----------+ + | + | + (^0.2), as specified by inverse_exponent_override[2]=0.2 + 3-D: + +---------------+ + / /| + / / | + +---------------+ | + | | | + | | -|---(^0.25), as specified by inverse_exponent_override[3][1]=0.25 + | | + + | | / + | |/\ + +---------------+ \ + | (^0.1667), the default inverse exponent 1/(2*3) since inverse_exponent_override[3][2] is not specified + | + no preconditioning since inverse_exponent_override[3][0]=0.0 + + + """ + + @dataclass(kw_only=True) class SpectralDescentPreconditionerConfig(PreconditionerConfig): """Configuration for spectral descent computation in DistributedShampoo. diff --git a/distributed_shampoo/tests/distributed_shampoo_test.py b/distributed_shampoo/tests/distributed_shampoo_test.py index dca5d0a..56be68f 100644 --- a/distributed_shampoo/tests/distributed_shampoo_test.py +++ b/distributed_shampoo/tests/distributed_shampoo_test.py @@ -34,9 +34,11 @@ DefaultSingleDeviceDistributedConfig, DefaultSpectralDescentPreconditionerConfig, DistributedConfig, + EigendecomposedKLShampooPreconditionerConfig, EigendecomposedShampooPreconditionerConfig, EigenvalueCorrectedShampooPreconditionerConfig, PreconditionerConfig, + RootInvKLShampooPreconditionerConfig, RootInvShampooPreconditionerConfig, ShampooPT2CompileConfig, SignDescentPreconditionerConfig, @@ -1096,6 +1098,18 @@ def _ref_state_dict(self) -> dict[str, Any]: } +class RootInvKLShampooStateDictTest(ShampooStateDictTest): + @property + def _preconditioner_config(self) -> RootInvKLShampooPreconditionerConfig: + return RootInvKLShampooPreconditionerConfig() + + +class EigendecomposedKLShampooStateDictTest(EigendecomposedShampooStateDictTest): + @property + def _preconditioner_config(self) -> EigendecomposedKLShampooPreconditionerConfig: + return EigendecomposedKLShampooPreconditionerConfig() + + class SignDescentStateDictTest(AbstractTest.NoPreconditionerStateDictTestBase): @property def _preconditioner_config(self) -> SignDescentPreconditionerConfig: