diff --git a/distributed_shampoo/README.md b/distributed_shampoo/README.md index 772706e..8a3799c 100644 --- a/distributed_shampoo/README.md +++ b/distributed_shampoo/README.md @@ -51,6 +51,7 @@ Key distinctives of this implementation include: - Option to (approximately) correct the eigenvalues/run Adam in the eigenbasis of Shampoo's preconditioner (SOAP) [2,6,7]. - Option to use an adaptive preconditioner update frequency when symmetric eigendecompositions or the QR algorithm is used [8]. - Spectral descent via reduced SVD or Newton-Schulz iteration for 2D gradients, or gradients that have been reshaped to 2D [9,10]. This can be used to implement Muon [11], see [Example 6](#example-6-muon). +- KL-Shampoo (without per-factor matrix eigenvalue correction) [12]. ## Requirements @@ -784,3 +785,4 @@ If you use PyTorch Distributed Shampoo in your work, please use the following Bi 9. [Preconditioned Spectral Descent for Deep Learning](https://papers.nips.cc/paper_files/paper/2015/hash/f50a6c02a3fc5a3a5d4d9391f05f3efc-Abstract.html). David E. Carlson, Edo Collins, Ya-Ping Hsieh, Lawrence Carin, Volkan Cevher. NeurIPS, 2015. 10. [Old Optimizer, New Norm: An Anthology](https://arxiv.org/abs/2409.20325). Jeremy Bernstein, Laker Newhouse. Tech report, 2024. 11. [Muon: An optimizer for hidden layers in neural networks](https://kellerjordan.github.io/posts/muon/). Keller Jordan, Yuchen Jin, Vlado Boza, Jiacheng You, Franz Cesista, Laker Newhouse, Jeremy Bernstein. Blog post, 2024. +12. [Understanding and Improving Shampoo and SOAP via Kullback-Leibler Minimization](https://arxiv.org/abs/2509.03378). Wu Lin, Scott C. Lowe, Felix Dangel, Runa Eschenhagen, Zikun Xu, Roger B. Grosse. Tech report, 2025. diff --git a/distributed_shampoo/__init__.py b/distributed_shampoo/__init__.py index 6c6bcb4..41988c8 100644 --- a/distributed_shampoo/__init__.py +++ b/distributed_shampoo/__init__.py @@ -43,6 +43,7 @@ DefaultSOAPConfig, DefaultSpectralDescentPreconditionerConfig, DistributedConfig, + EigendecomposedKLShampooPreconditionerConfig, EigendecomposedShampooPreconditionerConfig, EigenvalueCorrectedShampooPreconditionerConfig, FSDPDistributedConfig, @@ -52,6 +53,7 @@ HybridShardDistributedConfig, PreconditionerConfig, RMSpropPreconditionerConfig, + RootInvKLShampooPreconditionerConfig, RootInvShampooPreconditionerConfig, SGDPreconditionerConfig, ShampooPreconditionerConfig, @@ -83,7 +85,9 @@ "ShampooPreconditionerConfig", # Abstract base class (based on `AmortizedPreconditionerConfig`). "RootInvShampooPreconditionerConfig", # Based on `ShampooPreconditionerConfig`. "DefaultShampooConfig", # Default `RootInvShampooPreconditionerConfig` using `EigenConfig`. + "RootInvKLShampooPreconditionerConfig", # Based on `RootInvShampooPreconditionerConfig`. "EigendecomposedShampooPreconditionerConfig", # Based on `ShampooPreconditionerConfig`. + "EigendecomposedKLShampooPreconditionerConfig", # Based on `EigendecomposedShampooPreconditionerConfig`. "EigenvalueCorrectedShampooPreconditionerConfig", # Based on `AmortizedPreconditionerConfig`. "DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighEigendecompositionConfig`. "DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QREigendecompositionConfig`. diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index 7e21d91..c31ec91 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -48,8 +48,10 @@ SGDPreconditionerList, ) from distributed_shampoo.preconditioner.shampoo_preconditioner_list import ( + EigendecomposedKLShampooPreconditionerList, EigendecomposedShampooPreconditionerList, EigenvalueCorrectedShampooPreconditionerList, + RootInvKLShampooPreconditionerList, RootInvShampooPreconditionerList, ) from distributed_shampoo.preconditioner.sign_descent_preconditioner_list import ( @@ -72,6 +74,7 @@ DISTRIBUTED_CONFIG, DistributedConfig, DISTRIBUTOR, + EigendecomposedKLShampooPreconditionerConfig, EigendecomposedShampooPreconditionerConfig, EigenvalueCorrectedShampooPreconditionerConfig, EPSILON, @@ -98,6 +101,7 @@ PreconditionerConfig, PREVIOUS_GRAD_SELECTOR, RMSpropPreconditionerConfig, + RootInvKLShampooPreconditionerConfig, RootInvShampooPreconditionerConfig, SGDPreconditionerConfig, SHAMPOO_PRECONDITIONER_LIST, @@ -638,6 +642,8 @@ def _preconditioner_config_to_list_cls( RootInvShampooPreconditionerConfig() | EigendecomposedShampooPreconditionerConfig() | EigenvalueCorrectedShampooPreconditionerConfig() + | RootInvKLShampooPreconditionerConfig() + | EigendecomposedKLShampooPreconditionerConfig() ): preconditioner_config_to_list_cls: dict[ type[PreconditionerConfig], Callable[..., PreconditionerList] @@ -645,6 +651,8 @@ def _preconditioner_config_to_list_cls( RootInvShampooPreconditionerConfig: RootInvShampooPreconditionerList, EigendecomposedShampooPreconditionerConfig: EigendecomposedShampooPreconditionerList, EigenvalueCorrectedShampooPreconditionerConfig: EigenvalueCorrectedShampooPreconditionerList, + RootInvKLShampooPreconditionerConfig: RootInvKLShampooPreconditionerList, + EigendecomposedKLShampooPreconditionerConfig: EigendecomposedKLShampooPreconditionerList, } return preconditioner_config_to_list_cls[type(preconditioner_config)]( block_list=state_lists[DISTRIBUTOR].local_blocked_params, diff --git a/distributed_shampoo/examples/argument_parser.py b/distributed_shampoo/examples/argument_parser.py index 2d90928..3b6b45b 100644 --- a/distributed_shampoo/examples/argument_parser.py +++ b/distributed_shampoo/examples/argument_parser.py @@ -34,6 +34,8 @@ class PreconditionerComputationType(enum.Enum): ADAM = enum.auto() EIGEN_ROOT_INV = enum.auto() EIGENDECOMPOSED_ROOT_INV = enum.auto() + KL_EIGEN_ROOT_INV = enum.auto() + KL_EIGENDECOMPOSED_ROOT_INV = enum.auto() COUPLED_NEWTON_ROOT_INV = enum.auto() COUPLED_HIGHER_ORDER_ROOT_INV = enum.auto() EIGH_EIGENVALUE_CORRECTION = enum.auto() diff --git a/distributed_shampoo/examples/trainer_utils.py b/distributed_shampoo/examples/trainer_utils.py index 3f0f5d4..defdbad 100644 --- a/distributed_shampoo/examples/trainer_utils.py +++ b/distributed_shampoo/examples/trainer_utils.py @@ -31,9 +31,11 @@ DefaultSpectralDescentPreconditionerConfig, DistributedConfig, DistributedShampoo, + EigendecomposedKLShampooPreconditionerConfig, EigendecomposedShampooPreconditionerConfig, PreconditionerConfig, RMSpropPreconditionerConfig, + RootInvKLShampooPreconditionerConfig, RootInvShampooPreconditionerConfig, SGDPreconditionerConfig, ) @@ -119,6 +121,18 @@ def instantiate_preconditioner_config( == PreconditionerComputationType.EIGENDECOMPOSED_ROOT_INV ): return EigendecomposedShampooPreconditionerConfig() + elif ( + preconditioner_computation_type + == PreconditionerComputationType.KL_EIGEN_ROOT_INV + ): + return RootInvKLShampooPreconditionerConfig( + amortized_computation_config=EigenConfig() + ) + elif ( + preconditioner_computation_type + == PreconditionerComputationType.KL_EIGENDECOMPOSED_ROOT_INV + ): + return EigendecomposedKLShampooPreconditionerConfig() elif ( preconditioner_computation_type == PreconditionerComputationType.COUPLED_NEWTON_ROOT_INV diff --git a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py index a5b6d97..6d1feab 100644 --- a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py +++ b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py @@ -1270,6 +1270,25 @@ def compress_preconditioner_list( self._local_preconditioned_dims_selector_list, local_grad_selector ) + @profile_decorator + def _compute_outer_product_list( + self, + grad: Tensor, + order: int, + preconditioned_dims_selector: tuple[bool, ...], + kronecker_factors: _ShampooKroneckerFactorsUnwrappedType, + ) -> tuple[Tensor, ...]: + # Construct outer product list for updating Kronecker factors. + return tuple( + torch.tensordot( + grad, + grad, + # Contracts across all dimensions except for k. + dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] + ) + for k in compress_list(range(order), preconditioned_dims_selector) + ) + @profile_decorator def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None: # NOTE: Unlike AdagradPreconditionerList, we will loop through each gradient individually. @@ -1286,15 +1305,8 @@ def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None: if not kronecker_factors.factor_matrices: continue - # Construct outer product list for updating Kronecker factors. - outer_product_list = tuple( - torch.tensordot( - grad, - grad, - # Contracts across all dimensions except for k. - dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] - ) - for k in compress_list(range(order), preconditioned_dims_selector) + outer_product_list = self._compute_outer_product_list( + grad, order, preconditioned_dims_selector, kronecker_factors ) if self._beta2 != 1.0: @@ -1673,7 +1685,6 @@ def update_preconditioners( kronecker_factors.corrected_eigenvalues.mul_(self._beta2) # NOTE: The case when self._weighting_factor == 1.0 is not well tested and might not be stable. - print(f"{self._weighting_factor=}") kronecker_factors.corrected_eigenvalues.addcmul_( grad, grad, value=self._weighting_factor ) @@ -1707,3 +1718,104 @@ def _compute_preconditioned_gradient( preconditioner_list=kronecker_factors.factor_matrices_eigenvectors, dims=([0], [1]), ) + + +class RootInvKLShampooPreconditionerList(RootInvShampooPreconditionerList): + """Root inverse KL-Shampoo preconditioners for list of parameters.""" + + @profile_decorator + def _compute_outer_product_list( + self, + grad: Tensor, + order: int, + preconditioned_dims_selector: tuple[bool, ...], + kronecker_factors: RootInvShampooKroneckerFactorsUnwrapped, + ) -> tuple[Tensor, ...]: + # Construct outer product list for updating Kronecker factors. + outer_product_list = [] + for idx_of_k, k in enumerate( + compress_list(range(order), preconditioned_dims_selector) + ): + # KL-Shampoo uses the gradient preconditioned (along all dimensions that are contracted) with the inverse root of the factor matrices to compute the outer products. + local_preconditioned_dims_selector = list(preconditioned_dims_selector) + local_preconditioned_dims_selector[k] = False + preconditioned_grad = self._precondition_grad( + grad=grad, + preconditioned_dims_selector=tuple(local_preconditioned_dims_selector), + preconditioner_list=tuple( + inv_factor_matrix + for idx, inv_factor_matrix in enumerate( + kronecker_factors.inv_factor_matrices + ) + if idx != idx_of_k + ), + ) + outer_product_list.append( + torch.tensordot( + preconditioned_grad, + preconditioned_grad, + # Contracts across all dimensions except for k. + dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] + ) + ) + return tuple(outer_product_list) + + +class EigendecomposedKLShampooPreconditionerList( + EigendecomposedShampooPreconditionerList +): + """Eigendecomposed KL-Shampoo preconditioners for list of parameters.""" + + @profile_decorator + def _compute_outer_product_list( + self, + grad: Tensor, + order: int, + preconditioned_dims_selector: tuple[bool, ...], + kronecker_factors: EigendecomposedShampooKroneckerFactorsUnwrapped, + ) -> tuple[Tensor, ...]: + # TODO: remove assertion when rank_deficient_stability_config is generalized to MatrixFunctionConfig + assert isinstance( + self._preconditioner_config.amortized_computation_config, + EigendecompositionConfig, + ) + rank_deficient_stability_config = self._preconditioner_config.amortized_computation_config.rank_deficient_stability_config + + # Construct outer product list for updating Kronecker factors. + outer_product_list = [] + for idx_of_k, k in enumerate( + compress_list(range(order), preconditioned_dims_selector) + ): + # KL-Shampoo uses the gradient preconditioned (along all dimensions that are contracted) with the inverse root of the factor matrices to compute the outer products. + local_preconditioned_dims_selector = list(preconditioned_dims_selector) + local_preconditioned_dims_selector[k] = False + preconditioned_grad = self._precondition_grad( + grad=grad, + preconditioned_dims_selector=tuple(local_preconditioned_dims_selector), + preconditioner_list=tuple( + matrix_inverse_root_from_eigendecomposition( + L=eigenvalues, + Q=eigenvectors, + root=Fraction(root), + epsilon=self._epsilon, + rank_deficient_stability_config=rank_deficient_stability_config, + ) + for idx, (eigenvalues, eigenvectors, root) in enumerate( + zip( + kronecker_factors.factor_matrices_eigenvalues, + kronecker_factors.factor_matrices_eigenvectors, + kronecker_factors.roots, + strict=True, + ) + ) + if idx != idx_of_k + ), + ) + outer_product = torch.tensordot( + preconditioned_grad, + preconditioned_grad, + # Contracts across all dimensions except for k. + dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] + ) + outer_product_list.append(outer_product) + return tuple(outer_product_list) diff --git a/distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py index bf3e14b..3928477 100644 --- a/distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py @@ -29,8 +29,10 @@ from distributed_shampoo.preconditioner.preconditioner_list import PreconditionerList from distributed_shampoo.preconditioner.shampoo_preconditioner_list import ( BaseShampooPreconditionerList, + EigendecomposedKLShampooPreconditionerList, EigendecomposedShampooPreconditionerList, EigenvalueCorrectedShampooPreconditionerList, + RootInvKLShampooPreconditionerList, RootInvShampooPreconditionerList, ) from distributed_shampoo.preconditioner.tests.preconditioner_list_test_utils import ( @@ -40,9 +42,11 @@ AmortizedPreconditionerConfig, DefaultShampooConfig, DefaultSOAPConfig, + EigendecomposedKLShampooPreconditionerConfig, EigendecomposedShampooPreconditionerConfig, EigenvalueCorrectedShampooPreconditionerConfig, PreconditionerValueError, + RootInvKLShampooPreconditionerConfig, RootInvShampooPreconditionerConfig, ShampooPreconditionerConfig, ) @@ -1342,3 +1346,138 @@ def test_inverse_exponent_override(self) -> None: masked_grad_lists=[masked_grad_list1, masked_grad_list2], masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, ) + + +class RootInvKLShampooPreconditionerListTest(RootInvShampooPreconditionerListTest): + @property + def _default_preconditioner_config(self) -> RootInvShampooPreconditionerConfig: + return replace( + RootInvKLShampooPreconditionerConfig(), + factor_matrix_dtype=torch.float64, + ) + + @property + def _preconditioner_list_factory(self) -> Callable[..., PreconditionerList]: + return RootInvKLShampooPreconditionerList + + @unittest.skip( + "RootInvKLShampooPreconditionerList does not support adaptive computation frequency." + ) + def test_adaptive_amortized_computation_frequency(self) -> None: ... + + +class EigendecomposedKLShampooPreconditionerListTest( + EigendecomposedShampooPreconditionerListTest +): + @property + def _default_preconditioner_config( # type: ignore[override] + self, + ) -> EigendecomposedKLShampooPreconditionerConfig: + return EigendecomposedKLShampooPreconditionerConfig( + amortized_computation_config=QREigendecompositionConfig(), + factor_matrix_dtype=torch.float64, + factor_matrix_eigenvectors_dtype=torch.float64, + factor_matrix_eigenvalues_dtype=torch.float64, + ) + + @property + def _preconditioner_list_factory(self) -> Callable[..., PreconditionerList]: + return EigendecomposedKLShampooPreconditionerList + + def test_update_preconditioners_and_precondition_with_epsilon(self) -> None: + """ + We provide examples where we deliberately choose a large epsilon. This is to ensure that + the matrix inverse computation behaves as expected. Below we update the preconditioners twice + for 4 different blocks and check if the preconditioned gradients are as expected. When + performing the inverse computation, epsilon is chosen to be 80 in (L + epsilon * I) and + (R + epsilon * I). + + G_{ij}: at the i-th step, the gradient for the j-th block. For example, + G12 is the gradient for the second block, at the first step. + + For KL-Shampoo, we will correct the scale of the 2D gradients such that the Kronecker factors + match the ones of regular Shampoo with a scale correction s. + + epsilon = 80.0 + s = epsilon^{1/4} + + Gradients for block 1: (no right preconditioner) + (1) 1D Tensor of Size 2 + G11 = [1, 0]^T + G21 = [0, 1]^T + + L = G11 * G11^T + G21 * G21^T = [[1, 0], [0, 1]] + P = (L + epsilon * I)^{-1/2} G21 = [[81, 0], [0, 81]]^{-1/2} [0, 1]^T + = [0, 1/9]^T. + + Gradients for block 2: (both left and right preconditioner) + (2) Tensor of Size 2 x 2 + G12 = s * [[1, 0], [0, 1]] / sqrt(2) + G22 = s * [[1, 0], [0, 1]] / sqrt(2) + + L = G12 * G12^T + G22 * G22^T = [[1, 0], [0, 1]] + R = G12^T * G12 + G22^T * G22 = [[1, 0], [0, 1]] + P = (L + epsilon * I)^{-1/4} G22 (R + epsilon * I)^{-1/4} + = [[1/3, 0], [0, 1/3]] * G22 * [[1/3, 0], [0, 1/3]] = s * I / (9 * sqrt(2)) + + Gradients for block 3: (both left and right preconditioner) + (3) Tensor of Size 1 x 2 + G13 = s * [[1, 0]] + G23 = s * [[0, 1]] + + L = G13 * G13^T + G23 * G23^T = I + R = G13^T * G13 + G23^T * G23 = 2 + P = (L + epsilon * I)^{-1/4} G22 (R + epsilon * I)^{-1/4} + = [[1/3, 0], [0, 1/3]] * G22 * (80 + 2)^{-1/4} = + = s * [[0.0, 1.0/3.0 * 82.0 ** (-1/4)]] + + Gradients for block 4: (no preconditioner) + (4) Tensor of Size 0 + G14 = 1 + G24 = 1 + + No preconditioner is applied. Expected gradient is 1. + + """ + + epsilon = 80.0 + scale_correction = epsilon**0.25 + + # Blocked gradients at the first step: masked_grad_list1 = (G11, G12, G13, G14) + masked_grad_list1 = ( + torch.tensor([1.0, 0.0]), + scale_correction * torch.eye(2) / math.sqrt(2), + scale_correction * torch.tensor([[1.0, 0.0]]), + torch.tensor(1.0), + ) + + # Blocked gradients at the second step: masked_grad_list2 = (G21, G22, G23, G24) + masked_grad_list2 = ( + torch.tensor([0.0, 1.0]), + scale_correction * torch.eye(2) / math.sqrt(2), + scale_correction * torch.tensor([[0.0, 1.0]]), + torch.tensor(1.0), + ) + + # Manually apply the preconditioners to the gradients at the second step (masked_grad_list2) with epsilon. + # The result is stored in masked_expected_preconditioned_grad_list. + masked_expected_preconditioned_grad_list = ( + torch.tensor([0.0, 1.0 / 9.0]), + scale_correction * torch.eye(2) / (9 * math.sqrt(2)), + scale_correction * torch.tensor([[0.0, 1.0 / 3.0 * 82.0 ** (-1 / 4)]]), + torch.tensor(1.0), + ) + + # Apply preconditioner to the last step (masked_grad_list2) with epsilon. The result should be the same as the expected preconditioned grad list. + self._verify_preconditioner_updates( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=1.0, + weighting_factor=1.0, + use_bias_correction=True, + epsilon=epsilon, + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=tuple( + masked_expected_preconditioned_grad_list + ), + ) diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 20bdfe2..6c1786b 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -513,6 +513,104 @@ def __post_init__(self) -> None: ) +@dataclass(kw_only=True) +class RootInvKLShampooPreconditionerConfig(RootInvShampooPreconditionerConfig): + """Configuration for KL-Shampoo preconditioner computation with caching of the root inverse factor matrices. + + Attributes: + amortized_computation_config (RootInvConfig): Configuration for the inverse-root computation. (Default: DefaultEigenConfig) + num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3) + factor_matrix_dtype (torch.dtype): Data type for factor matrix. (Default: torch.float32) + inverse_exponent_override (dict[int, dict[int, float] | float]): The inverse_exponent_override attribute is a dictionary that allows for customizing the inverse exponent used in the KL-Shampoo preconditioner computation. + The keys of the dictionary represent the order of the tensor, and the values are either dictionaries with dimension indices as keys and override values as values, or a single float value for all dimensions. All unspecified dimensions use a default exponent of 1/(2*max(o,1)), where o is the order of the tensor. (Default: {}) + + As an example, suppose inverse_exponent_override={2: 0.2, 3: {0: 0.0, 1: 0.25}}. In this case, all 1-D tensors will use the default exponent of 0.5 for preconditioning the first (and only) dimension. All 2-D tensors will be preconditioned with an exponent of 0.2 on all dimensions. All 3-D tensors will have the first dimension be preconditioned with an exponent of 0.5, the second dimension not preconditioned, and the third dimension preconditioned with the default exponent 0.1667. + A visualization of this example can be seen below: + 1-D: + +-------x-------+ + | + | + (^0.5), the default inverse exponent 1/(2*1) since inverse_exponent_override[1] is not specified + 2-D: + +-----------+ + | | + | | + | |-----(^0.2), as specified by inverse_exponent_override[2]=0.2 + | | + | | + +-----------+ + | + | + (^0.2), as specified by inverse_exponent_override[2]=0.2 + 3-D: + +---------------+ + / /| + / / | + +---------------+ | + | | | + | | -|---(^0.25), as specified by inverse_exponent_override[3][1]=0.25 + | | + + | | / + | |/\ + +---------------+ \ + | (^0.1667), the default inverse exponent 1/(2*3) since inverse_exponent_override[3][2] is not specified + | + no preconditioning since inverse_exponent_override[3][0]=0.0 + + + """ + + +@dataclass(kw_only=True) +class EigendecomposedKLShampooPreconditionerConfig( + EigendecomposedShampooPreconditionerConfig +): + """Configuration for KL-Shampoo preconditioner computation with caching of the eigendecomposed factor matrices. + + Attributes: + amortized_computation_config (EigendecompositionConfig): Configuration for the eigendecomposition computation. (Default: DefaultEigendecompositionConfig) + num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3) + factor_matrix_dtype (torch.dtype): Data type for factor matrix. (Default: torch.float32) + inverse_exponent_override (dict[int, dict[int, float] | float]): The inverse_exponent_override attribute is a dictionary that allows for customizing the inverse exponent used in the KL-Shampoo preconditioner computation. + The keys of the dictionary represent the order of the tensor, and the values are either dictionaries with dimension indices as keys and override values as values, or a single float value for all dimensions. All unspecified dimensions use a default exponent of 1/(2*max(o,1)), where o is the order of the tensor. (Default: {}) + + As an example, suppose inverse_exponent_override={2: 0.2, 3: {0: 0.0, 1: 0.25}}. In this case, all 1-D tensors will use the default exponent of 0.5 for preconditioning the first (and only) dimension. All 2-D tensors will be preconditioned with an exponent of 0.2 on all dimensions. All 3-D tensors will have the first dimension be preconditioned with an exponent of 0.5, the second dimension not preconditioned, and the third dimension preconditioned with the default exponent 0.1667. + A visualization of this example can be seen below: + 1-D: + +-------x-------+ + | + | + (^0.5), the default inverse exponent 1/(2*1) since inverse_exponent_override[1] is not specified + 2-D: + +-----------+ + | | + | | + | |-----(^0.2), as specified by inverse_exponent_override[2]=0.2 + | | + | | + +-----------+ + | + | + (^0.2), as specified by inverse_exponent_override[2]=0.2 + 3-D: + +---------------+ + / /| + / / | + +---------------+ | + | | | + | | -|---(^0.25), as specified by inverse_exponent_override[3][1]=0.25 + | | + + | | / + | |/\ + +---------------+ \ + | (^0.1667), the default inverse exponent 1/(2*3) since inverse_exponent_override[3][2] is not specified + | + no preconditioning since inverse_exponent_override[3][0]=0.0 + + + """ + + @dataclass(kw_only=True) class SpectralDescentPreconditionerConfig(PreconditionerConfig): """Configuration for spectral descent computation in DistributedShampoo. diff --git a/distributed_shampoo/tests/distributed_shampoo_test.py b/distributed_shampoo/tests/distributed_shampoo_test.py index dca5d0a..56be68f 100644 --- a/distributed_shampoo/tests/distributed_shampoo_test.py +++ b/distributed_shampoo/tests/distributed_shampoo_test.py @@ -34,9 +34,11 @@ DefaultSingleDeviceDistributedConfig, DefaultSpectralDescentPreconditionerConfig, DistributedConfig, + EigendecomposedKLShampooPreconditionerConfig, EigendecomposedShampooPreconditionerConfig, EigenvalueCorrectedShampooPreconditionerConfig, PreconditionerConfig, + RootInvKLShampooPreconditionerConfig, RootInvShampooPreconditionerConfig, ShampooPT2CompileConfig, SignDescentPreconditionerConfig, @@ -1096,6 +1098,18 @@ def _ref_state_dict(self) -> dict[str, Any]: } +class RootInvKLShampooStateDictTest(ShampooStateDictTest): + @property + def _preconditioner_config(self) -> RootInvKLShampooPreconditionerConfig: + return RootInvKLShampooPreconditionerConfig() + + +class EigendecomposedKLShampooStateDictTest(EigendecomposedShampooStateDictTest): + @property + def _preconditioner_config(self) -> EigendecomposedKLShampooPreconditionerConfig: + return EigendecomposedKLShampooPreconditionerConfig() + + class SignDescentStateDictTest(AbstractTest.NoPreconditionerStateDictTestBase): @property def _preconditioner_config(self) -> SignDescentPreconditionerConfig: