Add decomposition tensor#34
Conversation
|
Isn't a mode better for decompositions? This is a more complicated version of the thing https://github.com/pytorch/pytorch/blob/25c6ebd12c094ca8b02e11cc12cf18102c55acfa/test/test_decomp.py#L377-L436 ; it both runs the decomp and the original |
|
Could there also be cases where a more fine-grained approach is preferred? For example if I have a subclass wrapping a backend, I only want to decompose when the computation involves a subclassed tensors to avoid the perf hit from decomposing the rest of the computation. |
|
|
||
| return tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) | ||
|
|
||
| # 3) Version using inheritance |
There was a problem hiding this comment.
I'm generally against implementing this kind of extra functionality with inheritance. Better to make sure there is some sort of subtyping relation if you're going to use inheritance.
| def wrapper(cls, func, types, args=(), kwargs=None): | ||
| if func in skip_list: | ||
| # Functions that the layers below are able to handle | ||
| return f(cls, func, types, args, kwargs) |
There was a problem hiding this comment.
how come unwrapping isn't needed in this version?
There was a problem hiding this comment.
Ahh f is the __torch_dispatch__ function not the aten op, so the unwrapping will still happen there. Maybe I should rename it to something better so that is clearer...
Exploring some possible UX for using decompositions with subclassing
cc @ezyang @albanD