Skip to content

CollectivePermute based allgather in host for-loop#5963

Draft
Priya2698 wants to merge 18 commits intomainfrom
pm/collective_permute
Draft

CollectivePermute based allgather in host for-loop#5963
Priya2698 wants to merge 18 commits intomainfrom
pm/collective_permute

Conversation

@Priya2698
Copy link
Collaborator

No description provided.

@Priya2698
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Feb 12, 2026

Description

  • Add new CollectivePermute communication primitive for peer-to-peer data exchange

  • Implement StreamBroadcast for decomposing allgather as broadcast in host for-loops

  • Add support for swizzle1d operations in allocation domain traversal

  • Include test coverage for collective permute and column parallel linear forward

Changes walkthrough

Relevant files
Enhancement
7 files
communication.h
Define CollectivePermute class and new communication types
+72/-1   
communication.cpp
Implement CollectivePermute and add postCollectivePermute function
+112/-0 
evaluator.cpp
Add CollectivePermute handler and improve for-loop index invalidation
+127/-0 
lower_to_communication.cpp
Add lowering support for CollectivePermute and StreamBroadcast
+97/-5   
lowering.cpp
Modify lowering to handle new communication types and sharding logic
+20/-12 
utils.cpp
Add dispatchSwizzle1D utility for computing peer indices 
+19/-0   
tensor_metadata.cpp
Add Swizzle1D handling in allocation domain traversal       
+26/-0   
Tests
2 files
test_communication.py
Add test for collective permute functionality                       
+29/-0   
test_overlap.py
Add column parallel linear forward test with broadcast-based
overlapping
+114/-0 
Additional files
8 files
dispatch.h +1/-0     
evaluator.h +1/-0     
ir.cpp +7/-3     
lower_to_communication.h +6/-0     
ops.cpp +4/-2     
convert_op_to_communication.cpp +4/-1     
utils.h +6/-0     
ir.cpp +21/-0   

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review
Potential division by zero

In dispatchSwizzle1D function, the code calculates team_size_val and uses it in modulo operations. While team_size should generally be > 0, there should be validation to ensure team_size > 0 before creating the Val and using it in calculations to prevent potential division by zero or modulo by zero errors.

  int64_t team_size = mesh.size(pt);
  at::Tensor md_index = mesh.multiDimensionalIndexOf(device_id);
  auto pt_axis = mesh.parallelTypeToAxis(pt);
  int64_t team_index = md_index[pt_axis].item<int64_t>();
  Val* team_size_val = IrBuilder::create<Val>(team_size, DataType::Index);
  Val* team_index_val = IrBuilder::create<Val>(team_index, DataType::Index);
  return std::make_pair(
      mod(add(host_loop_index, team_index_val), team_size_val),
      mod(add(team_size_val, sub(team_index_val, host_loop_index)),
          team_size_val));
}
Backend compatibility check

The CollectivePermute handler explicitly checks for NCCL backend and throws an error for other backends. This should be documented as a limitation, and consideration should be given to whether this is a fundamental limitation of the algorithm or an implementation constraint that could be relaxed.

NVF_CHECK_EQ(backend_type, CommunicatorBackend::kNccl);
Test coverage completeness

The test_collective_permute function only verifies that the correct number of NCCL operations are generated but doesn't validate the correctness of the data movement or test edge cases like single device scenarios. Additional test cases should be added to ensure the collective permute implementation works correctly across different device counts and tensor sizes.

def test_collective_permute(multidevice_test):
    d = multidevice_test.size
    mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d))

    with FusionDefinition() as fd:
        inp_tv = fd.define_tensor((d * 3,), contiguity=True, dtype=DataType.Float)
        out_tv = fd.ops.set(inp_tv)
        fd.add_output(out_tv)

        inp_tv.set_device_mesh(mesh)
        inp_tv.outer_split(0, d)
        inp_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x)

        out_tv.set_device_mesh(mesh)
        out_tv.outer_split(0, d)
        out_tv.swizzle1d(0, nvfuser.ParallelType.mesh_x)
        out_tv.axis(0).parallelize(nvfuser.ParallelType.stream)

    inp_ref = torch.randn(d * 3)
    inp = multidevice_test.shard_tensor(inp_ref, inp_tv)
    with torch.profiler.profile() as prof:
        (out,) = fd.execute([inp], _enable_options=["host_ir_lowering"])
    torch.testing.assert_close(out.cpu(), inp_ref)
    collective_permute_events = [
        event for event in prof.events() if "ncclDevKernel_SendRecv" in event.name
    ]
    assert len(collective_permute_events) == (d - 1)

Test failures

  • (Medium, 3) Missing NCCL broadcast event in tests.python.multidevice.test_overlap::test_column_parallel_linear_forward

    Test Name A100 GB200 H100 Source
    tests.python.multidevice.test_overlap.test_column_parallel_linear_forward
  • (Medium, 2) Large numerical mismatches in multidevice column-parallel linear forward test

    Test Name A100 (dist.) H100 (dist.) Source
    tests.python.multidevice.test_overlap.test_column_parallel_linear_forward

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant