This request comes from #5890. Currently, nvFuser uses too much memory by allgathering one of the einsum's operands.
A better (not sure if the best) approach is to stream-parallelize the allgather and the reducescatter:
This way, each GPU only has to store O(b * s/dy * s/dx * c).
In nvFuser, this can be represented as
Note that
is a Swizzle1D similar to
|
// (streamIdx + deviceIdx.x) % deviceDim.x |