Conversation
| # Gather will move "b" and "c" to the front for t and index respectively | ||
| # so we must force the order in order to compare to the original | ||
| # torch.gather. | ||
| ntensor = ntorch.gather(t, "b", index, "c")._force_order(("a", "c")) |
There was a problem hiding this comment.
This isn't a good unit test. It shouldn't call any _ functions.
There was a problem hiding this comment.
High level question: Since we don't assume any ordering, is the right approach to try all permutations of the output ntensor and pass if any of them succeed (equal base)? Or should wer try to keep the underlying order the same as torch.* (although this may be unclear for ntorch.gather since broadcasting isn't defined in torch.gather).
There was a problem hiding this comment.
I think the ntorch.equal function will now do this automatically. But either way isn't the function you want just .transpose? More importantly does this test prove to me that your change works?
There was a problem hiding this comment.
Great, thanks. I agree that the unit test doesn't test anything, but I wanted to ask about how to compare first. I can write a better test.
|
Any updates? |
Adds broadcasting support for gather by adding dimensions (unsqueezing through _force_order using an overlapping order) and expanding.