From b2c6b0b9feacb99fdcd9289e1a3b2aac84a96bd2 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Thu, 5 Feb 2026 17:32:08 -0800 Subject: [PATCH 01/15] broadcast based allgather --- csrc/host_ir/lower_to_communication.cpp | 60 +++++++++++++++++-- csrc/host_ir/lowering.cpp | 7 ++- csrc/host_ir/ops.cpp | 2 +- csrc/multidevice/communication.cpp | 20 ++++--- csrc/multidevice/communication.h | 3 +- .../python/multidevice/test_communication.py | 35 +++++++++++ 6 files changed, 112 insertions(+), 15 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index d91fd4eda60..db45fa010d9 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -166,6 +166,32 @@ void lowerToBroadcast( backend)); } +void lowerToStreamBroadcast( + TensorView* input_tv, + TensorView* output_tv, + const CommunicatorBackend backend, + std::vector& comms) { + const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); + const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); + NVF_ERROR_EQ( + sender_mesh, + receiver_mesh, + "StreamBroadcast sender and receiver meshes must be the same. Given ", + sender_mesh, + " and ", + receiver_mesh); + Team team = receiver_mesh.vector(); + comms.push_back(IrBuilder::create( + CommunicationType::StreamBroadcast, + output_tv, + input_tv, + team, + /*root=*/-1, // This will be replaced by HostIrLowering with the for-loop + // index + c10d::ReduceOp::RedOpType::UNUSED, + backend)); +} + // Adds several SendRecv communications to the vector 'comms' // For now, we assume that this function is called only if // the input and output have the same sharding. Later we could support more @@ -370,14 +396,16 @@ CommunicationInfo getCommunicationInfo(Expr* e) { const auto c2p_map = pairwise_map.mapConsumerToProducer(); // This ignores device dimensions on reduction axis. - auto producer_pt_to_did = + auto producer_pt_to_id = mapDeviceAndStreamParallelTypeToId(producer->getLoopDomain()); - auto consumer_pt_to_did = + auto consumer_pt_to_id = mapDeviceAndStreamParallelTypeToId(consumer->getLoopDomain()); for (ParallelType pt : kParallelTypeDIDs) { - IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt); - IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt); + IterDomain* p_loop_did = getOrDefault(producer_pt_to_id, pt); + IterDomain* c_loop_did = getOrDefault(consumer_pt_to_id, pt); + IterDomain* c_stream_id = + getOrDefault(consumer_pt_to_id, ParallelType::Stream); if (p_loop_did == nullptr && c_loop_did == nullptr) { // Not sharded on this parallel type @@ -391,6 +419,24 @@ CommunicationInfo getCommunicationInfo(Expr* e) { if (e->isA()) { if (p_loop_did && !c_loop_did) { IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); + // Check if we are going from DIDx -> Stream, which is a ring allgather. + // This can be executed as a broadcast or send recvs, which is decided + // by the presence of a swizzle in the stream id definition. + if (c_stream_id != nullptr) { + IterDomain* c_stream_logical_id = + getLogicalFromLoopId(consumer, c_stream_id); + if (c_stream_logical_id == p2c_map.at(p_logical_id)) { + NVF_CHECK( + same_mesh, + "Broadcast based allgather in stream parallel requires same " + "mesh.") + fill_communication_info( + CommunicationType::StreamBroadcast, + p_logical_id, + c_stream_logical_id); + continue; + } + } CommunicationType type = same_mesh ? CommunicationType::Allgather : CommunicationType::Gather; fill_communication_info(type, p_logical_id, p2c_map.at(p_logical_id)); @@ -478,7 +524,8 @@ Layout getCommunicationLayout( type == CommunicationType::Allreduce || type == CommunicationType::Broadcast || type == CommunicationType::SendRecv || - type == CommunicationType::AllToAll) { + type == CommunicationType::AllToAll || + type == CommunicationType::StreamBroadcast) { return layout; } @@ -605,6 +652,9 @@ std::vector convertSingleOpToCommunication( case CommunicationType::AllToAll: lowerToAllToAll(input_tv, output_tv, backend, comms); break; + case CommunicationType::StreamBroadcast: + lowerToStreamBroadcast(input_tv, output_tv, backend, comms); + break; } return comms; diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 3924d5658aa..63010a768c2 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -186,7 +186,8 @@ void lowerSegment( auto* communication = c->as(); TensorView* in = communication->in(); TensorView* out = communication->out(); - if (haveDifferentShardings( + if (communication->type() != CommunicationType::StreamBroadcast && + haveDifferentShardings( in, DomainType::kAllocation, out, @@ -219,6 +220,10 @@ void lowerSegment( innermost_scope.pushBack(allocate); } + if (communication->type() == CommunicationType::StreamBroadcast) { + replacement_map[communication->root()] = innermost.loop->index(); + } + Expr* new_c = cloneWithNewOperands(c, replacement_map); innermost_scope.pushBack(new_c); diff --git a/csrc/host_ir/ops.cpp b/csrc/host_ir/ops.cpp index 05fd42e2764..f0c555bf28b 100644 --- a/csrc/host_ir/ops.cpp +++ b/csrc/host_ir/ops.cpp @@ -90,7 +90,7 @@ TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e) { nullptr, "Destination allocation should be sharded on stream after " "shardAllocationAsLoop: ", - destination); + destination->domain()->toString(0, /*loop_only=*/false)); // Refine the contiguity flags so `out` aliases `in`. This is done similar // to AliasFinder::handle(const SliceOp*). We scan through the allocation diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 6778da9da71..7a4a5ac66a9 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -57,6 +57,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { case CommunicationType::AllToAll: os << "AllToAll"; break; + case CommunicationType::StreamBroadcast: + os << "StreamBroadcast"; + break; } return os; } @@ -149,6 +152,7 @@ bool hasRoot(CommunicationType type) { case CommunicationType::Reduce: case CommunicationType::Broadcast: case CommunicationType::SendRecv: + case CommunicationType::StreamBroadcast: return true; case CommunicationType::Allgather: case CommunicationType::Allreduce: @@ -171,6 +175,7 @@ bool isReduction(CommunicationType type) { case CommunicationType::Broadcast: case CommunicationType::SendRecv: case CommunicationType::AllToAll: + case CommunicationType::StreamBroadcast: return false; default: NVF_THROW("unrecognized CommunicationType: ", type); @@ -231,13 +236,13 @@ Communication::Communication( void Communication::validate() { if (root()->isConstScalar() && root()->isIntegralScalar()) { - auto root_val = root()->evaluate().as(); - NVF_ERROR( - hasRoot(type()) == (root_val >= 0), - "Root ", - root_val, - " is not expected by CommunicationType ", - type()); + // auto root_val = root()->evaluate().as(); + // NVF_ERROR( + // hasRoot(type()) == (root_val >= 0), + // "Root ", + // root_val, + // " is not expected by CommunicationType ", + // type()); } NVF_ERROR(isReduction(type()) == (reduceOp() != RedOpType::UNUSED)); } @@ -716,6 +721,7 @@ c10::intrusive_ptr postSingleCommunication( input_tensor, output_tensor); case CommunicationType::Broadcast: + case CommunicationType::StreamBroadcast: return postBroadcast( communication, my_device_index, diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 1a7f1a1cc4c..24101deeeb4 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -33,7 +33,8 @@ enum class CommunicationType { ReduceScatter, Broadcast, SendRecv, - AllToAll + AllToAll, + StreamBroadcast, }; std::ostream& operator<<(std::ostream& os, const CommunicationType& type); diff --git a/tests/python/multidevice/test_communication.py b/tests/python/multidevice/test_communication.py index 833ab511ff3..19f32509012 100644 --- a/tests/python/multidevice/test_communication.py +++ b/tests/python/multidevice/test_communication.py @@ -171,3 +171,38 @@ def test_alltoall(multidevice_test, inp_axis, out_axis): inp = multidevice_test.shard_tensor(in_ref, inp_tv) (out,) = fd.execute([inp]) torch.testing.assert_close(out, multidevice_test.shard_tensor(out_ref, out_tv)) + + +@pytest.mark.mpi +def test_broadcast_based_allgather(multidevice_test): + d = multidevice_test.size + + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) + + with FusionDefinition() as fd: + inp_tv = fd.define_tensor((d * 3,), contiguity=True, dtype=DataType.Half) + out_tv = fd.ops.set(inp_tv) + fd.add_output(out_tv) + + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) + inp_tv.set_device_mesh(mesh) + inp_tv.outer_split(0, d) + inp_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x) + + out_tv.set_device_mesh(mesh) + out_tv.outer_split(0, d) + out_tv.axis(0).parallelize(nvfuser.ParallelType.stream) + + unsharded_inp = torch.randn(d * 3, dtype=torch.float16) + inp = multidevice_test.shard_tensor(unsharded_inp, inp_tv) + with torch.profiler.profile(record_shapes=True) as profile: + (out,) = fd.execute( + [inp], + _enable_options=["host_ir_lowering"], + _disable_options=["infer_contiguity"], + ) + broadcast_events = [ + event for event in profile.events() if "ncclDevKernel_Broadcast" in event.name + ] + torch.testing.assert_close(out.cpu(), unsharded_inp) + print(broadcast_events, len(broadcast_events)) From 9fed72f3ac5fe4ab7744e87f50e958b04bc55df2 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Thu, 5 Feb 2026 18:10:15 -0800 Subject: [PATCH 02/15] allow host loop index as a variable --- csrc/host_ir/lower_to_communication.cpp | 13 +++++++++---- csrc/host_ir/lower_to_communication.h | 1 + csrc/host_ir/lowering.cpp | 7 ++----- csrc/host_ir/pass/convert_op_to_communication.cpp | 5 ++++- csrc/multidevice/communication.cpp | 14 +++++++------- 5 files changed, 23 insertions(+), 17 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index db45fa010d9..960113941c6 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -170,7 +170,8 @@ void lowerToStreamBroadcast( TensorView* input_tv, TensorView* output_tv, const CommunicatorBackend backend, - std::vector& comms) { + std::vector& comms, + Val* root) { const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); NVF_ERROR_EQ( @@ -186,8 +187,7 @@ void lowerToStreamBroadcast( output_tv, input_tv, team, - /*root=*/-1, // This will be replaced by HostIrLowering with the for-loop - // index + root, c10d::ReduceOp::RedOpType::UNUSED, backend)); } @@ -583,6 +583,7 @@ bool isCommunicationLayoutCompliant(Expr* expr) { std::vector convertSingleOpToCommunication( Expr* e, DeviceIdxType my_device_idx, + Val* host_loop_index, const CommunicatorBackend backend) { FusionGuard fg(e->fusion()); @@ -653,7 +654,11 @@ std::vector convertSingleOpToCommunication( lowerToAllToAll(input_tv, output_tv, backend, comms); break; case CommunicationType::StreamBroadcast: - lowerToStreamBroadcast(input_tv, output_tv, backend, comms); + NVF_ERROR( + host_loop_index != nullptr, + "StreamBroadcast requires a host loop index"); + lowerToStreamBroadcast( + input_tv, output_tv, backend, comms, /*root=*/host_loop_index); break; } diff --git a/csrc/host_ir/lower_to_communication.h b/csrc/host_ir/lower_to_communication.h index 65a9182e605..4b1c657e0e5 100644 --- a/csrc/host_ir/lower_to_communication.h +++ b/csrc/host_ir/lower_to_communication.h @@ -54,6 +54,7 @@ Layout getCommunicationLayout( std::vector convertSingleOpToCommunication( Expr* c, DeviceIdxType my_device_idx, + Val* host_loop_index = nullptr, const CommunicatorBackend backend = CommunicatorBackend::kNccl); } // namespace nvfuser diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 63010a768c2..01f271ea636 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -178,7 +178,8 @@ void lowerSegment( // TODO: `replacement_map` should be associated with the scope so // ShardByStream across segments in the same for-loop can be reused. std::unordered_map replacement_map; - for (Expr* c : convertSingleOpToCommunication(e, device_id)) { + for (Expr* c : convertSingleOpToCommunication( + e, device_id, innermost.loop->index())) { NVF_ERROR( c->isA(), "Exprs in a Communication group should be Communication: ", @@ -220,10 +221,6 @@ void lowerSegment( innermost_scope.pushBack(allocate); } - if (communication->type() == CommunicationType::StreamBroadcast) { - replacement_map[communication->root()] = innermost.loop->index(); - } - Expr* new_c = cloneWithNewOperands(c, replacement_map); innermost_scope.pushBack(new_c); diff --git a/csrc/host_ir/pass/convert_op_to_communication.cpp b/csrc/host_ir/pass/convert_op_to_communication.cpp index 4ce4c59b0ce..07305fb5e09 100644 --- a/csrc/host_ir/pass/convert_op_to_communication.cpp +++ b/csrc/host_ir/pass/convert_op_to_communication.cpp @@ -35,7 +35,10 @@ void ConvertOpToCommunication::passImplementation(Fusion* fusion) { return new_top_level_exprs.push_back(top_level_expr); } for (auto* expr : nvfuser::convertSingleOpToCommunication( - top_level_expr, my_device_index, params_.communicator_backend)) { + top_level_expr, + my_device_index, + /*host_loop_index=*/nullptr, + params_.communicator_backend)) { // Allocate the recv buffers of communications if (expr->isA()) { auto* communication = expr->as(); diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 7a4a5ac66a9..21bc4994628 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -236,13 +236,13 @@ Communication::Communication( void Communication::validate() { if (root()->isConstScalar() && root()->isIntegralScalar()) { - // auto root_val = root()->evaluate().as(); - // NVF_ERROR( - // hasRoot(type()) == (root_val >= 0), - // "Root ", - // root_val, - // " is not expected by CommunicationType ", - // type()); + auto root_val = root()->evaluate().as(); + NVF_ERROR( + hasRoot(type()) == (root_val >= 0), + "Root ", + root_val, + " is not expected by CommunicationType ", + type()); } NVF_ERROR(isReduction(type()) == (reduceOp() != RedOpType::UNUSED)); } From 2825e4f4379ecd4186735b6681154c7cfd53256f Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Fri, 6 Feb 2026 13:32:32 -0800 Subject: [PATCH 03/15] benchmark bcast + matmul decomposition --- .../python/multidevice/test_communication.py | 35 ----- tests/python/multidevice/test_overlap.py | 126 ++++++++++++++++++ 2 files changed, 126 insertions(+), 35 deletions(-) diff --git a/tests/python/multidevice/test_communication.py b/tests/python/multidevice/test_communication.py index 19f32509012..833ab511ff3 100644 --- a/tests/python/multidevice/test_communication.py +++ b/tests/python/multidevice/test_communication.py @@ -171,38 +171,3 @@ def test_alltoall(multidevice_test, inp_axis, out_axis): inp = multidevice_test.shard_tensor(in_ref, inp_tv) (out,) = fd.execute([inp]) torch.testing.assert_close(out, multidevice_test.shard_tensor(out_ref, out_tv)) - - -@pytest.mark.mpi -def test_broadcast_based_allgather(multidevice_test): - d = multidevice_test.size - - mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) - - with FusionDefinition() as fd: - inp_tv = fd.define_tensor((d * 3,), contiguity=True, dtype=DataType.Half) - out_tv = fd.ops.set(inp_tv) - fd.add_output(out_tv) - - mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) - inp_tv.set_device_mesh(mesh) - inp_tv.outer_split(0, d) - inp_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x) - - out_tv.set_device_mesh(mesh) - out_tv.outer_split(0, d) - out_tv.axis(0).parallelize(nvfuser.ParallelType.stream) - - unsharded_inp = torch.randn(d * 3, dtype=torch.float16) - inp = multidevice_test.shard_tensor(unsharded_inp, inp_tv) - with torch.profiler.profile(record_shapes=True) as profile: - (out,) = fd.execute( - [inp], - _enable_options=["host_ir_lowering"], - _disable_options=["infer_contiguity"], - ) - broadcast_events = [ - event for event in profile.events() if "ncclDevKernel_Broadcast" in event.name - ] - torch.testing.assert_close(out.cpu(), unsharded_inp) - print(broadcast_events, len(broadcast_events)) diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index c6453e513c5..4fb82bf43fc 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -417,6 +417,132 @@ def test_column_parallel_linear_forward_reference_benchmark( benchmark.pedantic(benchmark_fn, rounds=5) +@pytest.mark.mpi +def column_parallel_linear_forward(h: int, d: int): + with FusionDefinition() as fd: + inp_tv = fd.define_tensor((-1, h), contiguity=True, dtype=DataType.BFloat16) + weight_tv = fd.define_tensor( + (4 * h, h), contiguity=True, dtype=DataType.BFloat16 + ) + ag_out = fd.ops.set(inp_tv) + out_tv = fd.ops.linear(ag_out, weight_tv) + fd.add_output(out_tv) + + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) + + for tv in [inp_tv, weight_tv]: + tv.set_device_mesh(mesh) + tv.outer_split(0, d) + tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x) + + ag_out.set_device_mesh(mesh) + ag_out.outer_split(0, d) + ag_out.axis(0).parallelize(nvfuser.ParallelType.stream) + + # Fusion IR before segmentation will look like this: + # [t, h] + # /\. + # d + # (deviceIdx.x) + # | + # | set (lowered to StreamBroadcast. This decomposition is done manually in the definition above. It will later be done by preseg) + # | + # [t, h] [4h, h] + # /\ /\. + # s d + # (streamIdx) + # | + # | linear + # | + # [t, 4h, r{h}] + # /\ /\. + # s* d + + return fd + + +@pytest.mark.mpi +def test_column_parallel_linear_forward(multidevice_test): + # This is a port of CollectiveBasedOverlapTest.ColumnAndSequenceParallelLinear_Forward. + # The difference is we are using broadcast based overlapping instead of send/recv. + h, t = 2, 24 + d = multidevice_test.size + if (h * 4) % d != 0: + pytest.skip( + f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + ) + if t % d != 0: + pytest.skip( + f"Column-parallel linear requires {t} to be divisible by world size {d}." + ) + + fd = column_parallel_linear_forward(h, d) + + inp_ref = torch.testing.make_tensor(t, h, dtype=torch.int32, device="cpu").to( + torch.bfloat16 + ) + weight_ref = torch.testing.make_tensor( + 4 * h, h, dtype=torch.int32, device="cpu" + ).to(torch.bfloat16) + + inp = multidevice_test.shard_tensor(inp_ref, fd.fusion.inputs()[0]) + weight = multidevice_test.shard_tensor(weight_ref, fd.fusion.inputs()[1]) + + out_ref = torch.nn.functional.linear(inp_ref.cuda(), weight) + + with torch.profiler.profile(record_shapes=True) as prof: + (out,) = fd.execute( + [inp, weight], + _enable_options=["host_ir_lowering"], + _disable_options=["infer_contiguity"], + ) + torch.testing.assert_close(out, out_ref) + with torch.profiler.profile(record_shapes=True) as profile: + (out,) = fd.execute( + [inp], + _enable_options=["host_ir_lowering"], + _disable_options=["infer_contiguity"], + ) + broadcast_events = [ + event for event in profile.events() if "ncclDevKernel_Broadcast" in event.name + ] + assert len(broadcast_events) == d + + +@pytest.mark.mpi +@pytest.mark.benchmark +def test_column_parallel_linear_forward_benchmark(multidevice_test, benchmark): + # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. + h, t = 8192, 8192 + d = multidevice_test.size + if (4 * h) % d != 0: + pytest.skip( + f"Column-parallel linear requires {4 * h} to be divisible by world size {d}." + ) + if t % d != 0: + pytest.skip( + f"Column-parallel linear requires {t} to be divisible by world size {d}." + ) + + fd = column_parallel_linear_forward(h, d) + + inp_ref = torch.randn(t, h, dtype=torch.bfloat16, device="cpu") + weight_ref = torch.randn(4 * h, h, dtype=torch.bfloat16, device="cpu") + + inp = multidevice_test.shard_tensor(inp_ref, fd.fusion.inputs()[0]) + weight = multidevice_test.shard_tensor(weight_ref, fd.fusion.inputs()[1]) + + warmup_fn, benchmark_fn = get_benchmark_fns( + lambda: fd.execute( + [inp, weight], + _enable_options=["host_ir_lowering"], + _disable_options=["infer_contiguity"], + ) + ) + warmup_fn() + benchmark.pedantic(benchmark_fn, rounds=5) + + @pytest.mark.mpi @pytest.mark.parametrize("backend_type", [CommunicatorBackend.nccl]) @pytest.mark.parametrize("s", [1, 8]) From 9df94e839d281678dd4b1370622002cbbd3d6d54 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Fri, 6 Feb 2026 13:34:57 -0800 Subject: [PATCH 04/15] remove disable option --- tests/python/multidevice/test_overlap.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index 4fb82bf43fc..bd5b105cb43 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -497,14 +497,8 @@ def test_column_parallel_linear_forward(multidevice_test): _disable_options=["infer_contiguity"], ) torch.testing.assert_close(out, out_ref) - with torch.profiler.profile(record_shapes=True) as profile: - (out,) = fd.execute( - [inp], - _enable_options=["host_ir_lowering"], - _disable_options=["infer_contiguity"], - ) broadcast_events = [ - event for event in profile.events() if "ncclDevKernel_Broadcast" in event.name + event for event in prof.events() if "ncclDevKernel_Broadcast" in event.name ] assert len(broadcast_events) == d @@ -536,7 +530,6 @@ def test_column_parallel_linear_forward_benchmark(multidevice_test, benchmark): lambda: fd.execute( [inp, weight], _enable_options=["host_ir_lowering"], - _disable_options=["infer_contiguity"], ) ) warmup_fn() From 5ee83650e39ab26df42ac86b340cef544f31ce45 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Fri, 6 Feb 2026 15:27:49 -0800 Subject: [PATCH 05/15] comment --- csrc/host_ir/lower_to_communication.cpp | 1 + csrc/multidevice/communication.h | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 960113941c6..34594a02613 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -422,6 +422,7 @@ CommunicationInfo getCommunicationInfo(Expr* e) { // Check if we are going from DIDx -> Stream, which is a ring allgather. // This can be executed as a broadcast or send recvs, which is decided // by the presence of a swizzle in the stream id definition. + // TODO: Lower to SendRecv if swizzle is present. if (c_stream_id != nullptr) { IterDomain* c_stream_logical_id = getLogicalFromLoopId(consumer, c_stream_id); diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 24101deeeb4..5b08de1db8a 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -190,6 +190,10 @@ class P2PCommunication : public Expr { // - the root has one src buffer, and no or one dst buffer // - non-roots have no src buffer and one dst buffer // - all buffers have the same size +// (*) StreamBroadcast +// Shares the same postBroadcast logic with Broadcast. The difference is the +// root is the for-loop index. I kept it separate from Broadcast I do not need +// to inspect the tensorviews later if we have to distinguish the two. // (*) Gather // Copies each device's source buffer to the root's respective src // buffer. The order of the sender devices matches the order of the From 5d604e12e837a50a0ed6a4a9fdffad47cce0d6ce Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Mon, 9 Feb 2026 13:07:58 -0800 Subject: [PATCH 06/15] comment --- csrc/multidevice/communication.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 5b08de1db8a..ab864c3a44a 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -192,8 +192,8 @@ class P2PCommunication : public Expr { // - all buffers have the same size // (*) StreamBroadcast // Shares the same postBroadcast logic with Broadcast. The difference is the -// root is the for-loop index. I kept it separate from Broadcast I do not need -// to inspect the tensorviews later if we have to distinguish the two. +// root is the for-loop index. I kept it separate from Broadcast so I don't need +// to inspect the tensorviews later to distinguish the two. // (*) Gather // Copies each device's source buffer to the root's respective src // buffer. The order of the sender devices matches the order of the From 901f809ea6ccd7a10566388c180f07d8db805633 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Mon, 9 Feb 2026 13:10:25 -0800 Subject: [PATCH 07/15] rm disable option --- tests/python/multidevice/test_overlap.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index bd5b105cb43..e5d287ec61f 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -491,11 +491,7 @@ def test_column_parallel_linear_forward(multidevice_test): out_ref = torch.nn.functional.linear(inp_ref.cuda(), weight) with torch.profiler.profile(record_shapes=True) as prof: - (out,) = fd.execute( - [inp, weight], - _enable_options=["host_ir_lowering"], - _disable_options=["infer_contiguity"], - ) + (out,) = fd.execute([inp, weight], _enable_options=["host_ir_lowering"]) torch.testing.assert_close(out, out_ref) broadcast_events = [ event for event in prof.events() if "ncclDevKernel_Broadcast" in event.name From d21deedb653db55e2c4f7a58918371bcdf91f750 Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Mon, 9 Feb 2026 13:35:40 -0800 Subject: [PATCH 08/15] Update csrc/host_ir/lower_to_communication.cpp Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- csrc/host_ir/lower_to_communication.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 34594a02613..fd0353decc0 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -424,14 +424,10 @@ CommunicationInfo getCommunicationInfo(Expr* e) { // by the presence of a swizzle in the stream id definition. // TODO: Lower to SendRecv if swizzle is present. if (c_stream_id != nullptr) { - IterDomain* c_stream_logical_id = - getLogicalFromLoopId(consumer, c_stream_id); - if (c_stream_logical_id == p2c_map.at(p_logical_id)) { NVF_CHECK( same_mesh, "Broadcast based allgather in stream parallel requires same " - "mesh.") - fill_communication_info( + "mesh."); CommunicationType::StreamBroadcast, p_logical_id, c_stream_logical_id); From 6e862cb705a6f4764375efcc44fa16e93782378a Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Mon, 9 Feb 2026 15:14:29 -0800 Subject: [PATCH 09/15] fix broken greptile update --- csrc/host_ir/lower_to_communication.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index fd0353decc0..c963163caef 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -424,10 +424,14 @@ CommunicationInfo getCommunicationInfo(Expr* e) { // by the presence of a swizzle in the stream id definition. // TODO: Lower to SendRecv if swizzle is present. if (c_stream_id != nullptr) { + IterDomain* c_stream_logical_id = + getLogicalFromLoopId(consumer, c_stream_id); + if (c_stream_logical_id == p2c_map.at(p_logical_id)) { NVF_CHECK( same_mesh, "Broadcast based allgather in stream parallel requires same " "mesh."); + fill_communication_info( CommunicationType::StreamBroadcast, p_logical_id, c_stream_logical_id); From 33ff8fcf5101d3697403b3b9f3ea3e995e166251 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Mon, 9 Feb 2026 15:17:38 -0800 Subject: [PATCH 10/15] remove pytest marker from fusion function --- tests/python/multidevice/test_overlap.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index e5d287ec61f..24f6234edcc 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -417,7 +417,6 @@ def test_column_parallel_linear_forward_reference_benchmark( benchmark.pedantic(benchmark_fn, rounds=5) -@pytest.mark.mpi def column_parallel_linear_forward(h: int, d: int): with FusionDefinition() as fd: inp_tv = fd.define_tensor((-1, h), contiguity=True, dtype=DataType.BFloat16) From be4f66b9dd7a7bb63ce478a0bc4a508c26ae9a17 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Tue, 10 Feb 2026 14:12:09 -0800 Subject: [PATCH 11/15] review --- csrc/host_ir/lower_to_communication.cpp | 13 ++++++------- csrc/host_ir/lower_to_communication.h | 7 ++++++- csrc/host_ir/ops.cpp | 4 +++- csrc/host_ir/pass/convert_op_to_communication.cpp | 2 +- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index c963163caef..6bd394f8b7f 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -396,9 +396,9 @@ CommunicationInfo getCommunicationInfo(Expr* e) { const auto c2p_map = pairwise_map.mapConsumerToProducer(); // This ignores device dimensions on reduction axis. - auto producer_pt_to_id = + const std::unordered_map& producer_pt_to_id = mapDeviceAndStreamParallelTypeToId(producer->getLoopDomain()); - auto consumer_pt_to_id = + const std::unordered_map& consumer_pt_to_id = mapDeviceAndStreamParallelTypeToId(consumer->getLoopDomain()); for (ParallelType pt : kParallelTypeDIDs) { @@ -584,7 +584,7 @@ bool isCommunicationLayoutCompliant(Expr* expr) { std::vector convertSingleOpToCommunication( Expr* e, DeviceIdxType my_device_idx, - Val* host_loop_index, + Val* root, const CommunicatorBackend backend) { FusionGuard fg(e->fusion()); @@ -656,10 +656,9 @@ std::vector convertSingleOpToCommunication( break; case CommunicationType::StreamBroadcast: NVF_ERROR( - host_loop_index != nullptr, - "StreamBroadcast requires a host loop index"); - lowerToStreamBroadcast( - input_tv, output_tv, backend, comms, /*root=*/host_loop_index); + root != nullptr, + "StreamBroadcast requires a root value passed in through lowering"); + lowerToStreamBroadcast(input_tv, output_tv, backend, comms, root); break; } diff --git a/csrc/host_ir/lower_to_communication.h b/csrc/host_ir/lower_to_communication.h index 4b1c657e0e5..8c789377478 100644 --- a/csrc/host_ir/lower_to_communication.h +++ b/csrc/host_ir/lower_to_communication.h @@ -51,10 +51,15 @@ Layout getCommunicationLayout( const CommunicationType type, IterDomain* sharded_id); +// Creates a communication expr corresponding to the given +// resharding expr. In most cases, `root` is inferred based +// on communication type. However, in some cases, for e.g. +// decomposing allgather as broadcast in a host for-loop, `root` +// may be passed in through lowering. std::vector convertSingleOpToCommunication( Expr* c, DeviceIdxType my_device_idx, - Val* host_loop_index = nullptr, + Val* root = nullptr, const CommunicatorBackend backend = CommunicatorBackend::kNccl); } // namespace nvfuser diff --git a/csrc/host_ir/ops.cpp b/csrc/host_ir/ops.cpp index f0c555bf28b..8e6441c5d8a 100644 --- a/csrc/host_ir/ops.cpp +++ b/csrc/host_ir/ops.cpp @@ -89,7 +89,9 @@ TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e) { destination, ParallelType::Stream, DomainType::kAllocation) != nullptr, "Destination allocation should be sharded on stream after " - "shardAllocationAsLoop: ", + "shardAllocationAsLoop. ", + destination->name(), + ":", destination->domain()->toString(0, /*loop_only=*/false)); // Refine the contiguity flags so `out` aliases `in`. This is done similar diff --git a/csrc/host_ir/pass/convert_op_to_communication.cpp b/csrc/host_ir/pass/convert_op_to_communication.cpp index 07305fb5e09..c8402811230 100644 --- a/csrc/host_ir/pass/convert_op_to_communication.cpp +++ b/csrc/host_ir/pass/convert_op_to_communication.cpp @@ -37,7 +37,7 @@ void ConvertOpToCommunication::passImplementation(Fusion* fusion) { for (auto* expr : nvfuser::convertSingleOpToCommunication( top_level_expr, my_device_index, - /*host_loop_index=*/nullptr, + /*root=*/nullptr, params_.communicator_backend)) { // Allocate the recv buffers of communications if (expr->isA()) { From d1dd491e3329f48cc84db2184ab2a69ab7fe563a Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Tue, 10 Feb 2026 16:48:40 -0800 Subject: [PATCH 12/15] test case and python bindings --- python/python_direct/ir.cpp | 21 +++++++++++++++ .../python/multidevice/test_communication.py | 26 +++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/python/python_direct/ir.cpp b/python/python_direct/ir.cpp index 93c032e4ef6..c149663da38 100644 --- a/python/python_direct/ir.cpp +++ b/python/python_direct/ir.cpp @@ -501,6 +501,27 @@ Returns ------- TensorView A TensorView with the swizzled axes in its loop domain. +)") + .def( + "swizzle1d", + [](TensorView* self, int64_t x, ParallelType parallel_type) { + return self->swizzle1d(x, parallel_type); + }, + py::return_value_policy::reference, + py::arg("x"), + py::arg("parallel_type"), + R"( +Swizzle the specified axis with the device index corresponding to the given parallel type. +Parameters +---------- +x : int +The axis to swizzle. +parallel_type : ParallelType +The device parallel type for the 1D swizzle. +Returns +------- +TensorView +A TensorView with the swizzled axis in its loop domain. )") .def( "rfactor", diff --git a/tests/python/multidevice/test_communication.py b/tests/python/multidevice/test_communication.py index 833ab511ff3..cd2c8f9e550 100644 --- a/tests/python/multidevice/test_communication.py +++ b/tests/python/multidevice/test_communication.py @@ -171,3 +171,29 @@ def test_alltoall(multidevice_test, inp_axis, out_axis): inp = multidevice_test.shard_tensor(in_ref, inp_tv) (out,) = fd.execute([inp]) torch.testing.assert_close(out, multidevice_test.shard_tensor(out_ref, out_tv)) + + +def test_collective_permute(multidevice_test): + d = multidevice_test.size + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) + + with FusionDefinition() as fd: + inp_tv = fd.define_tensor((d * 3,), contiguity=True, dtype=DataType.Float) + out_tv = fd.ops.set(inp_tv) + fd.add_output(out_tv) + + inp_tv.set_device_mesh(mesh) + inp_tv.outer_split(0, d) + inp_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x) + + out_tv.set_device_mesh(mesh) + out_tv.outer_split(0, d) + out_tv.swizzle1d(0, nvfuser.ParallelType.mesh_x) + out_tv.axis(0).parallelize(nvfuser.ParallelType.Stream) + + inp_ref = torch.randn(d * 3) + inp = multidevice_test.shard_tensor(inp_ref, inp_tv) + with torch.profiler.profile() as prof: + (out,) = fd.execute([inp], _enable_options=["host_ir_lowering"]) + print(prof.key_averages()) + torch.testing.assert_close(out.cpu(), inp_ref) From 363a0f458b458c27b7ab920c7a6a4c54c7e8baed Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Wed, 11 Feb 2026 14:05:17 -0800 Subject: [PATCH 13/15] add CollectivePermute wip --- csrc/dispatch.h | 1 + csrc/host_ir/evaluator.cpp | 32 ++++++ csrc/host_ir/evaluator.h | 1 + csrc/host_ir/lower_to_communication.cpp | 47 +++++++- csrc/host_ir/lowering.cpp | 24 +++-- csrc/multidevice/communication.cpp | 100 ++++++++++++++++++ csrc/multidevice/communication.h | 66 ++++++++++++ csrc/multidevice/utils.cpp | 19 ++++ csrc/multidevice/utils.h | 6 ++ .../python/multidevice/test_communication.py | 2 +- 10 files changed, 282 insertions(+), 16 deletions(-) diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 7c53c86d903..6ef541a8037 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -125,6 +125,7 @@ class Val; f(SdpaFwdOp); \ f(SdpaBwdOp); \ f(EmbeddingFwdOp); \ + f(CollectivePermute); \ f(Communication); \ f(P2PCommunication); #define DISPATCH_FOR_ALL_KIR_EXPRS(f) \ diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 2396767d5b0..e299bc9649d 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -310,6 +310,38 @@ void HostIrEvaluator::handle(ShareMemHandles* share_mem_handles) { ipc_handle_cache_.exchangeHandles(share_mem_handles->communications()); } +void HostIrEvaluator::handle(CollectivePermute* communication) { + NVF_ERROR( + communicator_ != nullptr && communicator_->is_available(), + "A valid communicator must be provided"); + + at::Tensor input_tensor = getKnownTensorOrUndefined(communication->input(0)); + at::Tensor output_tensor = + getKnownTensorOrUndefined(communication->output(0)); + +#ifndef NDEBUG + validateSizesAndStrides( + {input_tensor, output_tensor}, + {communication->in(), communication->out()}, + expr_evaluator_); +#endif + + CommunicatorBackend backend_type = communication->backend(); + // CollectivePermute is only supported with NCCL backend because + // UCC does not support coalescing. + NVF_CHECK_EQ(backend_type, CommunicatorBackend::kNccl); + c10d::Backend* backend = + communicator_->getBackendForTeam(communication->team(), backend_type); + works_[communication] = postSingleCommunication( + communication, + communicator_->deviceId(), + backend, + input_tensor, + output_tensor, + expr_evaluator_.evaluate(communication->sendPeer()).as(), + expr_evaluator_.evaluate(communication->recvPeer()).as()); +} + void HostIrEvaluator::handle(Communication* communication) { NVF_ERROR( communicator_ != nullptr && communicator_->is_available(), diff --git a/csrc/host_ir/evaluator.h b/csrc/host_ir/evaluator.h index 22833156cab..9acbb03750a 100644 --- a/csrc/host_ir/evaluator.h +++ b/csrc/host_ir/evaluator.h @@ -96,6 +96,7 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch { void handle(Synchronize*) override; void handle(PostOnStream*) override; void handle(LaunchKernel*) override; + void handle(CollectivePermute*) override; void handle(Communication*) override; void handle(P2PCommunication*) override; void handle(Wait*) override; diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 6bd394f8b7f..a703bea6f91 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -345,6 +345,33 @@ void lowerToAllToAll( backend)); } +void lowerToCollectivePermute( + TensorView* input_tv, + TensorView* output_tv, + const CommunicatorBackend backend, + std::vector& comms, + Val* root, + DeviceIdxType my_device_idx) { + NVF_ERROR_EQ( + input_tv->getDeviceMesh(), + output_tv->getDeviceMesh(), + "CollectivePermute sender and receiver meshes must be the same. Given ", + input_tv->getDeviceMesh(), + " and ", + output_tv->getDeviceMesh()); + + IterDomain* stream_id = + getShardedIterDomain(output_tv, ParallelType::Stream, DomainType::kLoop); + Swizzle1D* swizzle = stream_id->definition()->as(); + ParallelType pt = swizzle->parallelType(); + + const auto& [send_peer, recv_peer] = + dispatchSwizzle1D(root, my_device_idx, pt, input_tv->getDeviceMesh()); + Team team = input_tv->getDeviceMesh().vector(); + comms.push_back(IrBuilder::create( + output_tv, input_tv, team, send_peer, recv_peer, backend)); +} + IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) { std::unordered_set logical_ids = getInputsInTargetDomain({loop_id}, tv->getLogicalDomain()); @@ -422,7 +449,6 @@ CommunicationInfo getCommunicationInfo(Expr* e) { // Check if we are going from DIDx -> Stream, which is a ring allgather. // This can be executed as a broadcast or send recvs, which is decided // by the presence of a swizzle in the stream id definition. - // TODO: Lower to SendRecv if swizzle is present. if (c_stream_id != nullptr) { IterDomain* c_stream_logical_id = getLogicalFromLoopId(consumer, c_stream_id); @@ -431,10 +457,11 @@ CommunicationInfo getCommunicationInfo(Expr* e) { same_mesh, "Broadcast based allgather in stream parallel requires same " "mesh."); - fill_communication_info( - CommunicationType::StreamBroadcast, - p_logical_id, - c_stream_logical_id); + auto* swizzle = dynamic_cast(c_stream_id->definition()); + CommunicationType type = swizzle != nullptr + ? CommunicationType::CollectivePermute + : CommunicationType::StreamBroadcast; + fill_communication_info(type, p_logical_id, c_stream_logical_id); continue; } } @@ -525,6 +552,7 @@ Layout getCommunicationLayout( type == CommunicationType::Allreduce || type == CommunicationType::Broadcast || type == CommunicationType::SendRecv || + type == CommunicationType::CollectivePermute || type == CommunicationType::AllToAll || type == CommunicationType::StreamBroadcast) { return layout; @@ -660,6 +688,15 @@ std::vector convertSingleOpToCommunication( "StreamBroadcast requires a root value passed in through lowering"); lowerToStreamBroadcast(input_tv, output_tv, backend, comms, root); break; + case CommunicationType::CollectivePermute: + // FIXME: Rename this to host loop index. Collective Permute has no root. + // The send and recv peer indices are computed using the host loop index. + NVF_ERROR( + root != nullptr, + "CollectivePermute requires a root value passed in through lowering"); + lowerToCollectivePermute( + input_tv, output_tv, backend, comms, root, my_device_idx); + break; } return comms; diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 01f271ea636..0f36d4f65db 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -181,13 +181,19 @@ void lowerSegment( for (Expr* c : convertSingleOpToCommunication( e, device_id, innermost.loop->index())) { NVF_ERROR( - c->isA(), - "Exprs in a Communication group should be Communication: ", + c->isA() || c->isA(), + "Exprs in a Communication group should be Communication or " + "CollectivePermute: ", c); - auto* communication = c->as(); - TensorView* in = communication->in(); - TensorView* out = communication->out(); - if (communication->type() != CommunicationType::StreamBroadcast && + TensorView* in = c->input(0)->as(); + TensorView* out = c->output(0)->as(); + bool can_shard_in = true; + if (c->isA() || + c->as()->type() == + CommunicationType::StreamBroadcast) { + can_shard_in = false; + } + if (can_shard_in && haveDifferentShardings( in, DomainType::kAllocation, @@ -196,8 +202,7 @@ void lowerSegment( {ParallelType::Stream})) { Val*& sharded_in = replacement_map[in]; if (sharded_in == nullptr) { - sharded_in = - hir::shardByStream(in, innermost.loop->index(), communication); + sharded_in = hir::shardByStream(in, innermost.loop->index(), c); innermost_scope.pushBack(sharded_in->definition()); } } @@ -213,8 +218,7 @@ void lowerSegment( innermost.parent_scope->insert( innermost.parent_insertion_point, allocate); auto [i, inserted] = replacement_map.emplace( - out, - hir::shardByStream(out, innermost.loop->index(), communication)); + out, hir::shardByStream(out, innermost.loop->index(), c)); NVF_ERROR(inserted, "The input segmented fusion should be SSA."); innermost_scope.pushBack(i->second->definition()); } else { diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 21bc4994628..5b432f8876d 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -60,6 +60,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { case CommunicationType::StreamBroadcast: os << "StreamBroadcast"; break; + case CommunicationType::CollectivePermute: + os << "CollectivePermute"; + break; } return os; } @@ -158,6 +161,7 @@ bool hasRoot(CommunicationType type) { case CommunicationType::Allreduce: case CommunicationType::ReduceScatter: case CommunicationType::AllToAll: + case CommunicationType::CollectivePermute: return false; } std::unreachable(); @@ -176,6 +180,7 @@ bool isReduction(CommunicationType type) { case CommunicationType::SendRecv: case CommunicationType::AllToAll: case CommunicationType::StreamBroadcast: + case CommunicationType::CollectivePermute: return false; default: NVF_THROW("unrecognized CommunicationType: ", type); @@ -326,6 +331,47 @@ std::string P2PCommunication::toString(int indent_size) const { return toInlineString(indent_size) + "\n"; } +CollectivePermute::CollectivePermute( + IrBuilderPasskey passkey, + TensorView* out, + TensorView* in, + Team team, + Val* send_peer, + Val* recv_peer, + CommunicatorBackend backend) + : Expr(passkey) { + NVF_ERROR( + in->getDeviceMesh().size() > 0, + "The input mesh size must be greater than 0."); + NVF_ERROR( + out->getDeviceMesh().size() > 0, + "The output mesh size must be greater than 0."); + addInput(in); + addInput(send_peer); + addInput(recv_peer); + addOutput(out); + addDataAttribute(CommunicationType::CollectivePermute); + addDataAttribute(team); + addDataAttribute(backend); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(CollectivePermute) + +std::string CollectivePermute::toInlineString(const int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "CollectivePermute " << name() << " (" + << "team=(" << team() << ")" + << ", send_peer=" << sendPeer()->toInlineString() + << ", recv_peer=" << recvPeer()->toInlineString() + << ", input=" << in() << ", output=" << out() + << ", backend=" << backend() << ")"; + return ss.str(); +} + +std::string CollectivePermute::toString(int indent_size) const { + return toInlineString(indent_size) + "\n"; +} + namespace { c10::intrusive_ptr postBroadcast( Communication* communication, @@ -650,6 +696,28 @@ c10::intrusive_ptr postAllToAll( empty_split_sizes, /*options=*/{}); } + +c10::intrusive_ptr postCollectivePermute( + CollectivePermute* communication, + DeviceIdxType my_device_index, + DeviceIdxType send_peer_index, + DeviceIdxType recv_peer_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor) { + backend->startCoalescing(); + std::vector send_tensors = {input_tensor}; + backend->send( + send_tensors, + send_peer_index, + /*tag=*/0); + std::vector recv_tensors = {output_tensor}; + backend->recv( + recv_tensors, + recv_peer_index, + /*tag=*/0); + return backend->endCoalescing(); +} } // namespace c10::intrusive_ptr postSingleCommunication( @@ -746,6 +814,38 @@ c10::intrusive_ptr postSingleCommunication( } } +c10::intrusive_ptr postSingleCommunication( + CollectivePermute* communication, + DeviceIdxType my_device_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor, + DeviceIdxType send_peer_index, + DeviceIdxType recv_peer_index) { + const Team& team = communication->team(); + if (std::find(team.begin(), team.end(), my_device_index) == team.end()) { + return nullptr; + } + NVF_CHECK(backend != nullptr); + + if (isDebugDumpEnabled(DebugDumpOption::Communication)) { + debug() << "Posting " << communication->toInlineString() + << " with input_tensor " << input_tensor.sizes() + << " and output_tensor " << output_tensor.sizes() + << " send_peer=" << send_peer_index + << " recv_peer=" << recv_peer_index << std::endl; + } + + return postCollectivePermute( + communication, + my_device_index, + send_peer_index, + recv_peer_index, + backend, + input_tensor, + output_tensor); +} + namespace { c10::intrusive_ptr postSend( diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index ab864c3a44a..d1bfa6d6cb6 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -35,6 +35,7 @@ enum class CommunicationType { SendRecv, AllToAll, StreamBroadcast, + CollectivePermute, }; std::ostream& operator<<(std::ostream& os, const CommunicationType& type); @@ -130,6 +131,62 @@ class Communication : public Expr { void validate(); }; +// CollectivePermute: send to send_peer, recv from recv_peer. Separate from +// Communication (no root, no reduce op). Layout: inputs [in, send_peer, +// recv_peer], output [out], attributes [type, team, backend]. +class CollectivePermute : public Expr { + public: + using Expr::Expr; + + CollectivePermute( + IrBuilderPasskey passkey, + TensorView* out, + TensorView* in, + Team team, + Val* send_peer, + Val* recv_peer, + CommunicatorBackend backend = CommunicatorBackend::kNccl); + + CollectivePermute(const CollectivePermute& other) = delete; + CollectivePermute& operator=(const CollectivePermute& other) = delete; + CollectivePermute(CollectivePermute&& other) = delete; + CollectivePermute& operator=(CollectivePermute&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "CollectivePermute"; + } + + CommunicationType type() const { + return attribute(0); + } + + TensorView* in() const { + return input(0)->as(); + } + TensorView* out() const { + return output(0)->as(); + } + Val* sendPeer() const { + return input(1); + } + Val* recvPeer() const { + return input(2); + } + const Team& team() const { + return attribute(1); + } + int64_t team_size() const { + return static_cast(team().size()); + } + CommunicatorBackend backend() const { + return attribute(2); + } +}; + enum class P2PCommunicationType { SEND, RECV }; std::ostream& operator<<(std::ostream& os, const P2PCommunicationType& type); @@ -246,6 +303,15 @@ c10::intrusive_ptr postSingleCommunication( at::Tensor output_tensor, DeviceIdxType root_index = -1); +c10::intrusive_ptr postSingleCommunication( + CollectivePermute* communication, + DeviceIdxType my_device_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor, + DeviceIdxType send_peer_index, + DeviceIdxType recv_peer_index); + c10::intrusive_ptr postSingleCommunication( P2PCommunication* communication, DeviceIdxType my_device_index, diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index beb7283c5a1..1b08456d753 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -14,8 +14,10 @@ #include #include "compute_at_map.h" +#include "ir/builder.h" #include "ir/internal_base_nodes.h" #include "ir/internal_nodes.h" +#include "ops/arith.h" #include "transform_replay.h" #include "type.h" @@ -355,4 +357,21 @@ int64_t getRFactorDeviceDimensionIndex(const TensorView* tv) { return rfactor_did_idx; } +std::pair dispatchSwizzle1D( + Val* host_loop_index, + DeviceIdxType device_id, + ParallelType pt, + const DeviceMesh& mesh) { + int64_t team_size = mesh.size(pt); + at::Tensor md_index = mesh.multiDimensionalIndexOf(device_id); + auto pt_axis = mesh.parallelTypeToAxis(pt); + int64_t team_index = md_index[pt_axis].item(); + Val* team_size_val = IrBuilder::create(team_size, DataType::Index); + Val* team_index_val = IrBuilder::create(team_index, DataType::Index); + return std::make_pair( + mod(add(host_loop_index, team_index_val), team_size_val), + mod(add(team_size_val, sub(team_index_val, host_loop_index)), + team_size_val)); +} + } // namespace nvfuser diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index e924e7fcc75..bad7730fee1 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -91,4 +91,10 @@ bool isValidDeviceSplit(Expr* expr); // See tests/python/test_multidevice.py/test_matmul_allreduce_loop_split int64_t getRFactorDeviceDimensionIndex(const TensorView* tv); +std::pair dispatchSwizzle1D( + Val* my_rank, + DeviceIdxType device_id, + ParallelType pt, + const DeviceMesh& mesh); + } // namespace nvfuser diff --git a/tests/python/multidevice/test_communication.py b/tests/python/multidevice/test_communication.py index cd2c8f9e550..2d4676823f3 100644 --- a/tests/python/multidevice/test_communication.py +++ b/tests/python/multidevice/test_communication.py @@ -189,7 +189,7 @@ def test_collective_permute(multidevice_test): out_tv.set_device_mesh(mesh) out_tv.outer_split(0, d) out_tv.swizzle1d(0, nvfuser.ParallelType.mesh_x) - out_tv.axis(0).parallelize(nvfuser.ParallelType.Stream) + out_tv.axis(0).parallelize(nvfuser.ParallelType.stream) inp_ref = torch.randn(d * 3) inp = multidevice_test.shard_tensor(inp_ref, inp_tv) From 4aa8346df7fce0c61a508af21bcd289705be41f3 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Wed, 11 Feb 2026 22:44:06 -0800 Subject: [PATCH 14/15] replay swizzle1d in logical to alloc traversal --- csrc/host_ir/ir.cpp | 10 +++-- csrc/host_ir/lower_to_communication.cpp | 2 +- csrc/host_ir/lowering.cpp | 45 ++++++++++++++++--- csrc/multidevice/communication.cpp | 8 +++- csrc/tensor_metadata.cpp | 26 +++++++++++ .../python/multidevice/test_communication.py | 11 ++++- 6 files changed, 88 insertions(+), 14 deletions(-) diff --git a/csrc/host_ir/ir.cpp b/csrc/host_ir/ir.cpp index 198601355fb..bc16a39931d 100644 --- a/csrc/host_ir/ir.cpp +++ b/csrc/host_ir/ir.cpp @@ -257,9 +257,13 @@ Wait::Wait(IrBuilderPasskey passkey, Expr* expr) this, "must be registered in a HostIrContainer"); NVF_ERROR( - (expr->isOneOf()), - expr, - " must be a Communication, a P2PCommunication, or a EndCoalescing"); + (expr->isOneOf< + Communication, + CollectivePermute, + P2PCommunication, + EndCoalescing>()), + "Got: ", + expr); } NVFUSER_DEFINE_CLONE_AND_CREATE(Wait) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index a703bea6f91..ae12472cce4 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -365,7 +365,7 @@ void lowerToCollectivePermute( Swizzle1D* swizzle = stream_id->definition()->as(); ParallelType pt = swizzle->parallelType(); - const auto& [send_peer, recv_peer] = + const auto& [recv_peer, send_peer] = dispatchSwizzle1D(root, my_device_idx, pt, input_tv->getDeviceMesh()); Team team = input_tv->getDeviceMesh().vector(); comms.push_back(IrBuilder::create( diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 0f36d4f65db..ea803151b15 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -12,8 +12,10 @@ #include "host_ir/ir.h" #include "host_ir/lower_to_communication.h" #include "host_ir/ops.h" +#include "ir/builder.h" #include "ir/iostream.h" #include "ir/utils.h" +#include "kernel_ir.h" #include "multidevice/propagation.h" #include "multidevice/resharding.h" #include "multidevice/utils.h" @@ -148,6 +150,24 @@ Expr* cloneWithNewOperands( return e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes()); } +// If all allocation domain extents of tv are constant, returns a new constant +// Val for the total size. Otherwise returns nullptr. Using a constant size +// makes the Allocate independent of the loop index so it is not invalidated +// when the index changes in the evaluator. +Val* getConstantAllocationSizeIfAvailable(TensorView* tv) { + const auto* domain = tv->domain(); + int64_t size = 1; + for (IterDomain* axis : + domain->maybeAllocation() | TensorDomain::kNoReductions) { + Val* extent = axis->extent(); + if (!extent->isConst()) { + return nullptr; + } + size *= extent->evaluate().as(); + } + return IrBuilder::create(size, DataType::Index); +} + void lowerSegment( const SegmentedGroup& group, const AliasInfoMap& aliases, @@ -207,9 +227,14 @@ void lowerSegment( } } - // Allocate the recv buffers of communications - auto* allocate = - IrBuilder::create(out, out->getMemoryType()); + // Allocate the recv buffers of communications. Use a constant size + // when all extents are constant so the Allocate is independent of the + // loop index and not invalidated when it changes. + Val* constant_size = getConstantAllocationSizeIfAvailable(out); + auto* allocate = constant_size != nullptr + ? IrBuilder::create( + out, out->getMemoryType(), constant_size) + : IrBuilder::create(out, out->getMemoryType()); if (getShardedIterDomain( out, ParallelType::Stream, DomainType::kLoop) != nullptr && getShardedIterDomain( @@ -314,8 +339,11 @@ void lowerSegment( if (getShardedIterDomain( out, ParallelType::Stream, DomainType::kAllocation) == nullptr) { - auto* allocate = - IrBuilder::create(out, out->getMemoryType()); + Val* constant_size = getConstantAllocationSizeIfAvailable(out); + auto* allocate = constant_size != nullptr + ? IrBuilder::create( + out, out->getMemoryType(), constant_size) + : IrBuilder::create(out, out->getMemoryType()); innermost.parent_scope->insert( innermost.parent_insertion_point, allocate); // Loop is stream parallelized but allocation is not. Therefore, @@ -351,8 +379,11 @@ void lowerSegment( " must not be an alias, got ", alias); - auto* allocate = - IrBuilder::create(out_tv, out_tv->getMemoryType()); + Val* constant_size = getConstantAllocationSizeIfAvailable(out_tv); + auto* allocate = constant_size != nullptr + ? IrBuilder::create( + out_tv, out_tv->getMemoryType(), constant_size) + : IrBuilder::create(out_tv, out_tv->getMemoryType()); innermost_scope.pushBack(allocate); } diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 33fa785157a..78521149cab 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -842,6 +842,11 @@ c10::intrusive_ptr postCollectivePermute( c10d::Backend* backend, at::Tensor input_tensor, at::Tensor output_tensor) { + if (my_device_index == send_peer_index && + my_device_index == recv_peer_index) { + doLocalCopy(output_tensor, input_tensor); + return nullptr; + } backend->startCoalescing(); std::vector send_tensors = {input_tensor}; backend->send( @@ -965,7 +970,8 @@ c10::intrusive_ptr postSingleCommunication( } NVF_CHECK(backend != nullptr); - if (isDebugDumpEnabled(DebugDumpOption::Communication)) { + if (isDebugDumpEnabled(DebugDumpOption::Communication) && + my_device_index == 0) { debug() << "Posting " << communication->toInlineString() << " with input_tensor " << input_tensor.sizes() << " and output_tensor " << output_tensor.sizes() diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index 676e27e8805..e66b3b8e793 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -95,11 +95,24 @@ class ForwardTraverseFromLogicalToAlloc { .second); } + void handle(Swizzle1D* swizzle1d) { + // Swizzle1D does not affect allocation (same size/stride, just reindexing). + auto in = swizzle1d->in(); + auto out = swizzle1d->out(); + auto in_it = active_ids_.find(in); + auto [in_size, in_stride] = in_it->second; + NVF_ERROR(active_ids_.erase(in) == 1); + NVF_ERROR( + active_ids_.emplace(out, std::make_pair(in_size, in_stride)).second); + } + void handle(Expr* expr) { if (auto split = dynamic_cast(expr)) { handle(split); } else if (auto merge = dynamic_cast(expr)) { handle(merge); + } else if (auto swizzle1d = dynamic_cast(expr)) { + handle(swizzle1d); } else { NVF_THROW("Unsupported transormation in allocation domain"); } @@ -190,11 +203,24 @@ class BackwardTraverseFromLogicalToAlloc { .second); } + void handle(Swizzle1D* swizzle1d) { + // Swizzle1D does not affect allocation (same size/stride, just reindexing). + auto in = swizzle1d->in(); + auto out = swizzle1d->out(); + auto out_it = active_ids_.find(out); + auto [out_size, out_stride] = out_it->second; + NVF_ERROR(active_ids_.erase(out) == 1); + NVF_ERROR( + active_ids_.emplace(in, std::make_pair(out_size, out_stride)).second); + } + void handle(Expr* expr) { if (auto split = dynamic_cast(expr)) { handle(split); } else if (auto merge = dynamic_cast(expr)) { handle(merge); + } else if (auto swizzle1d = dynamic_cast(expr)) { + handle(swizzle1d); } else { NVF_THROW("Unsupported transormation in allocation domain"); } diff --git a/tests/python/multidevice/test_communication.py b/tests/python/multidevice/test_communication.py index 2d4676823f3..fa5b88178a6 100644 --- a/tests/python/multidevice/test_communication.py +++ b/tests/python/multidevice/test_communication.py @@ -195,5 +195,12 @@ def test_collective_permute(multidevice_test): inp = multidevice_test.shard_tensor(inp_ref, inp_tv) with torch.profiler.profile() as prof: (out,) = fd.execute([inp], _enable_options=["host_ir_lowering"]) - print(prof.key_averages()) - torch.testing.assert_close(out.cpu(), inp_ref) + if multidevice_test.rank == 0: + print("\nOriginal input: ", inp_ref) + + print("\nOutput: ", out) + # torch.testing.assert_close(out.cpu(), inp_ref) + # collective_permute_events = [ + # event for event in prof.events() if "ncclDevKernel_SendRecv" in event.name + # ] + # assert len(collective_permute_events) == (d - 1) From 30cd1f25529470e812a5cebe80bbfe271ab200d5 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Thu, 12 Feb 2026 15:26:44 -0800 Subject: [PATCH 15/15] working allgather example --- csrc/host_ir/evaluator.cpp | 95 +++++++++++++++++++ csrc/host_ir/lowering.cpp | 43 ++------- .../python/multidevice/test_communication.py | 14 +-- 3 files changed, 107 insertions(+), 45 deletions(-) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index bedf802b41d..99851e21db4 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -564,6 +564,101 @@ void HostIrEvaluator::handle(hir::ForLoop* for_loop) { auto stop = expr_evaluator_.evaluate(for_loop->stop()).as(); for (auto i = start; i < stop; i++) { + // This is not ideal. In lowering, we create communication expr. + // The collective permute has the output tensorview, and input vals of + // send_peer and recv_peer. While the definition of output_tv is not + // modified and remains `set`, this output_tv is a use of the vals Even + // though we shardByStream, the use of vals is not modified and has a + // dependency on T1. Cloned e: T1_g_float[istreamIdx6{1}, iS5{3}] + // (DeviceMesh{0}) + // = Set( T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}), + // cache_op=Streaming ) + + // c: CollectivePermute 77 (team=(0), send_peer=( ( 1 + ( 0 - i140 ) ) % 1 + // ), recv_peer=( ( i140 + 0 ) % 1 ), input=T0_g_float[ideviceIdx.x2{1}, + // iS3{3}] (DeviceMesh{0}), output=T1_g_float[istreamIdx6{1}, iS5{3}] + // (DeviceMesh{0}), backend=NCCL) + + // %HostIrContainer { (T0_g_float[ideviceIdx.x2{1}, iS3{3}] + // (DeviceMesh{0})) -> (T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0})) + // : + // T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) = + // ALLOCATE(buffer=T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), + // mem_type=global, size=3, zero_init=false, resets_to_zero=false) Stream + // 0x281a6c60 = GetCurrentStream() FOR i140 from 0 to 1: + // SetCurrentStream(Stream i140) + // Synchronize(Stream 0x281a6c60) + // T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) = + // ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), + // stream_index=i140) CollectivePermute 82 (team=(0), send_peer=( ( 1 + + // ( 0 - i140 ) ) % 1 ), recv_peer=( ( i140 + 0 ) % 1 ), + // input=T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}), + // output=T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}), + // backend=NCCL) Wait(Communication 82) + // SetCurrentStream(Stream 0x281a6c60) + // FOR i140 from 0 to 1: + // Synchronize(Stream i140) + // } // %HostIrContainer + + // Invalidating index: i140 + // allConsumerValsOf(i140) + // Visited val: i140 + // Consumer of i140: i163 definition: i163 = 0 - i140; + + // Visited val: i163 + // Consumer of i163: i165 definition: i165 = 1 + i163; + + // Visited val: i165 + // Consumer of i165: i167 definition: i167 = i165 % 1; + + // Visited val: i167 + // Consumer of i167: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) + // definition: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) + // = Set( T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}), + // cache_op=Streaming ) + + // Visited val: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) + // Consumer of i167: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) + // definition: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) = + // ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), + // stream_index=i140) + + // Visited val: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) + // Consumer of i140: i169 definition: i169 = i140 + 0; + + // Visited val: i169 + // Consumer of i169: i171 definition: i171 = i169 % 1; + + // Visited val: i171 + // Consumer of i171: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) + // definition: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) + // = Set( T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}), + // cache_op=Streaming ) + + // Consumer of i171: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) + // definition: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) = + // ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), + // stream_index=i140) + + // Consumer of i140: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) + // definition: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) = + // ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), + // stream_index=i140) + + // consumer_vals: 8 + // Invalidating consumer: i169 + // Invalidating consumer: T2_l_float[istreamIdx10{1}, iS9{3}] + // (DeviceMesh{0}) Invalidating consumer: T1_g_float[istreamIdx6{1}, + // iS5{3}] (DeviceMesh{0}) Invalidating consumer: i167 Invalidating + // consumer: i165 Invalidating consumer: i163 Invalidating consumer: i171 + // Invalidating consumer: i140 + + expr_evaluator_.invalidate(for_loop->index()); + for (auto consumer : allConsumerValsOf(for_loop->index())) { + if (!consumer->isA()) { + expr_evaluator_.invalidate(consumer); + } + } expr_evaluator_.bind(for_loop->index(), i); for (Expr* e : for_loop->body().exprs()) { dispatch(e); diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index ea803151b15..7c6a18865b1 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -150,24 +150,6 @@ Expr* cloneWithNewOperands( return e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes()); } -// If all allocation domain extents of tv are constant, returns a new constant -// Val for the total size. Otherwise returns nullptr. Using a constant size -// makes the Allocate independent of the loop index so it is not invalidated -// when the index changes in the evaluator. -Val* getConstantAllocationSizeIfAvailable(TensorView* tv) { - const auto* domain = tv->domain(); - int64_t size = 1; - for (IterDomain* axis : - domain->maybeAllocation() | TensorDomain::kNoReductions) { - Val* extent = axis->extent(); - if (!extent->isConst()) { - return nullptr; - } - size *= extent->evaluate().as(); - } - return IrBuilder::create(size, DataType::Index); -} - void lowerSegment( const SegmentedGroup& group, const AliasInfoMap& aliases, @@ -194,6 +176,7 @@ void lowerSegment( // If a value is already cloned, IrCloner::clone returns the cloned value // without cloning the value again. Expr* e = ir_cloner.clone(group.exprs().front()); + debug() << "Cloned e: " << e << std::endl; // TODO: `replacement_map` should be associated with the scope so // ShardByStream across segments in the same for-loop can be reused. @@ -227,14 +210,8 @@ void lowerSegment( } } - // Allocate the recv buffers of communications. Use a constant size - // when all extents are constant so the Allocate is independent of the - // loop index and not invalidated when it changes. - Val* constant_size = getConstantAllocationSizeIfAvailable(out); - auto* allocate = constant_size != nullptr - ? IrBuilder::create( - out, out->getMemoryType(), constant_size) - : IrBuilder::create(out, out->getMemoryType()); + auto* allocate = + IrBuilder::create(out, out->getMemoryType()); if (getShardedIterDomain( out, ParallelType::Stream, DomainType::kLoop) != nullptr && getShardedIterDomain( @@ -339,11 +316,8 @@ void lowerSegment( if (getShardedIterDomain( out, ParallelType::Stream, DomainType::kAllocation) == nullptr) { - Val* constant_size = getConstantAllocationSizeIfAvailable(out); - auto* allocate = constant_size != nullptr - ? IrBuilder::create( - out, out->getMemoryType(), constant_size) - : IrBuilder::create(out, out->getMemoryType()); + auto* allocate = + IrBuilder::create(out, out->getMemoryType()); innermost.parent_scope->insert( innermost.parent_insertion_point, allocate); // Loop is stream parallelized but allocation is not. Therefore, @@ -379,11 +353,8 @@ void lowerSegment( " must not be an alias, got ", alias); - Val* constant_size = getConstantAllocationSizeIfAvailable(out_tv); - auto* allocate = constant_size != nullptr - ? IrBuilder::create( - out_tv, out_tv->getMemoryType(), constant_size) - : IrBuilder::create(out_tv, out_tv->getMemoryType()); + auto* allocate = + IrBuilder::create(out_tv, out_tv->getMemoryType()); innermost_scope.pushBack(allocate); } diff --git a/tests/python/multidevice/test_communication.py b/tests/python/multidevice/test_communication.py index fa5b88178a6..be7d28aef43 100644 --- a/tests/python/multidevice/test_communication.py +++ b/tests/python/multidevice/test_communication.py @@ -195,12 +195,8 @@ def test_collective_permute(multidevice_test): inp = multidevice_test.shard_tensor(inp_ref, inp_tv) with torch.profiler.profile() as prof: (out,) = fd.execute([inp], _enable_options=["host_ir_lowering"]) - if multidevice_test.rank == 0: - print("\nOriginal input: ", inp_ref) - - print("\nOutput: ", out) - # torch.testing.assert_close(out.cpu(), inp_ref) - # collective_permute_events = [ - # event for event in prof.events() if "ncclDevKernel_SendRecv" in event.name - # ] - # assert len(collective_permute_events) == (d - 1) + torch.testing.assert_close(out.cpu(), inp_ref) + collective_permute_events = [ + event for event in prof.events() if "ncclDevKernel_SendRecv" in event.name + ] + assert len(collective_permute_events) == (d - 1)