From fe9010e00512fe552896b0c947ea080afab594bb Mon Sep 17 00:00:00 2001 From: Armin Arjmand Date: Thu, 18 Jun 2020 13:55:11 +0430 Subject: [PATCH 1/2] trace sqrt correction --- objectives.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/objectives.py b/objectives.py index 852d6eb..c47417b 100644 --- a/objectives.py +++ b/objectives.py @@ -71,7 +71,9 @@ def loss(self, H1, H2): if self.use_all_singular_values: # all singular values are used to calculate the correlation - tmp = torch.matmul(Tval.t(), Tval) + tmp = torch.diag(torch.matmul(Tval.t(), Tval)) + tmp = torch.where(tmp>eps, tmp, (torch.ones(tmp.shape).double()*eps).to(self.device)) + tmp = torch.sum(torch.sqrt(tmp)) corr = torch.trace(torch.sqrt(tmp)) # assert torch.isnan(corr).item() == 0 else: From 92b9ad8c6063eedb3c2a5cf8d7f6e7d698e7a305 Mon Sep 17 00:00:00 2001 From: Armin Arjmand Date: Thu, 18 Jun 2020 13:57:52 +0430 Subject: [PATCH 2/2] trace sqrt correction --- objectives.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/objectives.py b/objectives.py index c47417b..2d757e7 100644 --- a/objectives.py +++ b/objectives.py @@ -73,8 +73,7 @@ def loss(self, H1, H2): # all singular values are used to calculate the correlation tmp = torch.diag(torch.matmul(Tval.t(), Tval)) tmp = torch.where(tmp>eps, tmp, (torch.ones(tmp.shape).double()*eps).to(self.device)) - tmp = torch.sum(torch.sqrt(tmp)) - corr = torch.trace(torch.sqrt(tmp)) + corr = torch.sum(torch.sqrt(tmp)) # assert torch.isnan(corr).item() == 0 else: # just the top self.outdim_size singular values are used