Ensure PSD-safe factorization in constructor of MultivariateNormal#2297
Ensure PSD-safe factorization in constructor of MultivariateNormal#2297SebastianAment wants to merge 1 commit intocornellius-gp:mainfrom
MultivariateNormal#2297Conversation
e06f8da to
d0e35e4
Compare
… to ensure PSD-safe factorization
d0e35e4 to
570d43f
Compare
Balandat
left a comment
There was a problem hiding this comment.
Thanks! An alternative to making everything into a LinearOperator would be upon receiving a dense tensor, rather than just storing that and passing it along, to compute the cholesky decomposition with jitted and then pass that as the scale_tril to the torch distribution. The downside of that is ofc that this would do a lot of compute upon construction of the object, potentially unnecessary. Another hack could be to mock some of the torch distribution code so that it users the psd-safe cholesky decomposition internally (though that seems very hacky and potentially problematic).
| # will fail if the covariance matrix is semi-definite, whereas DenseLinearOperator ends up | ||
| # calling _psd_safe_cholesky, which factorizes semi-definite matrices by adding to the diagonal. | ||
| if isinstance(covariance_matrix, Tensor): | ||
| self._islazy = False # to allow _unbroadcasted_scale_tril setter |
There was a problem hiding this comment.
It seems odd to have _islazy set to True if the covariance matrix is indeed a LinearOperator. I guess the "lazy" nomenclature is a bit outdated anyway with the move to LinearOperator.
|
|
||
| event_shape = self.loc.shape[-1:] | ||
|
|
||
| # TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic |
There was a problem hiding this comment.
Do you mean changing the torch code to validate LinearOperator inputs? That might be somewhat challenging to do if we want to use LinearOperators there explicitly. What would work is to make changes in pure torch that would make it easier to use LinearOperator objects by means of the __torch_function__ interface we define in LinearOperator.
There was a recent BoTorch issue that was caused by a positive semi-definite matrix being passed to
MultivariateNormalas aTensor, which causes the constructor to fail because PyTorch's constructor callscholeskyon the tensor. This commit upstreams the corresponding BoTorch PR to ensure that all covariance matrices areLinearOperatortypes, thereby triggering_psd_safe_cholesky, whenevercholeskyis called.