Skip to content

Broadcast-based allgather in host for-loop#5925

Open
Priya2698 wants to merge 12 commits intomainfrom
pm/stream_broadcast
Open

Broadcast-based allgather in host for-loop#5925
Priya2698 wants to merge 12 commits intomainfrom
pm/stream_broadcast

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Feb 6, 2026

Screenshot 2026-02-09 at 1 24 11 PM

The broadcast version is very slow so I am not comparing timings until we integrate this with multicast

@github-actions
Copy link

github-actions bot commented Feb 6, 2026

Review updated until commit 493c4ad

Description

  • Introduce StreamBroadcast communication type for broadcast-based allgather in host for-loops

  • Add lowerToStreamBroadcast function to create StreamBroadcast communications with loop index as root

  • Implement ring allgather detection from DIDx -> Stream parallel types

  • Add comprehensive test coverage for column-parallel linear forward with StreamBroadcast

  • Update communication infrastructure to handle new StreamBroadcast type alongside existing Broadcast

Changes walkthrough

Relevant files
Enhancement
lower_to_communication.cpp
Implement StreamBroadcast communication creation and detection

csrc/host_ir/lower_to_communication.cpp

  • Add lowerToStreamBroadcast function for creating StreamBroadcast
    communications
  • Implement ring allgather detection logic from DIDx to Stream parallel
    types
  • Update getCommunicationInfo to handle StreamBroadcast case with mesh
    validation
  • Add StreamBroadcast to layout compliance checks and conversion logic
  • +60/-5   
    lowering.cpp
    Update lowering to use loop index as StreamBroadcast root

    csrc/host_ir/lowering.cpp

  • Pass innermost loop index as root parameter to
    convertSingleOpToCommunication
  • Add condition to skip sharding checks for StreamBroadcast
    communications
  • +4/-2     
    ops.cpp
    Enhance error messaging in stream sharding                             

    csrc/host_ir/ops.cpp

    • Improve error message formatting in shardByStream function
    +4/-2     
    convert_op_to_communication.cpp
    Update communication conversion pass interface                     

    csrc/host_ir/pass/convert_op_to_communication.cpp

  • Update convertSingleOpToCommunication call to pass root=nulllet for
    pass implementation
  • +4/-1     
    communication.cpp
    Integrate StreamBroadcast into communication infrastructure

    csrc/multidevice/communication.cpp

  • Add StreamBroadcast case to operator<< overload for debugging
  • Include StreamBroadcast in hasRoot and isReduction functions
  • Update postSingleCommunication to handle StreamBroadcast using
    broadcast logic
  • +6/-0     
    lower_to_communication.h
    Update communication conversion interface with root parameter

    csrc/host_ir/lower_to_communication.h

  • Update convertSingleOpToCommunication signature to include optional
    root parameter
  • Add documentation explaining root parameter usage for StreamBroadcast
  • +6/-0     
    communication.h
    Add StreamBroadcast to communication type enum                     

    csrc/multidevice/communication.h

  • Add StreamBroadcast to CommunicationType enum
  • Update documentation to explain StreamBroadcast differences from
    Broadcast
  • +6/-1     
    Tests
    test_overlap.py
    Add comprehensive tests for StreamBroadcast functionality

    tests/python/multidevice/test_overlap.py

  • Add column_parallel_linear_forward function demonstrating
    StreamBroadcast usage
  • Add test_column_parallel_linear_forward to verify functionality with
    profiler validation
  • Add benchmark test for performance measurement of StreamBroadcast
    approach
  • +114/-0 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review
    Missing error handling

    In getCommunicationInfo function around line 430-440, when checking for StreamBroadcast (DIDx -> Stream transformation), if the same_mesh check fails, the code continues without setting any communication info. This could lead to silent failures where no communication is generated. Consider adding explicit error handling or fallback logic.

        NVF_CHECK(
            same_mesh,
            "Broadcast based allgather in stream parallel requires same "
            "mesh.");
        fill_communication_info(
            CommunicationType::StreamBroadcast,
            p_logical_id,
            c_stream_logical_id);
        continue;
      }
    }
    Incomplete TODO comment

    There's a TODO comment on line 425 mentioning "Lower to SendRecv if swizzle is present" but no implementation or further context. This suggests incomplete functionality that should be addressed or tracked properly.

    // TODO: Lower to SendRecv if swizzle is present.

    @Priya2698 Priya2698 marked this pull request as ready for review February 9, 2026 21:10
    @Priya2698 Priya2698 requested a review from wujingyue February 9, 2026 21:11
    @Priya2698
    Copy link
    Collaborator Author

    !test

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 9, 2026

    Greptile Overview

    Greptile Summary

    Implements broadcast-based allgather for DIDx→Stream resharding by introducing a new StreamBroadcast communication type. This allows host for-loop parallelization of allgather operations by decomposing them into broadcasts where the loop index serves as the root.

    Key changes:

    • Added StreamBroadcast to CommunicationType enum and wired it through the communication stack
    • Detection logic in getCommunicationInfo identifies DIDx→Stream resharding patterns
    • lowerToStreamBroadcast function creates broadcast operations using the host loop index as root
    • Runtime execution reuses existing postBroadcast infrastructure
    • Test coverage added for column-parallel linear forward pass

    Issues found:

    • CUDA backend multicast path in evaluator.cpp and cuda_p2p.cpp doesn't handle StreamBroadcast, which will cause runtime failures if CUDA backend is used

    Confidence Score: 3/5

    • Safe to merge for NCCL backend, but will fail at runtime if CUDA backend is used with StreamBroadcast
    • The implementation correctly adds StreamBroadcast type throughout most of the communication stack and test coverage is provided. However, the CUDA backend multicast path (evaluator.cpp and cuda_p2p.cpp) doesn't handle StreamBroadcast, which will cause runtime failures. Since the PR mentions waiting for multicast integration and focuses on NCCL backend for now, this may be intentional, but it creates a latent bug that should be addressed before CUDA backend is used.
    • csrc/host_ir/evaluator.cpp and csrc/multidevice/cuda_p2p.cpp need StreamBroadcast support added before CUDA backend can be used

    Important Files Changed

    Filename Overview
    csrc/host_ir/lower_to_communication.cpp Adds lowerToStreamBroadcast function and StreamBroadcast detection logic in getCommunicationInfo to handle DIDx→Stream resharding as broadcast-based allgather
    csrc/host_ir/lowering.cpp Passes loop index as root to convertSingleOpToCommunication and skips sharding validation for StreamBroadcast communications
    csrc/multidevice/communication.cpp Adds StreamBroadcast case to operator<<, hasRoot, isReduction, and postSingleCommunication (reuses broadcast logic)
    tests/python/multidevice/test_overlap.py Adds column-parallel linear forward test and benchmark to verify broadcast-based allgather with StreamBroadcast works correctly

    Sequence Diagram

    sequenceDiagram
        participant Lowering as Host Lowering
        participant Conv as convertSingleOpToCommunication
        participant GetInfo as getCommunicationInfo
        participant Lower as lowerToStreamBroadcast
        participant Post as postSingleCommunication
        
        Note over Lowering: Detects DIDx→Stream resharding
        Lowering->>Conv: convertSingleOpToCommunication(expr, device_id, loop_index)
        Conv->>GetInfo: getCommunicationInfo(expr)
        Note over GetInfo: Checks producer DIDx vs consumer Stream
        GetInfo-->>Conv: CommunicationType::StreamBroadcast
        Conv->>Lower: lowerToStreamBroadcast(in, out, backend, comms, root=loop_index)
        Note over Lower: Creates Communication with root=loop_index
        Lower-->>Conv: Communication expr
        Conv-->>Lowering: StreamBroadcast communication
        Note over Lowering: Skips sharding validation for StreamBroadcast
        Lowering->>Post: Execute at runtime
        Note over Post: Routes to postBroadcast (same as Broadcast)
    
    Loading

    Last reviewed commit: 493c4ad

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    @Priya2698
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 3 comments

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 9, 2026

    Additional Comments (1)

    csrc/multidevice/communication.cpp
    Root validation rejects non-const

    Communication::validate only enforces the root/type contract when root() is a const integral scalar. For StreamBroadcast, root is the host loop index (non-const), so hasRoot(type()) is never validated and invalid roots (e.g., non-integral or negative-at-runtime) can slip through. This can lead to runtime failures when postBroadcast interprets the root.

    Consider extending validation to require root() be Index dtype for StreamBroadcast/rooted collectives even when not constant, and/or add runtime checks where the root is consumed.

    @Priya2698
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines 54 to 58
    std::vector<Expr*> convertSingleOpToCommunication(
    Expr* c,
    DeviceIdxType my_device_idx,
    Val* host_loop_index = nullptr,
    const CommunicatorBackend backend = CommunicatorBackend::kNccl);
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    StreamBroadcast not handled

    convertSingleOpToCommunication now defaults host_loop_index to nullptr, but it can return CommunicationType::StreamBroadcast from getCommunicationInfo and then hard-errors if host_loop_index == nullptr (lower_to_communication.cpp:657-663). This makes the existing ConvertOpToCommunication pass (which calls it at top-level with nullptr) crash for any fusion that produces this DIDx→Stream resharding outside the host-loop lowering path.

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 9, 2026

    Additional Comments (1)

    csrc/multidevice/communication.cpp
    Non-constant root accepted

    Communication::validate only checks root/type consistency when root() is a const integral scalar (communication.cpp:238-246). For StreamBroadcast, the root is intentionally a non-const Val* (host loop index), so this validation becomes a no-op: invalid roots (e.g., negative at runtime, wrong dtype) won’t be rejected here but later code assumes a valid rank/root. If StreamBroadcast relies on runtime root, it still needs a type/dtype/range validation path for non-const roots (at least DataType::Index and non-negative).

    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    It's great to see this work functionally!

    type == CommunicationType::SendRecv ||
    type == CommunicationType::AllToAll) {
    type == CommunicationType::AllToAll ||
    type == CommunicationType::StreamBroadcast) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I understood the motivation but can this be consolidated into the same Broadcast?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I kept it separate so I don't need to check for the StreamParallel Type in lowerToBroadcast when deciding the root. Posting the communication uses a common function.
    I also wanted to first integrate SendRecv based decomposition and then reconsider the design based on what is needed for both these comms.

    "Destination allocation should be sharded on stream after "
    "shardAllocationAsLoop: ",
    destination);
    destination->domain()->toString(0, /*loop_only=*/false));
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I guess destination is still worth printing in addition to the domain?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    destination prints the loop domain. I added name above to be printed in addition to the complete domain

    TensorView* in = communication->in();
    TensorView* out = communication->out();
    if (haveDifferentShardings(
    if (communication->type() != CommunicationType::StreamBroadcast &&
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    While I understood the motivation and that the tests pass, I'm thinking how to make this cleaner.

    Is it possible to frame this as an optimization? For example, if in can be sharded on Stream in the same way as communication, insert a shardByStream.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Yeah, I do think this should be merged in the shardByStream or some other logic.
    For now, I kept it simple since I am not sure how it will look like with Collective Permute representation (a composite Communication, P2P comms corresponding to SendRecv etc.), so I took the verbose approach as an interim step.

    Let me see what I can do in this PR itself.


    // This ignores device dimensions on reduction axis.
    auto producer_pt_to_did =
    auto producer_pt_to_id =
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggested change
    auto producer_pt_to_id =
    const std::unordered_map<ParallelType, IterDomain*>& producer_pt_to_id =

    auto producer_pt_to_id =
    mapDeviceAndStreamParallelTypeToId(producer->getLoopDomain());
    auto consumer_pt_to_did =
    auto consumer_pt_to_id =
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    ditto

    std::vector<Expr*> convertSingleOpToCommunication(
    Expr* e,
    DeviceIdxType my_device_idx,
    Val* host_loop_index,
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggested change
    Val* host_loop_index,
    Val* root,

    Some communications (e.g. broadcast, reduce, gather, and scatter) are rooted. So far, we've been deciding the root according to device meshes. However, this use makes a case for passing in the root from the lowering process.

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @Priya2698
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    input_tensor,
    output_tensor);
    case CommunicationType::Broadcast:
    case CommunicationType::StreamBroadcast:
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    StreamBroadcast reuses broadcast logic but csrc/host_ir/evaluator.cpp:334-336 and :405-407 still reject it for the CUDA backend multicast path. If CUDA backend is intended to work with StreamBroadcast, add it to those checks; otherwise this will fail at runtime when using CommunicatorBackend::kCuda.

    input_tensor,
    output_tensor);
    case CommunicationType::Broadcast:
    case CommunicationType::StreamBroadcast:
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    StreamBroadcast shares broadcast logic but csrc/multidevice/cuda_p2p.cpp:645-667 and :689-704 don't handle it in their switch statements. This will cause runtime errors with CUDA backend. Either add StreamBroadcast cases (treating them like Broadcast) or ensure CUDA backend is never used with this communication type.

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants