diff --git a/objectives.py b/objectives.py index 852d6eb..2d757e7 100644 --- a/objectives.py +++ b/objectives.py @@ -71,8 +71,9 @@ def loss(self, H1, H2): if self.use_all_singular_values: # all singular values are used to calculate the correlation - tmp = torch.matmul(Tval.t(), Tval) - corr = torch.trace(torch.sqrt(tmp)) + tmp = torch.diag(torch.matmul(Tval.t(), Tval)) + tmp = torch.where(tmp>eps, tmp, (torch.ones(tmp.shape).double()*eps).to(self.device)) + corr = torch.sum(torch.sqrt(tmp)) # assert torch.isnan(corr).item() == 0 else: # just the top self.outdim_size singular values are used