diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index e77b8908a82..30cbbc3dc0c 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -532,6 +532,23 @@ void HostIrEvaluator::handle(hir::ForLoop* for_loop) { auto stop = expr_evaluator_.evaluate(for_loop->stop()).as(); for (auto i = start; i < stop; i++) { + // Expressions dependent on loop index and all allocations + // inside the loop body should be invalidated. We cannot + // simply use allConsumerValsOf because the loop index can be an input to + // fusion outputs or buffers allocated outside the loop. + std::unordered_set allocations; + for (Expr* e : for_loop->body().exprs()) { + if (auto* alloc = dynamic_cast(e)) { + allocations.insert(alloc->buffer()); + } + } + expr_evaluator_.invalidate(for_loop->index()); + for (auto consumer : allConsumerValsOf(for_loop->index())) { + if (consumer->isA() && !allocations.contains(consumer)) { + continue; + } + expr_evaluator_.invalidate(consumer); + } expr_evaluator_.bind(for_loop->index(), i); for (Expr* e : for_loop->body().exprs()) { dispatch(e); diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 94415e31a80..2aea046ff2b 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -167,6 +167,32 @@ void lowerToBroadcast( backend)); } +void lowerToStreamBroadcast( + TensorView* input_tv, + TensorView* output_tv, + const CommunicatorBackend backend, + std::vector& comms, + Val* root) { + const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); + const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); + NVF_ERROR_EQ( + sender_mesh, + receiver_mesh, + "StreamBroadcast sender and receiver meshes must be the same. Given ", + sender_mesh, + " and ", + receiver_mesh); + Team team = receiver_mesh.vector(); + comms.push_back(IrBuilder::create( + CommunicationType::StreamBroadcast, + output_tv, + input_tv, + team, + root, + c10d::ReduceOp::RedOpType::UNUSED, + backend)); +} + // Adds several SendRecv communications to the vector 'comms' // For now, we assume that this function is called only if // the input and output have the same sharding. Later we could support more @@ -371,14 +397,17 @@ CommunicationInfo getCommunicationInfo(Expr* e) { const auto c2p_map = pairwise_map.mapConsumerToProducer(); // This ignores device dimensions on reduction axis. - auto producer_pt_to_did = + const std::unordered_map& producer_pt_to_id = mapDeviceAndStreamParallelTypeToId(producer->getLoopDomain()); - auto consumer_pt_to_did = + const std::unordered_map& consumer_pt_to_id = mapDeviceAndStreamParallelTypeToId(consumer->getLoopDomain()); + IterDomain* c_stream_id = + getOrDefault(consumer_pt_to_id, ParallelType::Stream); + for (ParallelType pt : kParallelTypeDIDs) { - IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt); - IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt); + IterDomain* p_loop_did = getOrDefault(producer_pt_to_id, pt); + IterDomain* c_loop_did = getOrDefault(consumer_pt_to_id, pt); if (p_loop_did == nullptr && c_loop_did == nullptr) { // Not sharded on this parallel type @@ -392,6 +421,25 @@ CommunicationInfo getCommunicationInfo(Expr* e) { if (e->isA()) { if (p_loop_did && !c_loop_did) { IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); + // Check if we are going from DIDx -> Stream, which is a ring allgather. + // This can be executed as a broadcast or send recvs, which is decided + // by the presence of a swizzle in the stream id definition. + // TODO: Lower to SendRecv if swizzle is present. + if (c_stream_id != nullptr) { + IterDomain* c_stream_logical_id = + getLogicalFromLoopId(consumer, c_stream_id); + if (c_stream_logical_id == p2c_map.at(p_logical_id)) { + NVF_CHECK( + same_mesh, + "Broadcast based allgather in stream parallel requires same " + "mesh."); + fill_communication_info( + CommunicationType::StreamBroadcast, + p_logical_id, + c_stream_logical_id); + continue; + } + } CommunicationType type = same_mesh ? CommunicationType::Allgather : CommunicationType::Gather; fill_communication_info(type, p_logical_id, p2c_map.at(p_logical_id)); @@ -479,7 +527,8 @@ Layout getCommunicationLayout( type == CommunicationType::Allreduce || type == CommunicationType::Broadcast || type == CommunicationType::SendRecv || - type == CommunicationType::AllToAll) { + type == CommunicationType::AllToAll || + type == CommunicationType::StreamBroadcast) { return layout; } @@ -537,6 +586,7 @@ bool isCommunicationLayoutCompliant(Expr* expr) { std::vector convertSingleOpToCommunication( Expr* e, DeviceIdxType my_device_idx, + Val* root, const CommunicatorBackend backend) { FusionGuard fg(e->fusion()); @@ -606,6 +656,12 @@ std::vector convertSingleOpToCommunication( case CommunicationType::AllToAll: lowerToAllToAll(input_tv, output_tv, backend, comms); break; + case CommunicationType::StreamBroadcast: + NVF_ERROR( + root != nullptr, + "StreamBroadcast requires a root value passed in through lowering"); + lowerToStreamBroadcast(input_tv, output_tv, backend, comms, root); + break; } return comms; diff --git a/csrc/host_ir/lower_to_communication.h b/csrc/host_ir/lower_to_communication.h index 65a9182e605..8c789377478 100644 --- a/csrc/host_ir/lower_to_communication.h +++ b/csrc/host_ir/lower_to_communication.h @@ -51,9 +51,15 @@ Layout getCommunicationLayout( const CommunicationType type, IterDomain* sharded_id); +// Creates a communication expr corresponding to the given +// resharding expr. In most cases, `root` is inferred based +// on communication type. However, in some cases, for e.g. +// decomposing allgather as broadcast in a host for-loop, `root` +// may be passed in through lowering. std::vector convertSingleOpToCommunication( Expr* c, DeviceIdxType my_device_idx, + Val* root = nullptr, const CommunicatorBackend backend = CommunicatorBackend::kNccl); } // namespace nvfuser diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 3924d5658aa..95ea945e895 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -145,7 +145,10 @@ Expr* cloneWithNewOperands( return e; } - return e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes()); + auto* new_e = + e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes()); + e->container()->removeExpr(e); + return new_e; } void lowerSegment( @@ -178,7 +181,8 @@ void lowerSegment( // TODO: `replacement_map` should be associated with the scope so // ShardByStream across segments in the same for-loop can be reused. std::unordered_map replacement_map; - for (Expr* c : convertSingleOpToCommunication(e, device_id)) { + Val* root = loop_nest.empty() ? nullptr : innermost.loop->index(); + for (Expr* c : convertSingleOpToCommunication(e, device_id, root)) { NVF_ERROR( c->isA(), "Exprs in a Communication group should be Communication: ", @@ -186,7 +190,8 @@ void lowerSegment( auto* communication = c->as(); TensorView* in = communication->in(); TensorView* out = communication->out(); - if (haveDifferentShardings( + if (communication->type() != CommunicationType::StreamBroadcast && + haveDifferentShardings( in, DomainType::kAllocation, out, diff --git a/csrc/host_ir/ops.cpp b/csrc/host_ir/ops.cpp index a11ffb0e652..d3fe0cc8740 100644 --- a/csrc/host_ir/ops.cpp +++ b/csrc/host_ir/ops.cpp @@ -89,8 +89,10 @@ TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e) { destination, ParallelType::Stream, DomainType::kAllocation) != nullptr, "Destination allocation should be sharded on stream after " - "shardAllocationAsLoop: ", - destination); + "shardAllocationAsLoop. ", + destination->name(), + ":", + destination->domain()->toString(0, /*loop_only=*/false)); // Refine the contiguity flags so `out` aliases `in`. This is done similar // to AliasFinder::handle(const SliceOp*). We scan through the allocation diff --git a/csrc/host_ir/pass/convert_op_to_communication.cpp b/csrc/host_ir/pass/convert_op_to_communication.cpp index 4ce4c59b0ce..c8402811230 100644 --- a/csrc/host_ir/pass/convert_op_to_communication.cpp +++ b/csrc/host_ir/pass/convert_op_to_communication.cpp @@ -35,7 +35,10 @@ void ConvertOpToCommunication::passImplementation(Fusion* fusion) { return new_top_level_exprs.push_back(top_level_expr); } for (auto* expr : nvfuser::convertSingleOpToCommunication( - top_level_expr, my_device_index, params_.communicator_backend)) { + top_level_expr, + my_device_index, + /*root=*/nullptr, + params_.communicator_backend)) { // Allocate the recv buffers of communications if (expr->isA()) { auto* communication = expr->as(); diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 9557374bf6c..539f3a2e94a 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -57,6 +57,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { case CommunicationType::AllToAll: os << "AllToAll"; break; + case CommunicationType::StreamBroadcast: + os << "StreamBroadcast"; + break; } return os; } @@ -149,6 +152,7 @@ bool hasRoot(CommunicationType type) { case CommunicationType::Reduce: case CommunicationType::Broadcast: case CommunicationType::SendRecv: + case CommunicationType::StreamBroadcast: return true; case CommunicationType::Allgather: case CommunicationType::Allreduce: @@ -171,6 +175,7 @@ bool isReduction(CommunicationType type) { case CommunicationType::Broadcast: case CommunicationType::SendRecv: case CommunicationType::AllToAll: + case CommunicationType::StreamBroadcast: return false; default: NVF_THROW("unrecognized CommunicationType: ", type); @@ -853,6 +858,7 @@ c10::intrusive_ptr postSingleCommunication( input_tensor, output_tensor); case CommunicationType::Broadcast: + case CommunicationType::StreamBroadcast: return postBroadcast( communication, my_device_index, diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index ce217a01adc..4f74fcf1cbe 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -33,7 +33,8 @@ enum class CommunicationType { ReduceScatter, Broadcast, SendRecv, - AllToAll + AllToAll, + StreamBroadcast, }; std::ostream& operator<<(std::ostream& os, const CommunicationType& type); @@ -347,6 +348,10 @@ class MoeCombine : public Expr { // - the root has one src buffer, and no or one dst buffer // - non-roots have no src buffer and one dst buffer // - all buffers have the same size +// (*) StreamBroadcast +// Shares the same postBroadcast logic with Broadcast. The difference is the +// root is the for-loop index. I kept it separate from Broadcast so I don't need +// to inspect the tensorviews later to distinguish the two. // (*) Gather // Copies each device's source buffer to the root's respective src // buffer. The order of the sender devices matches the order of the diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index c6453e513c5..696a6412929 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -417,6 +417,120 @@ def test_column_parallel_linear_forward_reference_benchmark( benchmark.pedantic(benchmark_fn, rounds=5) +def column_parallel_linear_forward(h: int, d: int): + with FusionDefinition() as fd: + inp_tv = fd.define_tensor((-1, h), contiguity=True, dtype=DataType.BFloat16) + weight_tv = fd.define_tensor( + (4 * h, h), contiguity=True, dtype=DataType.BFloat16 + ) + ag_out = fd.ops.set(inp_tv) + out_tv = fd.ops.linear(ag_out, weight_tv) + fd.add_output(out_tv) + + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) + + for tv in [inp_tv, weight_tv]: + tv.set_device_mesh(mesh) + tv.outer_split(0, d) + tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x) + + ag_out.set_device_mesh(mesh) + ag_out.outer_split(0, d) + ag_out.axis(0).parallelize(nvfuser.ParallelType.stream) + + # Fusion IR before segmentation will look like this: + # [t, h] + # /\. + # d + # (deviceIdx.x) + # | + # | set (lowered to StreamBroadcast. This decomposition is done manually in the definition above. It will later be done by preseg) + # | + # [t, h] [4h, h] + # /\ /\. + # s d + # (streamIdx) + # | + # | linear + # | + # [t, 4h, r{h}] + # /\ /\. + # s* d + + return fd + + +@pytest.mark.mpi +def test_column_parallel_linear_forward(multidevice_test): + # This is a port of CollectiveBasedOverlapTest.ColumnAndSequenceParallelLinear_Forward. + # The difference is we are using broadcast based overlapping instead of send/recv. + h, t = 2, 24 + d = multidevice_test.size + if (h * 4) % d != 0: + pytest.skip( + f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + ) + if t % d != 0: + pytest.skip( + f"Column-parallel linear requires {t} to be divisible by world size {d}." + ) + + fd = column_parallel_linear_forward(h, d) + + inp_ref = torch.testing.make_tensor(t, h, dtype=torch.int32, device="cpu").to( + torch.bfloat16 + ) + weight_ref = torch.testing.make_tensor( + 4 * h, h, dtype=torch.int32, device="cpu" + ).to(torch.bfloat16) + + inp = multidevice_test.shard_tensor(inp_ref, fd.fusion.inputs()[0]) + weight = multidevice_test.shard_tensor(weight_ref, fd.fusion.inputs()[1]) + + out_ref = torch.nn.functional.linear(inp_ref.cuda(), weight) + + with torch.profiler.profile(record_shapes=True) as prof: + (out,) = fd.execute([inp, weight], _enable_options=["host_ir_lowering"]) + torch.testing.assert_close(out, out_ref) + broadcast_events = [ + event for event in prof.events() if "ncclDevKernel_Broadcast" in event.name + ] + assert len(broadcast_events) == (d if d > 1 else 0) + + +@pytest.mark.mpi +@pytest.mark.benchmark +def test_column_parallel_linear_forward_benchmark(multidevice_test, benchmark): + # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. + h, t = 8192, 8192 + d = multidevice_test.size + if (4 * h) % d != 0: + pytest.skip( + f"Column-parallel linear requires {4 * h} to be divisible by world size {d}." + ) + if t % d != 0: + pytest.skip( + f"Column-parallel linear requires {t} to be divisible by world size {d}." + ) + + fd = column_parallel_linear_forward(h, d) + + inp_ref = torch.randn(t, h, dtype=torch.bfloat16, device="cpu") + weight_ref = torch.randn(4 * h, h, dtype=torch.bfloat16, device="cpu") + + inp = multidevice_test.shard_tensor(inp_ref, fd.fusion.inputs()[0]) + weight = multidevice_test.shard_tensor(weight_ref, fd.fusion.inputs()[1]) + + warmup_fn, benchmark_fn = get_benchmark_fns( + lambda: fd.execute( + [inp, weight], + _enable_options=["host_ir_lowering"], + ) + ) + warmup_fn() + benchmark.pedantic(benchmark_fn, rounds=5) + + @pytest.mark.mpi @pytest.mark.parametrize("backend_type", [CommunicatorBackend.nccl]) @pytest.mark.parametrize("s", [1, 8])