Skip to content

Reduce memory usage in distributed triangle updates #5942

@wujingyue

Description

@wujingyue

This request comes from #5890. Currently, nvFuser uses too much memory by allgathering one of the einsum's operands.

Image

A better (not sure if the best) approach is to stream-parallelize the allgather and the reducescatter:

Image

This way, each GPU only has to store O(b * s/dy * s/dx * c).

In nvFuser, this can be represented as

Image

Note that

dy
|
s

is a Swizzle1D similar to

// (streamIdx + deviceIdx.x) % deviceDim.x

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions