The order of operations when using all singular values is currently sqrt(trace(T'T)) - note this is not the case when using topK.
Equation (10) in Andrew's original paper due to notation is slightly ambiguous but his description (and the derivation of gradients) suggest that the correct order is trace(sqrt(T'T)).
I've been working on fixing this with the main problem being that T'T here seems to have negative entries off the diagonals which in pytorch result in nan gradients. I think that tensorflow possibly doesn't produce nans for gradient sqrt(0) which is what allows tensorflow implementations to work out of the box with trace(sqrt(T'T)). Note: it doesn't actually matter what the backend produces in either case because the off diagonals should be thrown by the trace operations.
A short term fix is to push everything through the topk route?
The order of operations when using all singular values is currently sqrt(trace(T'T)) - note this is not the case when using topK.
Equation (10) in Andrew's original paper due to notation is slightly ambiguous but his description (and the derivation of gradients) suggest that the correct order is trace(sqrt(T'T)).
I've been working on fixing this with the main problem being that T'T here seems to have negative entries off the diagonals which in pytorch result in nan gradients. I think that tensorflow possibly doesn't produce nans for gradient sqrt(0) which is what allows tensorflow implementations to work out of the box with trace(sqrt(T'T)). Note: it doesn't actually matter what the backend produces in either case because the off diagonals should be thrown by the trace operations.
A short term fix is to push everything through the topk route?