-
Notifications
You must be signed in to change notification settings - Fork 78
Broadcast-based allgather in host for-loop #5925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b2c6b0b
9fed72f
2825e4f
9df94e8
5ee8365
5d604e1
901f809
d21deed
6e862cb
33ff8fc
be4f66b
493c4ad
4bf2d11
a0b190c
1c23e50
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -167,6 +167,32 @@ void lowerToBroadcast( | |
| backend)); | ||
| } | ||
|
|
||
| void lowerToStreamBroadcast( | ||
| TensorView* input_tv, | ||
| TensorView* output_tv, | ||
| const CommunicatorBackend backend, | ||
| std::vector<Expr*>& comms, | ||
| Val* root) { | ||
| const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); | ||
| const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); | ||
| NVF_ERROR_EQ( | ||
| sender_mesh, | ||
| receiver_mesh, | ||
| "StreamBroadcast sender and receiver meshes must be the same. Given ", | ||
| sender_mesh, | ||
| " and ", | ||
| receiver_mesh); | ||
| Team team = receiver_mesh.vector(); | ||
| comms.push_back(IrBuilder::create<Communication>( | ||
| CommunicationType::StreamBroadcast, | ||
| output_tv, | ||
| input_tv, | ||
| team, | ||
| root, | ||
| c10d::ReduceOp::RedOpType::UNUSED, | ||
| backend)); | ||
|
Comment on lines
+185
to
+193
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. StreamBroadcast root is a raw loop index, not a device ID
For a mesh such as The root should be the mesh device at position
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Communication can accept a Val as root or a DeviceIdx |
||
| } | ||
|
|
||
| // Adds several SendRecv communications to the vector 'comms' | ||
| // For now, we assume that this function is called only if | ||
| // the input and output have the same sharding. Later we could support more | ||
|
|
@@ -371,14 +397,17 @@ CommunicationInfo getCommunicationInfo(Expr* e) { | |
| const auto c2p_map = pairwise_map.mapConsumerToProducer(); | ||
|
|
||
| // This ignores device dimensions on reduction axis. | ||
| auto producer_pt_to_did = | ||
| const std::unordered_map<ParallelType, IterDomain*>& producer_pt_to_id = | ||
| mapDeviceAndStreamParallelTypeToId(producer->getLoopDomain()); | ||
| auto consumer_pt_to_did = | ||
| const std::unordered_map<ParallelType, IterDomain*>& consumer_pt_to_id = | ||
| mapDeviceAndStreamParallelTypeToId(consumer->getLoopDomain()); | ||
|
|
||
| IterDomain* c_stream_id = | ||
| getOrDefault(consumer_pt_to_id, ParallelType::Stream); | ||
|
|
||
| for (ParallelType pt : kParallelTypeDIDs) { | ||
| IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt); | ||
| IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt); | ||
| IterDomain* p_loop_did = getOrDefault(producer_pt_to_id, pt); | ||
| IterDomain* c_loop_did = getOrDefault(consumer_pt_to_id, pt); | ||
|
|
||
| if (p_loop_did == nullptr && c_loop_did == nullptr) { | ||
| // Not sharded on this parallel type | ||
|
|
@@ -392,6 +421,25 @@ CommunicationInfo getCommunicationInfo(Expr* e) { | |
| if (e->isA<LoadStoreOp>()) { | ||
| if (p_loop_did && !c_loop_did) { | ||
| IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); | ||
| // Check if we are going from DIDx -> Stream, which is a ring allgather. | ||
| // This can be executed as a broadcast or send recvs, which is decided | ||
| // by the presence of a swizzle in the stream id definition. | ||
| // TODO: Lower to SendRecv if swizzle is present. | ||
| if (c_stream_id != nullptr) { | ||
| IterDomain* c_stream_logical_id = | ||
| getLogicalFromLoopId(consumer, c_stream_id); | ||
| if (c_stream_logical_id == p2c_map.at(p_logical_id)) { | ||
| NVF_CHECK( | ||
| same_mesh, | ||
| "Broadcast based allgather in stream parallel requires same " | ||
| "mesh."); | ||
| fill_communication_info( | ||
| CommunicationType::StreamBroadcast, | ||
| p_logical_id, | ||
| c_stream_logical_id); | ||
| continue; | ||
| } | ||
| } | ||
| CommunicationType type = same_mesh ? CommunicationType::Allgather | ||
| : CommunicationType::Gather; | ||
| fill_communication_info(type, p_logical_id, p2c_map.at(p_logical_id)); | ||
|
|
@@ -479,7 +527,8 @@ Layout getCommunicationLayout( | |
| type == CommunicationType::Allreduce || | ||
| type == CommunicationType::Broadcast || | ||
| type == CommunicationType::SendRecv || | ||
| type == CommunicationType::AllToAll) { | ||
| type == CommunicationType::AllToAll || | ||
| type == CommunicationType::StreamBroadcast) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understood the motivation but can this be consolidated into the same Broadcast?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I kept it separate so I don't need to check for the StreamParallel Type in |
||
| return layout; | ||
| } | ||
|
|
||
|
|
@@ -537,6 +586,7 @@ bool isCommunicationLayoutCompliant(Expr* expr) { | |
| std::vector<Expr*> convertSingleOpToCommunication( | ||
| Expr* e, | ||
| DeviceIdxType my_device_idx, | ||
| Val* root, | ||
| const CommunicatorBackend backend) { | ||
| FusionGuard fg(e->fusion()); | ||
|
|
||
|
|
@@ -606,6 +656,12 @@ std::vector<Expr*> convertSingleOpToCommunication( | |
| case CommunicationType::AllToAll: | ||
| lowerToAllToAll(input_tv, output_tv, backend, comms); | ||
| break; | ||
| case CommunicationType::StreamBroadcast: | ||
| NVF_ERROR( | ||
| root != nullptr, | ||
| "StreamBroadcast requires a root value passed in through lowering"); | ||
| lowerToStreamBroadcast(input_tv, output_tv, backend, comms, root); | ||
| break; | ||
| } | ||
|
|
||
| return comms; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -145,7 +145,10 @@ Expr* cloneWithNewOperands( | |
| return e; | ||
| } | ||
|
|
||
| return e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes()); | ||
| auto* new_e = | ||
| e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes()); | ||
| e->container()->removeExpr(e); | ||
| return new_e; | ||
| } | ||
|
|
||
| void lowerSegment( | ||
|
|
@@ -178,15 +181,17 @@ void lowerSegment( | |
| // TODO: `replacement_map` should be associated with the scope so | ||
| // ShardByStream across segments in the same for-loop can be reused. | ||
| std::unordered_map<Val*, Val*> replacement_map; | ||
| for (Expr* c : convertSingleOpToCommunication(e, device_id)) { | ||
| Val* root = loop_nest.empty() ? nullptr : innermost.loop->index(); | ||
| for (Expr* c : convertSingleOpToCommunication(e, device_id, root)) { | ||
| NVF_ERROR( | ||
| c->isA<Communication>(), | ||
| "Exprs in a Communication group should be Communication: ", | ||
| c); | ||
| auto* communication = c->as<Communication>(); | ||
| TensorView* in = communication->in(); | ||
| TensorView* out = communication->out(); | ||
| if (haveDifferentShardings( | ||
| if (communication->type() != CommunicationType::StreamBroadcast && | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While I understood the motivation and that the tests pass, I'm thinking how to make this cleaner. Is it possible to frame this as an optimization? For example, if
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I do think this should be merged in the shardByStream or some other logic. Let me see what I can do in this PR itself. |
||
| haveDifferentShardings( | ||
| in, | ||
| DomainType::kAllocation, | ||
| out, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -89,8 +89,10 @@ TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e) { | |
| destination, ParallelType::Stream, DomainType::kAllocation) != | ||
| nullptr, | ||
| "Destination allocation should be sharded on stream after " | ||
| "shardAllocationAsLoop: ", | ||
| destination); | ||
| "shardAllocationAsLoop. ", | ||
| destination->name(), | ||
| ":", | ||
| destination->domain()->toString(0, /*loop_only=*/false)); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| // Refine the contiguity flags so `out` aliases `in`. This is done similar | ||
| // to AliasFinder::handle(const SliceOp*). We scan through the allocation | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -57,6 +57,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { | |
| case CommunicationType::AllToAll: | ||
| os << "AllToAll"; | ||
| break; | ||
| case CommunicationType::StreamBroadcast: | ||
| os << "StreamBroadcast"; | ||
| break; | ||
| } | ||
| return os; | ||
| } | ||
|
|
@@ -149,6 +152,7 @@ bool hasRoot(CommunicationType type) { | |
| case CommunicationType::Reduce: | ||
| case CommunicationType::Broadcast: | ||
| case CommunicationType::SendRecv: | ||
| case CommunicationType::StreamBroadcast: | ||
| return true; | ||
| case CommunicationType::Allgather: | ||
| case CommunicationType::Allreduce: | ||
|
|
@@ -171,6 +175,7 @@ bool isReduction(CommunicationType type) { | |
| case CommunicationType::Broadcast: | ||
| case CommunicationType::SendRecv: | ||
| case CommunicationType::AllToAll: | ||
| case CommunicationType::StreamBroadcast: | ||
| return false; | ||
| default: | ||
| NVF_THROW("unrecognized CommunicationType: ", type); | ||
|
|
@@ -853,6 +858,7 @@ c10::intrusive_ptr<c10d::Work> postSingleCommunication( | |
| input_tensor, | ||
| output_tensor); | ||
| case CommunicationType::Broadcast: | ||
| case CommunicationType::StreamBroadcast: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| return postBroadcast( | ||
| communication, | ||
| my_device_index, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
teamcontains absolute device IDs (e.g.{4, 5, 6, 7}), butrootis a loop indexVal*that evaluates to0, 1, 2, ...at runtime. InpostBroadcast(communication.cpp:474), the code checksif (my_device_index == root_index)and callsgetRootRelativeIndex(root_index)which doesstd::find(team.begin(), team.end(), root_index). Both expectroot_indexto be an actual device ID from the team, not a loop iteration index. For a mesh like{4,5,6,7}, loop index0won't match any device and the assertion at communication.cpp:255 will fire.Root should be the device ID at the loop index position:
receiver_mesh.at(loop_index). Sinceloop_indexis evaluated at runtime, you'll need to compute this device ID lookup at evaluation time (e.g., via an IR expression that indexes into the mesh).