From b2c6b0b9feacb99fdcd9289e1a3b2aac84a96bd2 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Thu, 5 Feb 2026 17:32:08 -0800 Subject: [PATCH 01/15] broadcast based allgather --- csrc/host_ir/lower_to_communication.cpp | 60 +++++++++++++++++-- csrc/host_ir/lowering.cpp | 7 ++- csrc/host_ir/ops.cpp | 2 +- csrc/multidevice/communication.cpp | 20 ++++--- csrc/multidevice/communication.h | 3 +- .../python/multidevice/test_communication.py | 35 +++++++++++ 6 files changed, 112 insertions(+), 15 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index d91fd4eda60..db45fa010d9 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -166,6 +166,32 @@ void lowerToBroadcast( backend)); } +void lowerToStreamBroadcast( + TensorView* input_tv, + TensorView* output_tv, + const CommunicatorBackend backend, + std::vector& comms) { + const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); + const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); + NVF_ERROR_EQ( + sender_mesh, + receiver_mesh, + "StreamBroadcast sender and receiver meshes must be the same. Given ", + sender_mesh, + " and ", + receiver_mesh); + Team team = receiver_mesh.vector(); + comms.push_back(IrBuilder::create( + CommunicationType::StreamBroadcast, + output_tv, + input_tv, + team, + /*root=*/-1, // This will be replaced by HostIrLowering with the for-loop + // index + c10d::ReduceOp::RedOpType::UNUSED, + backend)); +} + // Adds several SendRecv communications to the vector 'comms' // For now, we assume that this function is called only if // the input and output have the same sharding. Later we could support more @@ -370,14 +396,16 @@ CommunicationInfo getCommunicationInfo(Expr* e) { const auto c2p_map = pairwise_map.mapConsumerToProducer(); // This ignores device dimensions on reduction axis. - auto producer_pt_to_did = + auto producer_pt_to_id = mapDeviceAndStreamParallelTypeToId(producer->getLoopDomain()); - auto consumer_pt_to_did = + auto consumer_pt_to_id = mapDeviceAndStreamParallelTypeToId(consumer->getLoopDomain()); for (ParallelType pt : kParallelTypeDIDs) { - IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt); - IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt); + IterDomain* p_loop_did = getOrDefault(producer_pt_to_id, pt); + IterDomain* c_loop_did = getOrDefault(consumer_pt_to_id, pt); + IterDomain* c_stream_id = + getOrDefault(consumer_pt_to_id, ParallelType::Stream); if (p_loop_did == nullptr && c_loop_did == nullptr) { // Not sharded on this parallel type @@ -391,6 +419,24 @@ CommunicationInfo getCommunicationInfo(Expr* e) { if (e->isA()) { if (p_loop_did && !c_loop_did) { IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); + // Check if we are going from DIDx -> Stream, which is a ring allgather. + // This can be executed as a broadcast or send recvs, which is decided + // by the presence of a swizzle in the stream id definition. + if (c_stream_id != nullptr) { + IterDomain* c_stream_logical_id = + getLogicalFromLoopId(consumer, c_stream_id); + if (c_stream_logical_id == p2c_map.at(p_logical_id)) { + NVF_CHECK( + same_mesh, + "Broadcast based allgather in stream parallel requires same " + "mesh.") + fill_communication_info( + CommunicationType::StreamBroadcast, + p_logical_id, + c_stream_logical_id); + continue; + } + } CommunicationType type = same_mesh ? CommunicationType::Allgather : CommunicationType::Gather; fill_communication_info(type, p_logical_id, p2c_map.at(p_logical_id)); @@ -478,7 +524,8 @@ Layout getCommunicationLayout( type == CommunicationType::Allreduce || type == CommunicationType::Broadcast || type == CommunicationType::SendRecv || - type == CommunicationType::AllToAll) { + type == CommunicationType::AllToAll || + type == CommunicationType::StreamBroadcast) { return layout; } @@ -605,6 +652,9 @@ std::vector convertSingleOpToCommunication( case CommunicationType::AllToAll: lowerToAllToAll(input_tv, output_tv, backend, comms); break; + case CommunicationType::StreamBroadcast: + lowerToStreamBroadcast(input_tv, output_tv, backend, comms); + break; } return comms; diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 3924d5658aa..63010a768c2 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -186,7 +186,8 @@ void lowerSegment( auto* communication = c->as(); TensorView* in = communication->in(); TensorView* out = communication->out(); - if (haveDifferentShardings( + if (communication->type() != CommunicationType::StreamBroadcast && + haveDifferentShardings( in, DomainType::kAllocation, out, @@ -219,6 +220,10 @@ void lowerSegment( innermost_scope.pushBack(allocate); } + if (communication->type() == CommunicationType::StreamBroadcast) { + replacement_map[communication->root()] = innermost.loop->index(); + } + Expr* new_c = cloneWithNewOperands(c, replacement_map); innermost_scope.pushBack(new_c); diff --git a/csrc/host_ir/ops.cpp b/csrc/host_ir/ops.cpp index 05fd42e2764..f0c555bf28b 100644 --- a/csrc/host_ir/ops.cpp +++ b/csrc/host_ir/ops.cpp @@ -90,7 +90,7 @@ TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e) { nullptr, "Destination allocation should be sharded on stream after " "shardAllocationAsLoop: ", - destination); + destination->domain()->toString(0, /*loop_only=*/false)); // Refine the contiguity flags so `out` aliases `in`. This is done similar // to AliasFinder::handle(const SliceOp*). We scan through the allocation diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 6778da9da71..7a4a5ac66a9 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -57,6 +57,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { case CommunicationType::AllToAll: os << "AllToAll"; break; + case CommunicationType::StreamBroadcast: + os << "StreamBroadcast"; + break; } return os; } @@ -149,6 +152,7 @@ bool hasRoot(CommunicationType type) { case CommunicationType::Reduce: case CommunicationType::Broadcast: case CommunicationType::SendRecv: + case CommunicationType::StreamBroadcast: return true; case CommunicationType::Allgather: case CommunicationType::Allreduce: @@ -171,6 +175,7 @@ bool isReduction(CommunicationType type) { case CommunicationType::Broadcast: case CommunicationType::SendRecv: case CommunicationType::AllToAll: + case CommunicationType::StreamBroadcast: return false; default: NVF_THROW("unrecognized CommunicationType: ", type); @@ -231,13 +236,13 @@ Communication::Communication( void Communication::validate() { if (root()->isConstScalar() && root()->isIntegralScalar()) { - auto root_val = root()->evaluate().as(); - NVF_ERROR( - hasRoot(type()) == (root_val >= 0), - "Root ", - root_val, - " is not expected by CommunicationType ", - type()); + // auto root_val = root()->evaluate().as(); + // NVF_ERROR( + // hasRoot(type()) == (root_val >= 0), + // "Root ", + // root_val, + // " is not expected by CommunicationType ", + // type()); } NVF_ERROR(isReduction(type()) == (reduceOp() != RedOpType::UNUSED)); } @@ -716,6 +721,7 @@ c10::intrusive_ptr postSingleCommunication( input_tensor, output_tensor); case CommunicationType::Broadcast: + case CommunicationType::StreamBroadcast: return postBroadcast( communication, my_device_index, diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 1a7f1a1cc4c..24101deeeb4 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -33,7 +33,8 @@ enum class CommunicationType { ReduceScatter, Broadcast, SendRecv, - AllToAll + AllToAll, + StreamBroadcast, }; std::ostream& operator<<(std::ostream& os, const CommunicationType& type); diff --git a/tests/python/multidevice/test_communication.py b/tests/python/multidevice/test_communication.py index 833ab511ff3..19f32509012 100644 --- a/tests/python/multidevice/test_communication.py +++ b/tests/python/multidevice/test_communication.py @@ -171,3 +171,38 @@ def test_alltoall(multidevice_test, inp_axis, out_axis): inp = multidevice_test.shard_tensor(in_ref, inp_tv) (out,) = fd.execute([inp]) torch.testing.assert_close(out, multidevice_test.shard_tensor(out_ref, out_tv)) + + +@pytest.mark.mpi +def test_broadcast_based_allgather(multidevice_test): + d = multidevice_test.size + + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) + + with FusionDefinition() as fd: + inp_tv = fd.define_tensor((d * 3,), contiguity=True, dtype=DataType.Half) + out_tv = fd.ops.set(inp_tv) + fd.add_output(out_tv) + + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) + inp_tv.set_device_mesh(mesh) + inp_tv.outer_split(0, d) + inp_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x) + + out_tv.set_device_mesh(mesh) + out_tv.outer_split(0, d) + out_tv.axis(0).parallelize(nvfuser.ParallelType.stream) + + unsharded_inp = torch.randn(d * 3, dtype=torch.float16) + inp = multidevice_test.shard_tensor(unsharded_inp, inp_tv) + with torch.profiler.profile(record_shapes=True) as profile: + (out,) = fd.execute( + [inp], + _enable_options=["host_ir_lowering"], + _disable_options=["infer_contiguity"], + ) + broadcast_events = [ + event for event in profile.events() if "ncclDevKernel_Broadcast" in event.name + ] + torch.testing.assert_close(out.cpu(), unsharded_inp) + print(broadcast_events, len(broadcast_events)) From 9fed72f3ac5fe4ab7744e87f50e958b04bc55df2 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Thu, 5 Feb 2026 18:10:15 -0800 Subject: [PATCH 02/15] allow host loop index as a variable --- csrc/host_ir/lower_to_communication.cpp | 13 +++++++++---- csrc/host_ir/lower_to_communication.h | 1 + csrc/host_ir/lowering.cpp | 7 ++----- csrc/host_ir/pass/convert_op_to_communication.cpp | 5 ++++- csrc/multidevice/communication.cpp | 14 +++++++------- 5 files changed, 23 insertions(+), 17 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index db45fa010d9..960113941c6 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -170,7 +170,8 @@ void lowerToStreamBroadcast( TensorView* input_tv, TensorView* output_tv, const CommunicatorBackend backend, - std::vector& comms) { + std::vector& comms, + Val* root) { const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); NVF_ERROR_EQ( @@ -186,8 +187,7 @@ void lowerToStreamBroadcast( output_tv, input_tv, team, - /*root=*/-1, // This will be replaced by HostIrLowering with the for-loop - // index + root, c10d::ReduceOp::RedOpType::UNUSED, backend)); } @@ -583,6 +583,7 @@ bool isCommunicationLayoutCompliant(Expr* expr) { std::vector convertSingleOpToCommunication( Expr* e, DeviceIdxType my_device_idx, + Val* host_loop_index, const CommunicatorBackend backend) { FusionGuard fg(e->fusion()); @@ -653,7 +654,11 @@ std::vector convertSingleOpToCommunication( lowerToAllToAll(input_tv, output_tv, backend, comms); break; case CommunicationType::StreamBroadcast: - lowerToStreamBroadcast(input_tv, output_tv, backend, comms); + NVF_ERROR( + host_loop_index != nullptr, + "StreamBroadcast requires a host loop index"); + lowerToStreamBroadcast( + input_tv, output_tv, backend, comms, /*root=*/host_loop_index); break; } diff --git a/csrc/host_ir/lower_to_communication.h b/csrc/host_ir/lower_to_communication.h index 65a9182e605..4b1c657e0e5 100644 --- a/csrc/host_ir/lower_to_communication.h +++ b/csrc/host_ir/lower_to_communication.h @@ -54,6 +54,7 @@ Layout getCommunicationLayout( std::vector convertSingleOpToCommunication( Expr* c, DeviceIdxType my_device_idx, + Val* host_loop_index = nullptr, const CommunicatorBackend backend = CommunicatorBackend::kNccl); } // namespace nvfuser diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 63010a768c2..01f271ea636 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -178,7 +178,8 @@ void lowerSegment( // TODO: `replacement_map` should be associated with the scope so // ShardByStream across segments in the same for-loop can be reused. std::unordered_map replacement_map; - for (Expr* c : convertSingleOpToCommunication(e, device_id)) { + for (Expr* c : convertSingleOpToCommunication( + e, device_id, innermost.loop->index())) { NVF_ERROR( c->isA(), "Exprs in a Communication group should be Communication: ", @@ -220,10 +221,6 @@ void lowerSegment( innermost_scope.pushBack(allocate); } - if (communication->type() == CommunicationType::StreamBroadcast) { - replacement_map[communication->root()] = innermost.loop->index(); - } - Expr* new_c = cloneWithNewOperands(c, replacement_map); innermost_scope.pushBack(new_c); diff --git a/csrc/host_ir/pass/convert_op_to_communication.cpp b/csrc/host_ir/pass/convert_op_to_communication.cpp index 4ce4c59b0ce..07305fb5e09 100644 --- a/csrc/host_ir/pass/convert_op_to_communication.cpp +++ b/csrc/host_ir/pass/convert_op_to_communication.cpp @@ -35,7 +35,10 @@ void ConvertOpToCommunication::passImplementation(Fusion* fusion) { return new_top_level_exprs.push_back(top_level_expr); } for (auto* expr : nvfuser::convertSingleOpToCommunication( - top_level_expr, my_device_index, params_.communicator_backend)) { + top_level_expr, + my_device_index, + /*host_loop_index=*/nullptr, + params_.communicator_backend)) { // Allocate the recv buffers of communications if (expr->isA()) { auto* communication = expr->as(); diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 7a4a5ac66a9..21bc4994628 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -236,13 +236,13 @@ Communication::Communication( void Communication::validate() { if (root()->isConstScalar() && root()->isIntegralScalar()) { - // auto root_val = root()->evaluate().as(); - // NVF_ERROR( - // hasRoot(type()) == (root_val >= 0), - // "Root ", - // root_val, - // " is not expected by CommunicationType ", - // type()); + auto root_val = root()->evaluate().as(); + NVF_ERROR( + hasRoot(type()) == (root_val >= 0), + "Root ", + root_val, + " is not expected by CommunicationType ", + type()); } NVF_ERROR(isReduction(type()) == (reduceOp() != RedOpType::UNUSED)); } From 2825e4f4379ecd4186735b6681154c7cfd53256f Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Fri, 6 Feb 2026 13:32:32 -0800 Subject: [PATCH 03/15] benchmark bcast + matmul decomposition --- .../python/multidevice/test_communication.py | 35 ----- tests/python/multidevice/test_overlap.py | 126 ++++++++++++++++++ 2 files changed, 126 insertions(+), 35 deletions(-) diff --git a/tests/python/multidevice/test_communication.py b/tests/python/multidevice/test_communication.py index 19f32509012..833ab511ff3 100644 --- a/tests/python/multidevice/test_communication.py +++ b/tests/python/multidevice/test_communication.py @@ -171,38 +171,3 @@ def test_alltoall(multidevice_test, inp_axis, out_axis): inp = multidevice_test.shard_tensor(in_ref, inp_tv) (out,) = fd.execute([inp]) torch.testing.assert_close(out, multidevice_test.shard_tensor(out_ref, out_tv)) - - -@pytest.mark.mpi -def test_broadcast_based_allgather(multidevice_test): - d = multidevice_test.size - - mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) - - with FusionDefinition() as fd: - inp_tv = fd.define_tensor((d * 3,), contiguity=True, dtype=DataType.Half) - out_tv = fd.ops.set(inp_tv) - fd.add_output(out_tv) - - mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) - inp_tv.set_device_mesh(mesh) - inp_tv.outer_split(0, d) - inp_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x) - - out_tv.set_device_mesh(mesh) - out_tv.outer_split(0, d) - out_tv.axis(0).parallelize(nvfuser.ParallelType.stream) - - unsharded_inp = torch.randn(d * 3, dtype=torch.float16) - inp = multidevice_test.shard_tensor(unsharded_inp, inp_tv) - with torch.profiler.profile(record_shapes=True) as profile: - (out,) = fd.execute( - [inp], - _enable_options=["host_ir_lowering"], - _disable_options=["infer_contiguity"], - ) - broadcast_events = [ - event for event in profile.events() if "ncclDevKernel_Broadcast" in event.name - ] - torch.testing.assert_close(out.cpu(), unsharded_inp) - print(broadcast_events, len(broadcast_events)) diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index c6453e513c5..4fb82bf43fc 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -417,6 +417,132 @@ def test_column_parallel_linear_forward_reference_benchmark( benchmark.pedantic(benchmark_fn, rounds=5) +@pytest.mark.mpi +def column_parallel_linear_forward(h: int, d: int): + with FusionDefinition() as fd: + inp_tv = fd.define_tensor((-1, h), contiguity=True, dtype=DataType.BFloat16) + weight_tv = fd.define_tensor( + (4 * h, h), contiguity=True, dtype=DataType.BFloat16 + ) + ag_out = fd.ops.set(inp_tv) + out_tv = fd.ops.linear(ag_out, weight_tv) + fd.add_output(out_tv) + + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) + + for tv in [inp_tv, weight_tv]: + tv.set_device_mesh(mesh) + tv.outer_split(0, d) + tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x) + + ag_out.set_device_mesh(mesh) + ag_out.outer_split(0, d) + ag_out.axis(0).parallelize(nvfuser.ParallelType.stream) + + # Fusion IR before segmentation will look like this: + # [t, h] + # /\. + # d + # (deviceIdx.x) + # | + # | set (lowered to StreamBroadcast. This decomposition is done manually in the definition above. It will later be done by preseg) + # | + # [t, h] [4h, h] + # /\ /\. + # s d + # (streamIdx) + # | + # | linear + # | + # [t, 4h, r{h}] + # /\ /\. + # s* d + + return fd + + +@pytest.mark.mpi +def test_column_parallel_linear_forward(multidevice_test): + # This is a port of CollectiveBasedOverlapTest.ColumnAndSequenceParallelLinear_Forward. + # The difference is we are using broadcast based overlapping instead of send/recv. + h, t = 2, 24 + d = multidevice_test.size + if (h * 4) % d != 0: + pytest.skip( + f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + ) + if t % d != 0: + pytest.skip( + f"Column-parallel linear requires {t} to be divisible by world size {d}." + ) + + fd = column_parallel_linear_forward(h, d) + + inp_ref = torch.testing.make_tensor(t, h, dtype=torch.int32, device="cpu").to( + torch.bfloat16 + ) + weight_ref = torch.testing.make_tensor( + 4 * h, h, dtype=torch.int32, device="cpu" + ).to(torch.bfloat16) + + inp = multidevice_test.shard_tensor(inp_ref, fd.fusion.inputs()[0]) + weight = multidevice_test.shard_tensor(weight_ref, fd.fusion.inputs()[1]) + + out_ref = torch.nn.functional.linear(inp_ref.cuda(), weight) + + with torch.profiler.profile(record_shapes=True) as prof: + (out,) = fd.execute( + [inp, weight], + _enable_options=["host_ir_lowering"], + _disable_options=["infer_contiguity"], + ) + torch.testing.assert_close(out, out_ref) + with torch.profiler.profile(record_shapes=True) as profile: + (out,) = fd.execute( + [inp], + _enable_options=["host_ir_lowering"], + _disable_options=["infer_contiguity"], + ) + broadcast_events = [ + event for event in profile.events() if "ncclDevKernel_Broadcast" in event.name + ] + assert len(broadcast_events) == d + + +@pytest.mark.mpi +@pytest.mark.benchmark +def test_column_parallel_linear_forward_benchmark(multidevice_test, benchmark): + # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. + h, t = 8192, 8192 + d = multidevice_test.size + if (4 * h) % d != 0: + pytest.skip( + f"Column-parallel linear requires {4 * h} to be divisible by world size {d}." + ) + if t % d != 0: + pytest.skip( + f"Column-parallel linear requires {t} to be divisible by world size {d}." + ) + + fd = column_parallel_linear_forward(h, d) + + inp_ref = torch.randn(t, h, dtype=torch.bfloat16, device="cpu") + weight_ref = torch.randn(4 * h, h, dtype=torch.bfloat16, device="cpu") + + inp = multidevice_test.shard_tensor(inp_ref, fd.fusion.inputs()[0]) + weight = multidevice_test.shard_tensor(weight_ref, fd.fusion.inputs()[1]) + + warmup_fn, benchmark_fn = get_benchmark_fns( + lambda: fd.execute( + [inp, weight], + _enable_options=["host_ir_lowering"], + _disable_options=["infer_contiguity"], + ) + ) + warmup_fn() + benchmark.pedantic(benchmark_fn, rounds=5) + + @pytest.mark.mpi @pytest.mark.parametrize("backend_type", [CommunicatorBackend.nccl]) @pytest.mark.parametrize("s", [1, 8]) From 9df94e839d281678dd4b1370622002cbbd3d6d54 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Fri, 6 Feb 2026 13:34:57 -0800 Subject: [PATCH 04/15] remove disable option --- tests/python/multidevice/test_overlap.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index 4fb82bf43fc..bd5b105cb43 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -497,14 +497,8 @@ def test_column_parallel_linear_forward(multidevice_test): _disable_options=["infer_contiguity"], ) torch.testing.assert_close(out, out_ref) - with torch.profiler.profile(record_shapes=True) as profile: - (out,) = fd.execute( - [inp], - _enable_options=["host_ir_lowering"], - _disable_options=["infer_contiguity"], - ) broadcast_events = [ - event for event in profile.events() if "ncclDevKernel_Broadcast" in event.name + event for event in prof.events() if "ncclDevKernel_Broadcast" in event.name ] assert len(broadcast_events) == d @@ -536,7 +530,6 @@ def test_column_parallel_linear_forward_benchmark(multidevice_test, benchmark): lambda: fd.execute( [inp, weight], _enable_options=["host_ir_lowering"], - _disable_options=["infer_contiguity"], ) ) warmup_fn() From 5ee83650e39ab26df42ac86b340cef544f31ce45 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Fri, 6 Feb 2026 15:27:49 -0800 Subject: [PATCH 05/15] comment --- csrc/host_ir/lower_to_communication.cpp | 1 + csrc/multidevice/communication.h | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 960113941c6..34594a02613 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -422,6 +422,7 @@ CommunicationInfo getCommunicationInfo(Expr* e) { // Check if we are going from DIDx -> Stream, which is a ring allgather. // This can be executed as a broadcast or send recvs, which is decided // by the presence of a swizzle in the stream id definition. + // TODO: Lower to SendRecv if swizzle is present. if (c_stream_id != nullptr) { IterDomain* c_stream_logical_id = getLogicalFromLoopId(consumer, c_stream_id); diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 24101deeeb4..5b08de1db8a 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -190,6 +190,10 @@ class P2PCommunication : public Expr { // - the root has one src buffer, and no or one dst buffer // - non-roots have no src buffer and one dst buffer // - all buffers have the same size +// (*) StreamBroadcast +// Shares the same postBroadcast logic with Broadcast. The difference is the +// root is the for-loop index. I kept it separate from Broadcast I do not need +// to inspect the tensorviews later if we have to distinguish the two. // (*) Gather // Copies each device's source buffer to the root's respective src // buffer. The order of the sender devices matches the order of the From 5d604e12e837a50a0ed6a4a9fdffad47cce0d6ce Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Mon, 9 Feb 2026 13:07:58 -0800 Subject: [PATCH 06/15] comment --- csrc/multidevice/communication.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 5b08de1db8a..ab864c3a44a 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -192,8 +192,8 @@ class P2PCommunication : public Expr { // - all buffers have the same size // (*) StreamBroadcast // Shares the same postBroadcast logic with Broadcast. The difference is the -// root is the for-loop index. I kept it separate from Broadcast I do not need -// to inspect the tensorviews later if we have to distinguish the two. +// root is the for-loop index. I kept it separate from Broadcast so I don't need +// to inspect the tensorviews later to distinguish the two. // (*) Gather // Copies each device's source buffer to the root's respective src // buffer. The order of the sender devices matches the order of the From 901f809ea6ccd7a10566388c180f07d8db805633 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Mon, 9 Feb 2026 13:10:25 -0800 Subject: [PATCH 07/15] rm disable option --- tests/python/multidevice/test_overlap.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index bd5b105cb43..e5d287ec61f 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -491,11 +491,7 @@ def test_column_parallel_linear_forward(multidevice_test): out_ref = torch.nn.functional.linear(inp_ref.cuda(), weight) with torch.profiler.profile(record_shapes=True) as prof: - (out,) = fd.execute( - [inp, weight], - _enable_options=["host_ir_lowering"], - _disable_options=["infer_contiguity"], - ) + (out,) = fd.execute([inp, weight], _enable_options=["host_ir_lowering"]) torch.testing.assert_close(out, out_ref) broadcast_events = [ event for event in prof.events() if "ncclDevKernel_Broadcast" in event.name From d21deedb653db55e2c4f7a58918371bcdf91f750 Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Mon, 9 Feb 2026 13:35:40 -0800 Subject: [PATCH 08/15] Update csrc/host_ir/lower_to_communication.cpp Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- csrc/host_ir/lower_to_communication.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 34594a02613..fd0353decc0 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -424,14 +424,10 @@ CommunicationInfo getCommunicationInfo(Expr* e) { // by the presence of a swizzle in the stream id definition. // TODO: Lower to SendRecv if swizzle is present. if (c_stream_id != nullptr) { - IterDomain* c_stream_logical_id = - getLogicalFromLoopId(consumer, c_stream_id); - if (c_stream_logical_id == p2c_map.at(p_logical_id)) { NVF_CHECK( same_mesh, "Broadcast based allgather in stream parallel requires same " - "mesh.") - fill_communication_info( + "mesh."); CommunicationType::StreamBroadcast, p_logical_id, c_stream_logical_id); From 6e862cb705a6f4764375efcc44fa16e93782378a Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Mon, 9 Feb 2026 15:14:29 -0800 Subject: [PATCH 09/15] fix broken greptile update --- csrc/host_ir/lower_to_communication.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index fd0353decc0..c963163caef 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -424,10 +424,14 @@ CommunicationInfo getCommunicationInfo(Expr* e) { // by the presence of a swizzle in the stream id definition. // TODO: Lower to SendRecv if swizzle is present. if (c_stream_id != nullptr) { + IterDomain* c_stream_logical_id = + getLogicalFromLoopId(consumer, c_stream_id); + if (c_stream_logical_id == p2c_map.at(p_logical_id)) { NVF_CHECK( same_mesh, "Broadcast based allgather in stream parallel requires same " "mesh."); + fill_communication_info( CommunicationType::StreamBroadcast, p_logical_id, c_stream_logical_id); From 33ff8fcf5101d3697403b3b9f3ea3e995e166251 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Mon, 9 Feb 2026 15:17:38 -0800 Subject: [PATCH 10/15] remove pytest marker from fusion function --- tests/python/multidevice/test_overlap.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index e5d287ec61f..24f6234edcc 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -417,7 +417,6 @@ def test_column_parallel_linear_forward_reference_benchmark( benchmark.pedantic(benchmark_fn, rounds=5) -@pytest.mark.mpi def column_parallel_linear_forward(h: int, d: int): with FusionDefinition() as fd: inp_tv = fd.define_tensor((-1, h), contiguity=True, dtype=DataType.BFloat16) From be4f66b9dd7a7bb63ce478a0bc4a508c26ae9a17 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Tue, 10 Feb 2026 14:12:09 -0800 Subject: [PATCH 11/15] review --- csrc/host_ir/lower_to_communication.cpp | 13 ++++++------- csrc/host_ir/lower_to_communication.h | 7 ++++++- csrc/host_ir/ops.cpp | 4 +++- csrc/host_ir/pass/convert_op_to_communication.cpp | 2 +- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index c963163caef..6bd394f8b7f 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -396,9 +396,9 @@ CommunicationInfo getCommunicationInfo(Expr* e) { const auto c2p_map = pairwise_map.mapConsumerToProducer(); // This ignores device dimensions on reduction axis. - auto producer_pt_to_id = + const std::unordered_map& producer_pt_to_id = mapDeviceAndStreamParallelTypeToId(producer->getLoopDomain()); - auto consumer_pt_to_id = + const std::unordered_map& consumer_pt_to_id = mapDeviceAndStreamParallelTypeToId(consumer->getLoopDomain()); for (ParallelType pt : kParallelTypeDIDs) { @@ -584,7 +584,7 @@ bool isCommunicationLayoutCompliant(Expr* expr) { std::vector convertSingleOpToCommunication( Expr* e, DeviceIdxType my_device_idx, - Val* host_loop_index, + Val* root, const CommunicatorBackend backend) { FusionGuard fg(e->fusion()); @@ -656,10 +656,9 @@ std::vector convertSingleOpToCommunication( break; case CommunicationType::StreamBroadcast: NVF_ERROR( - host_loop_index != nullptr, - "StreamBroadcast requires a host loop index"); - lowerToStreamBroadcast( - input_tv, output_tv, backend, comms, /*root=*/host_loop_index); + root != nullptr, + "StreamBroadcast requires a root value passed in through lowering"); + lowerToStreamBroadcast(input_tv, output_tv, backend, comms, root); break; } diff --git a/csrc/host_ir/lower_to_communication.h b/csrc/host_ir/lower_to_communication.h index 4b1c657e0e5..8c789377478 100644 --- a/csrc/host_ir/lower_to_communication.h +++ b/csrc/host_ir/lower_to_communication.h @@ -51,10 +51,15 @@ Layout getCommunicationLayout( const CommunicationType type, IterDomain* sharded_id); +// Creates a communication expr corresponding to the given +// resharding expr. In most cases, `root` is inferred based +// on communication type. However, in some cases, for e.g. +// decomposing allgather as broadcast in a host for-loop, `root` +// may be passed in through lowering. std::vector convertSingleOpToCommunication( Expr* c, DeviceIdxType my_device_idx, - Val* host_loop_index = nullptr, + Val* root = nullptr, const CommunicatorBackend backend = CommunicatorBackend::kNccl); } // namespace nvfuser diff --git a/csrc/host_ir/ops.cpp b/csrc/host_ir/ops.cpp index f0c555bf28b..8e6441c5d8a 100644 --- a/csrc/host_ir/ops.cpp +++ b/csrc/host_ir/ops.cpp @@ -89,7 +89,9 @@ TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e) { destination, ParallelType::Stream, DomainType::kAllocation) != nullptr, "Destination allocation should be sharded on stream after " - "shardAllocationAsLoop: ", + "shardAllocationAsLoop. ", + destination->name(), + ":", destination->domain()->toString(0, /*loop_only=*/false)); // Refine the contiguity flags so `out` aliases `in`. This is done similar diff --git a/csrc/host_ir/pass/convert_op_to_communication.cpp b/csrc/host_ir/pass/convert_op_to_communication.cpp index 07305fb5e09..c8402811230 100644 --- a/csrc/host_ir/pass/convert_op_to_communication.cpp +++ b/csrc/host_ir/pass/convert_op_to_communication.cpp @@ -37,7 +37,7 @@ void ConvertOpToCommunication::passImplementation(Fusion* fusion) { for (auto* expr : nvfuser::convertSingleOpToCommunication( top_level_expr, my_device_index, - /*host_loop_index=*/nullptr, + /*root=*/nullptr, params_.communicator_backend)) { // Allocate the recv buffers of communications if (expr->isA()) { From 493c4ad026122c1e0e3dc2a62a4d578959c13847 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Fri, 13 Feb 2026 13:26:26 -0800 Subject: [PATCH 12/15] fix condition for 1 device --- tests/python/multidevice/test_overlap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index 24f6234edcc..696a6412929 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -495,7 +495,7 @@ def test_column_parallel_linear_forward(multidevice_test): broadcast_events = [ event for event in prof.events() if "ncclDevKernel_Broadcast" in event.name ] - assert len(broadcast_events) == d + assert len(broadcast_events) == (d if d > 1 else 0) @pytest.mark.mpi From 4bf2d11d064a0fa089eba9ce195082d0d49691ac Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Tue, 17 Feb 2026 14:29:30 -0800 Subject: [PATCH 13/15] test with removing stale exprs --- csrc/host_ir/evaluator.cpp | 4 ++++ csrc/host_ir/lowering.cpp | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 2396767d5b0..8ea2fd158ab 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -478,6 +478,10 @@ void HostIrEvaluator::handle(hir::ForLoop* for_loop) { auto stop = expr_evaluator_.evaluate(for_loop->stop()).as(); for (auto i = start; i < stop; i++) { + expr_evaluator_.invalidate(for_loop->index()); + for (auto consumer : allConsumerValsOf(for_loop->index())) { + expr_evaluator_.invalidate(consumer); + } expr_evaluator_.bind(for_loop->index(), i); for (Expr* e : for_loop->body().exprs()) { dispatch(e); diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 01f271ea636..58243c859c4 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -145,7 +145,10 @@ Expr* cloneWithNewOperands( return e; } - return e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes()); + auto* new_e = + e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes()); + e->container()->removeExpr(e); + return new_e; } void lowerSegment( From 1c23e50997c332ab8deaf425b2a3e72ad728e261 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Tue, 17 Feb 2026 18:10:32 -0800 Subject: [PATCH 14/15] invalidate only allocations for now --- csrc/host_ir/evaluator.cpp | 13 +++++++++++++ csrc/host_ir/lower_to_communication.cpp | 5 +++-- csrc/host_ir/lowering.cpp | 4 ++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 5d52d06e87f..30cbbc3dc0c 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -532,8 +532,21 @@ void HostIrEvaluator::handle(hir::ForLoop* for_loop) { auto stop = expr_evaluator_.evaluate(for_loop->stop()).as(); for (auto i = start; i < stop; i++) { + // Expressions dependent on loop index and all allocations + // inside the loop body should be invalidated. We cannot + // simply use allConsumerValsOf because the loop index can be an input to + // fusion outputs or buffers allocated outside the loop. + std::unordered_set allocations; + for (Expr* e : for_loop->body().exprs()) { + if (auto* alloc = dynamic_cast(e)) { + allocations.insert(alloc->buffer()); + } + } expr_evaluator_.invalidate(for_loop->index()); for (auto consumer : allConsumerValsOf(for_loop->index())) { + if (consumer->isA() && !allocations.contains(consumer)) { + continue; + } expr_evaluator_.invalidate(consumer); } expr_evaluator_.bind(for_loop->index(), i); diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index f393a335c1b..2aea046ff2b 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -402,11 +402,12 @@ CommunicationInfo getCommunicationInfo(Expr* e) { const std::unordered_map& consumer_pt_to_id = mapDeviceAndStreamParallelTypeToId(consumer->getLoopDomain()); + IterDomain* c_stream_id = + getOrDefault(consumer_pt_to_id, ParallelType::Stream); + for (ParallelType pt : kParallelTypeDIDs) { IterDomain* p_loop_did = getOrDefault(producer_pt_to_id, pt); IterDomain* c_loop_did = getOrDefault(consumer_pt_to_id, pt); - IterDomain* c_stream_id = - getOrDefault(consumer_pt_to_id, ParallelType::Stream); if (p_loop_did == nullptr && c_loop_did == nullptr) { // Not sharded on this parallel type diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 58243c859c4..95ea945e895 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -181,8 +181,8 @@ void lowerSegment( // TODO: `replacement_map` should be associated with the scope so // ShardByStream across segments in the same for-loop can be reused. std::unordered_map replacement_map; - for (Expr* c : convertSingleOpToCommunication( - e, device_id, innermost.loop->index())) { + Val* root = loop_nest.empty() ? nullptr : innermost.loop->index(); + for (Expr* c : convertSingleOpToCommunication(e, device_id, root)) { NVF_ERROR( c->isA(), "Exprs in a Communication group should be Communication: ", From 6df1d46e2d159fe326d069ae6cd4b955606da9ce Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Thu, 19 Feb 2026 13:51:59 -0800 Subject: [PATCH 15/15] when sharding by stream input maintains its did parallelization --- csrc/host_ir/lowering.cpp | 41 ++++++++--- csrc/host_ir/ops.cpp | 21 ++++-- csrc/multidevice/propagation.cpp | 76 ++++++++++++++++++++ csrc/multidevice/propagation.h | 8 +++ csrc/preseg_passes/decompose_reshardings.cpp | 61 +--------------- 5 files changed, 130 insertions(+), 77 deletions(-) diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 95ea945e895..a08f965c5d8 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -190,18 +190,22 @@ void lowerSegment( auto* communication = c->as(); TensorView* in = communication->in(); TensorView* out = communication->out(); - if (communication->type() != CommunicationType::StreamBroadcast && - haveDifferentShardings( + if (haveDifferentShardings( in, DomainType::kAllocation, out, DomainType::kLoop, {ParallelType::Stream})) { - Val*& sharded_in = replacement_map[in]; - if (sharded_in == nullptr) { - sharded_in = + if (!replacement_map.contains(in)) { + TensorView* sharded_in = hir::shardByStream(in, innermost.loop->index(), communication); - innermost_scope.pushBack(sharded_in->definition()); + if (sharded_in != nullptr) { + // `sharded_in` is nullptr if the input cannot be sharded by + // stream such as in broadcast or collective-permute based + // decomposition of allgather. + replacement_map[in] = sharded_in; + innermost_scope.pushBack(sharded_in->definition()); + } } } @@ -215,11 +219,18 @@ void lowerSegment( nullptr) { innermost.parent_scope->insert( innermost.parent_insertion_point, allocate); - auto [i, inserted] = replacement_map.emplace( - out, - hir::shardByStream(out, innermost.loop->index(), communication)); - NVF_ERROR(inserted, "The input segmented fusion should be SSA."); - innermost_scope.pushBack(i->second->definition()); + NVF_ERROR_EQ( + replacement_map.contains(out), + false, + "The input segmented fusion should be SSA."); + TensorView* sharded_out = + hir::shardByStream(out, innermost.loop->index(), communication); + NVF_ERROR( + sharded_out != nullptr, + "Output could not be sharded by stream: ", + out); + replacement_map[out] = sharded_out; + innermost_scope.pushBack(sharded_out->definition()); } else { innermost_scope.pushBack(allocate); } @@ -301,6 +312,10 @@ void lowerSegment( {ParallelType::Stream})) { TensorView* sharded_in = hir::shardByStream(in, innermost.loop->index(), e); + NVF_ERROR( + sharded_in != nullptr, + "Input could not be sharded by stream: ", + in); replacement_map[in] = sharded_in; innermost_scope.pushBack(sharded_in->definition()); } @@ -321,6 +336,10 @@ void lowerSegment( // `out` should be allocated outside the loop. TensorView* sharded_out = hir::shardByStream(out, innermost.loop->index(), e); + NVF_ERROR( + sharded_out != nullptr, + "Output could not be sharded by stream: ", + out); replacement_map[out] = sharded_out; innermost_scope.pushBack(sharded_out->definition()); } diff --git a/csrc/host_ir/ops.cpp b/csrc/host_ir/ops.cpp index d3fe0cc8740..72642335ca4 100644 --- a/csrc/host_ir/ops.cpp +++ b/csrc/host_ir/ops.cpp @@ -35,13 +35,13 @@ TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e) { ops::newValLike(source, *source->getDataType())->as(); if (std::ranges::find(e->inputs(), source) != e->inputs().end()) { - // Propagate the allocation domain from `source` to `destination`. - // Consider adding a config to TransformReplay::selfReplay to control what - // to propagate, so we don't have to reset the loop domain. + // Propagate the domain from `source` to `destination`. + // Unparallelize the destination on `ParallelType::Stream` which + // will be inferred based on the output of the expression. TransformReplay::selfReplay(source->domain(), destination->domain()); - destination->setLoopDomain(destination->getLogicalDomain()); + unshard(destination, {ParallelType::Stream}); - // Propagate the loop domain from `e` to `destination`. There are two + // Propagate ParallelType::Stream from `e` to `destination`. There are two // technical challenges: // 1. Loop domains are associated with TensorViews, not Exprs. So we // find e's reference output, `ref_out`, and propagate its loop domain. @@ -58,7 +58,7 @@ TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e) { shardLoopLike( ref_out, destination, - deviceAndStreamParallelTypes(), + {ParallelType::Stream}, PropagateDirection::kBackward); temp_e->fusion()->removeExpr(temp_e); // Fusion::removeExpr sets all outputs' definitions to nullptr, so we need @@ -68,6 +68,15 @@ TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e) { for (auto* out : e->outputs()) { out->setDefinition(e); } + + // It is possible that destination's loop domain could not be + // stream-parallelized. This happens when the corresponding id is already + // sharded such as in broadcast or collective-permute based decomposition of + // allgather. + if (getShardedIterDomain( + destination, ParallelType::Stream, DomainType::kLoop) == nullptr) { + return nullptr; + } } else { NVF_ERROR( std::ranges::find(e->outputs(), source) != e->outputs().end(), diff --git a/csrc/multidevice/propagation.cpp b/csrc/multidevice/propagation.cpp index 9bf24df334f..a95e0a68717 100644 --- a/csrc/multidevice/propagation.cpp +++ b/csrc/multidevice/propagation.cpp @@ -340,4 +340,80 @@ void shardLoopLike( transformLoopDomain(target, ref, device_or_stream_ids, direction); } +// Canonicalizes tv's loop domain for simplicity and working around schedulers' +// limitations. Many schedulers panic when seeing the input fusion segment +// contains non-DID loop splits. For example, an rFactor tensor may look like +// the following: +// +// r{k} +// / \. +// [i{m} i{n} iDIDx{d} r{k/d}] +// / \. +// i{d} i{n/d} +// +// The split of i{n} is unnecessary because i{d} and i{n/d} are both +// ParallelType::Serial. This function replaces the two with i{n} in the loop +// domain. +void canonicalizeLoopDomain(TensorView* tv) { + LinkedHashMap loop; + for (IterDomain* id : tv->getLoopDomain()) { + loop.pushBack(id, std::monostate()); + } + + for (Expr* transform : + DependencyCheck::getAllExprsBetween( + {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}, + {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}) | + std::views::reverse) { + if (auto* swizzle1d = dynamic_cast(transform)) { + if (swizzle1d->out()->isParallelized()) { + continue; + } + const auto it = loop.erase(swizzle1d->out()).second; + loop.insert(it, swizzle1d->in(), std::monostate()); + continue; + } + if (auto* split = dynamic_cast(transform)) { + NVF_ERROR( + split != nullptr, + "Only splits are expected so far, but found: ", + transform); + + if (split->outer()->isParallelized() || + split->inner()->isParallelized()) { + continue; + } + + if (!loop.contains(split->outer()) || !loop.contains(split->inner())) { + continue; + } + + loop.erase(split->outer()); + const auto inner_i = loop.erase(split->inner()).second; + // `inner_i` is picked arbitrarily as the insertion point. Given `in`, + // `outer` and `inner` are all serial, `in`'s position in the loop domain + // doesn't matter. + loop.insert(inner_i, split->in(), std::monostate()); + continue; + } + NVF_THROW("Expected a split or swizzle1d transform. Got: ", transform); + } + + auto new_loop = std::views::keys(loop); + tv->setLoopDomain({new_loop.begin(), new_loop.end()}); +} + +void unshard( + TensorView* tv, + const std::unordered_set& parallel_types) { + tv->setDeviceMesh(DeviceMesh()); + for (IterDomain* id : tv->getLoopDomain()) { + if (parallel_types.count(id->getParallelType()) == 0) { + continue; + } + id->parallelize(ParallelType::Serial); + } + canonicalizeLoopDomain(tv); +} + } // namespace nvfuser diff --git a/csrc/multidevice/propagation.h b/csrc/multidevice/propagation.h index a7ca9237bc1..74c9b7c1605 100644 --- a/csrc/multidevice/propagation.h +++ b/csrc/multidevice/propagation.h @@ -37,4 +37,12 @@ void shardLoopLike( const std::unordered_set& selected_parallel_types, PropagateDirection direction); +// Canonicalizes the loop domain of the given tensor view. +void canonicalizeLoopDomain(TensorView* tv); + +// Removes the given parallel types and canonicalizes the loop domain. +void unshard( + TensorView* tv, + const std::unordered_set& parallel_types); + } // namespace nvfuser diff --git a/csrc/preseg_passes/decompose_reshardings.cpp b/csrc/preseg_passes/decompose_reshardings.cpp index 6c0e684df5e..6f34d24bd77 100644 --- a/csrc/preseg_passes/decompose_reshardings.cpp +++ b/csrc/preseg_passes/decompose_reshardings.cpp @@ -82,65 +82,6 @@ bool isLowerableToCommunication(Expr* e) { return false; } -// Canonicalizes tv's loop domain for simplicity and working around schedulers' -// limitations. Many schedulers panic when seeing the input fusion segment -// contains non-DID loop splits. For example, an rFactor tensor may look like -// the following: -// -// r{k} -// / \. -// [i{m} i{n} iDIDx{d} r{k/d}] -// / \. -// i{d} i{n/d} -// -// The split of i{n} is unnecessary because i{d} and i{n/d} are both -// ParallelType::Serial. This function replaces the two with i{n} in the loop -// domain. -void canonicalizeLoopDomain(TensorView* tv) { - LinkedHashMap loop; - for (IterDomain* id : tv->getLoopDomain()) { - loop.pushBack(id, std::monostate()); - } - - for (Expr* transform : - DependencyCheck::getAllExprsBetween( - {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}, - {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}) | - std::views::reverse) { - auto* split = dynamic_cast(transform); - NVF_ERROR( - split != nullptr, - "Only splits are expected so far, but found: ", - transform); - - if (split->outer()->isParallelized() || split->inner()->isParallelized()) { - continue; - } - - if (!loop.contains(split->outer()) || !loop.contains(split->inner())) { - continue; - } - - loop.erase(split->outer()); - const auto inner_i = loop.erase(split->inner()).second; - // `inner_i` is picked arbitrarily as the insertion point. Given `in`, - // `outer` and `inner` are all serial, `in`'s position in the loop domain - // doesn't matter. - loop.insert(inner_i, split->in(), std::monostate()); - } - - auto new_loop = std::views::keys(loop); - tv->setLoopDomain({new_loop.begin(), new_loop.end()}); -} - -void unshard(TensorView* tv) { - tv->setDeviceMesh(DeviceMesh()); - for (IterDomain* id : tv->getLoopDomain()) { - id->parallelize(ParallelType::Serial); - } - canonicalizeLoopDomain(tv); -} - void insertReshardingSetsBefore(Fusion* fusion) { // Remove this after we refactor this as a pre-segmenter pass. FusionGuard fg(fusion); @@ -246,7 +187,7 @@ void insertReshardingSetsAfter(Fusion* fusion) { // Remove existing shardings from output so we can shard it like // input. `shardLoopLike` does not overwrite existing shardings. - unshard(output); + unshard(output, deviceAndStreamParallelTypes()); shardLoopLike( /*ref=*/resharding_input,