Conversation
|
Review updated until commit 493c4ad Description
|
| Relevant files | |||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||||||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| 🔒 No security concerns identified |
| ⚡ Recommended focus areas for review |
Missing error handling
|
|
!test |
Greptile OverviewGreptile SummaryImplements broadcast-based allgather for DIDx→Stream resharding by introducing a new Key changes:
Issues found:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Lowering as Host Lowering
participant Conv as convertSingleOpToCommunication
participant GetInfo as getCommunicationInfo
participant Lower as lowerToStreamBroadcast
participant Post as postSingleCommunication
Note over Lowering: Detects DIDx→Stream resharding
Lowering->>Conv: convertSingleOpToCommunication(expr, device_id, loop_index)
Conv->>GetInfo: getCommunicationInfo(expr)
Note over GetInfo: Checks producer DIDx vs consumer Stream
GetInfo-->>Conv: CommunicationType::StreamBroadcast
Conv->>Lower: lowerToStreamBroadcast(in, out, backend, comms, root=loop_index)
Note over Lower: Creates Communication with root=loop_index
Lower-->>Conv: Communication expr
Conv-->>Lowering: StreamBroadcast communication
Note over Lowering: Skips sharding validation for StreamBroadcast
Lowering->>Post: Execute at runtime
Note over Post: Routes to postBroadcast (same as Broadcast)
Last reviewed commit: 493c4ad |
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
|
!test |
Additional Comments (1)
Consider extending validation to require |
|
!test |
| std::vector<Expr*> convertSingleOpToCommunication( | ||
| Expr* c, | ||
| DeviceIdxType my_device_idx, | ||
| Val* host_loop_index = nullptr, | ||
| const CommunicatorBackend backend = CommunicatorBackend::kNccl); |
There was a problem hiding this comment.
StreamBroadcast not handled
convertSingleOpToCommunication now defaults host_loop_index to nullptr, but it can return CommunicationType::StreamBroadcast from getCommunicationInfo and then hard-errors if host_loop_index == nullptr (lower_to_communication.cpp:657-663). This makes the existing ConvertOpToCommunication pass (which calls it at top-level with nullptr) crash for any fusion that produces this DIDx→Stream resharding outside the host-loop lowering path.
Additional Comments (1)
|
wujingyue
left a comment
There was a problem hiding this comment.
It's great to see this work functionally!
| type == CommunicationType::SendRecv || | ||
| type == CommunicationType::AllToAll) { | ||
| type == CommunicationType::AllToAll || | ||
| type == CommunicationType::StreamBroadcast) { |
There was a problem hiding this comment.
I understood the motivation but can this be consolidated into the same Broadcast?
There was a problem hiding this comment.
I kept it separate so I don't need to check for the StreamParallel Type in lowerToBroadcast when deciding the root. Posting the communication uses a common function.
I also wanted to first integrate SendRecv based decomposition and then reconsider the design based on what is needed for both these comms.
| "Destination allocation should be sharded on stream after " | ||
| "shardAllocationAsLoop: ", | ||
| destination); | ||
| destination->domain()->toString(0, /*loop_only=*/false)); |
There was a problem hiding this comment.
I guess destination is still worth printing in addition to the domain?
There was a problem hiding this comment.
destination prints the loop domain. I added name above to be printed in addition to the complete domain
| TensorView* in = communication->in(); | ||
| TensorView* out = communication->out(); | ||
| if (haveDifferentShardings( | ||
| if (communication->type() != CommunicationType::StreamBroadcast && |
There was a problem hiding this comment.
While I understood the motivation and that the tests pass, I'm thinking how to make this cleaner.
Is it possible to frame this as an optimization? For example, if in can be sharded on Stream in the same way as communication, insert a shardByStream.
There was a problem hiding this comment.
Yeah, I do think this should be merged in the shardByStream or some other logic.
For now, I kept it simple since I am not sure how it will look like with Collective Permute representation (a composite Communication, P2P comms corresponding to SendRecv etc.), so I took the verbose approach as an interim step.
Let me see what I can do in this PR itself.
|
|
||
| // This ignores device dimensions on reduction axis. | ||
| auto producer_pt_to_did = | ||
| auto producer_pt_to_id = |
There was a problem hiding this comment.
| auto producer_pt_to_id = | |
| const std::unordered_map<ParallelType, IterDomain*>& producer_pt_to_id = |
| auto producer_pt_to_id = | ||
| mapDeviceAndStreamParallelTypeToId(producer->getLoopDomain()); | ||
| auto consumer_pt_to_did = | ||
| auto consumer_pt_to_id = |
| std::vector<Expr*> convertSingleOpToCommunication( | ||
| Expr* e, | ||
| DeviceIdxType my_device_idx, | ||
| Val* host_loop_index, |
There was a problem hiding this comment.
| Val* host_loop_index, | |
| Val* root, |
Some communications (e.g. broadcast, reduce, gather, and scatter) are rooted. So far, we've been deciding the root according to device meshes. However, this use makes a case for passing in the root from the lowering process.
|
!test |
| input_tensor, | ||
| output_tensor); | ||
| case CommunicationType::Broadcast: | ||
| case CommunicationType::StreamBroadcast: |
There was a problem hiding this comment.
StreamBroadcast reuses broadcast logic but csrc/host_ir/evaluator.cpp:334-336 and :405-407 still reject it for the CUDA backend multicast path. If CUDA backend is intended to work with StreamBroadcast, add it to those checks; otherwise this will fail at runtime when using CommunicatorBackend::kCuda.
| input_tensor, | ||
| output_tensor); | ||
| case CommunicationType::Broadcast: | ||
| case CommunicationType::StreamBroadcast: |
There was a problem hiding this comment.
StreamBroadcast shares broadcast logic but csrc/multidevice/cuda_p2p.cpp:645-667 and :689-704 don't handle it in their switch statements. This will cause runtime errors with CUDA backend. Either add StreamBroadcast cases (treating them like Broadcast) or ensure CUDA backend is never used with this communication type.
The broadcast version is very slow so I am not comparing timings until we integrate this with multicast