diff --git a/csrc/dispatch.h b/csrc/dispatch.h index f6614d5df00..cdd9c9ffd34 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -128,6 +128,7 @@ class Val; f(SdpaFwdOp); \ f(SdpaBwdOp); \ f(EmbeddingFwdOp); \ + f(CollectivePermute); \ f(Communication); \ f(P2PCommunication); #define DISPATCH_FOR_ALL_KIR_EXPRS(f) \ diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index e77b8908a82..99851e21db4 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -311,6 +311,38 @@ void HostIrEvaluator::handle(ShareMemHandles* share_mem_handles) { ipc_handle_cache_.exchangeHandles(share_mem_handles->communications()); } +void HostIrEvaluator::handle(CollectivePermute* communication) { + NVF_ERROR( + communicator_ != nullptr && communicator_->is_available(), + "A valid communicator must be provided"); + + at::Tensor input_tensor = getKnownTensorOrUndefined(communication->input(0)); + at::Tensor output_tensor = + getKnownTensorOrUndefined(communication->output(0)); + +#ifndef NDEBUG + validateSizesAndStrides( + {input_tensor, output_tensor}, + {communication->in(), communication->out()}, + expr_evaluator_); +#endif + + CommunicatorBackend backend_type = communication->backend(); + // CollectivePermute is only supported with NCCL backend because + // UCC does not support coalescing. + NVF_CHECK_EQ(backend_type, CommunicatorBackend::kNccl); + c10d::Backend* backend = + communicator_->getBackendForTeam(communication->team(), backend_type); + works_[communication] = postSingleCommunication( + communication, + communicator_->deviceId(), + backend, + input_tensor, + output_tensor, + expr_evaluator_.evaluate(communication->sendPeer()).as(), + expr_evaluator_.evaluate(communication->recvPeer()).as()); +} + void HostIrEvaluator::handle(Communication* communication) { NVF_ERROR( communicator_ != nullptr && communicator_->is_available(), @@ -532,6 +564,101 @@ void HostIrEvaluator::handle(hir::ForLoop* for_loop) { auto stop = expr_evaluator_.evaluate(for_loop->stop()).as(); for (auto i = start; i < stop; i++) { + // This is not ideal. In lowering, we create communication expr. + // The collective permute has the output tensorview, and input vals of + // send_peer and recv_peer. While the definition of output_tv is not + // modified and remains `set`, this output_tv is a use of the vals Even + // though we shardByStream, the use of vals is not modified and has a + // dependency on T1. Cloned e: T1_g_float[istreamIdx6{1}, iS5{3}] + // (DeviceMesh{0}) + // = Set( T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}), + // cache_op=Streaming ) + + // c: CollectivePermute 77 (team=(0), send_peer=( ( 1 + ( 0 - i140 ) ) % 1 + // ), recv_peer=( ( i140 + 0 ) % 1 ), input=T0_g_float[ideviceIdx.x2{1}, + // iS3{3}] (DeviceMesh{0}), output=T1_g_float[istreamIdx6{1}, iS5{3}] + // (DeviceMesh{0}), backend=NCCL) + + // %HostIrContainer { (T0_g_float[ideviceIdx.x2{1}, iS3{3}] + // (DeviceMesh{0})) -> (T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0})) + // : + // T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) = + // ALLOCATE(buffer=T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), + // mem_type=global, size=3, zero_init=false, resets_to_zero=false) Stream + // 0x281a6c60 = GetCurrentStream() FOR i140 from 0 to 1: + // SetCurrentStream(Stream i140) + // Synchronize(Stream 0x281a6c60) + // T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) = + // ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), + // stream_index=i140) CollectivePermute 82 (team=(0), send_peer=( ( 1 + + // ( 0 - i140 ) ) % 1 ), recv_peer=( ( i140 + 0 ) % 1 ), + // input=T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}), + // output=T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}), + // backend=NCCL) Wait(Communication 82) + // SetCurrentStream(Stream 0x281a6c60) + // FOR i140 from 0 to 1: + // Synchronize(Stream i140) + // } // %HostIrContainer + + // Invalidating index: i140 + // allConsumerValsOf(i140) + // Visited val: i140 + // Consumer of i140: i163 definition: i163 = 0 - i140; + + // Visited val: i163 + // Consumer of i163: i165 definition: i165 = 1 + i163; + + // Visited val: i165 + // Consumer of i165: i167 definition: i167 = i165 % 1; + + // Visited val: i167 + // Consumer of i167: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) + // definition: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) + // = Set( T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}), + // cache_op=Streaming ) + + // Visited val: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) + // Consumer of i167: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) + // definition: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) = + // ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), + // stream_index=i140) + + // Visited val: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) + // Consumer of i140: i169 definition: i169 = i140 + 0; + + // Visited val: i169 + // Consumer of i169: i171 definition: i171 = i169 % 1; + + // Visited val: i171 + // Consumer of i171: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) + // definition: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) + // = Set( T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}), + // cache_op=Streaming ) + + // Consumer of i171: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) + // definition: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) = + // ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), + // stream_index=i140) + + // Consumer of i140: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) + // definition: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) = + // ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}), + // stream_index=i140) + + // consumer_vals: 8 + // Invalidating consumer: i169 + // Invalidating consumer: T2_l_float[istreamIdx10{1}, iS9{3}] + // (DeviceMesh{0}) Invalidating consumer: T1_g_float[istreamIdx6{1}, + // iS5{3}] (DeviceMesh{0}) Invalidating consumer: i167 Invalidating + // consumer: i165 Invalidating consumer: i163 Invalidating consumer: i171 + // Invalidating consumer: i140 + + expr_evaluator_.invalidate(for_loop->index()); + for (auto consumer : allConsumerValsOf(for_loop->index())) { + if (!consumer->isA()) { + expr_evaluator_.invalidate(consumer); + } + } expr_evaluator_.bind(for_loop->index(), i); for (Expr* e : for_loop->body().exprs()) { dispatch(e); diff --git a/csrc/host_ir/evaluator.h b/csrc/host_ir/evaluator.h index 4a1929ba1bd..a039215ef67 100644 --- a/csrc/host_ir/evaluator.h +++ b/csrc/host_ir/evaluator.h @@ -96,6 +96,7 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch { void handle(Synchronize*) override; void handle(PostOnStream*) override; void handle(LaunchKernel*) override; + void handle(CollectivePermute*) override; void handle(Communication*) override; void handle(P2PCommunication*) override; void handle(MoeDispatch*) override; diff --git a/csrc/host_ir/ir.cpp b/csrc/host_ir/ir.cpp index 198601355fb..bc16a39931d 100644 --- a/csrc/host_ir/ir.cpp +++ b/csrc/host_ir/ir.cpp @@ -257,9 +257,13 @@ Wait::Wait(IrBuilderPasskey passkey, Expr* expr) this, "must be registered in a HostIrContainer"); NVF_ERROR( - (expr->isOneOf()), - expr, - " must be a Communication, a P2PCommunication, or a EndCoalescing"); + (expr->isOneOf< + Communication, + CollectivePermute, + P2PCommunication, + EndCoalescing>()), + "Got: ", + expr); } NVFUSER_DEFINE_CLONE_AND_CREATE(Wait) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index d91fd4eda60..ae12472cce4 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -166,6 +166,32 @@ void lowerToBroadcast( backend)); } +void lowerToStreamBroadcast( + TensorView* input_tv, + TensorView* output_tv, + const CommunicatorBackend backend, + std::vector& comms, + Val* root) { + const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); + const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); + NVF_ERROR_EQ( + sender_mesh, + receiver_mesh, + "StreamBroadcast sender and receiver meshes must be the same. Given ", + sender_mesh, + " and ", + receiver_mesh); + Team team = receiver_mesh.vector(); + comms.push_back(IrBuilder::create( + CommunicationType::StreamBroadcast, + output_tv, + input_tv, + team, + root, + c10d::ReduceOp::RedOpType::UNUSED, + backend)); +} + // Adds several SendRecv communications to the vector 'comms' // For now, we assume that this function is called only if // the input and output have the same sharding. Later we could support more @@ -319,6 +345,33 @@ void lowerToAllToAll( backend)); } +void lowerToCollectivePermute( + TensorView* input_tv, + TensorView* output_tv, + const CommunicatorBackend backend, + std::vector& comms, + Val* root, + DeviceIdxType my_device_idx) { + NVF_ERROR_EQ( + input_tv->getDeviceMesh(), + output_tv->getDeviceMesh(), + "CollectivePermute sender and receiver meshes must be the same. Given ", + input_tv->getDeviceMesh(), + " and ", + output_tv->getDeviceMesh()); + + IterDomain* stream_id = + getShardedIterDomain(output_tv, ParallelType::Stream, DomainType::kLoop); + Swizzle1D* swizzle = stream_id->definition()->as(); + ParallelType pt = swizzle->parallelType(); + + const auto& [recv_peer, send_peer] = + dispatchSwizzle1D(root, my_device_idx, pt, input_tv->getDeviceMesh()); + Team team = input_tv->getDeviceMesh().vector(); + comms.push_back(IrBuilder::create( + output_tv, input_tv, team, send_peer, recv_peer, backend)); +} + IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) { std::unordered_set logical_ids = getInputsInTargetDomain({loop_id}, tv->getLogicalDomain()); @@ -370,14 +423,16 @@ CommunicationInfo getCommunicationInfo(Expr* e) { const auto c2p_map = pairwise_map.mapConsumerToProducer(); // This ignores device dimensions on reduction axis. - auto producer_pt_to_did = + const std::unordered_map& producer_pt_to_id = mapDeviceAndStreamParallelTypeToId(producer->getLoopDomain()); - auto consumer_pt_to_did = + const std::unordered_map& consumer_pt_to_id = mapDeviceAndStreamParallelTypeToId(consumer->getLoopDomain()); for (ParallelType pt : kParallelTypeDIDs) { - IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt); - IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt); + IterDomain* p_loop_did = getOrDefault(producer_pt_to_id, pt); + IterDomain* c_loop_did = getOrDefault(consumer_pt_to_id, pt); + IterDomain* c_stream_id = + getOrDefault(consumer_pt_to_id, ParallelType::Stream); if (p_loop_did == nullptr && c_loop_did == nullptr) { // Not sharded on this parallel type @@ -391,6 +446,25 @@ CommunicationInfo getCommunicationInfo(Expr* e) { if (e->isA()) { if (p_loop_did && !c_loop_did) { IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); + // Check if we are going from DIDx -> Stream, which is a ring allgather. + // This can be executed as a broadcast or send recvs, which is decided + // by the presence of a swizzle in the stream id definition. + if (c_stream_id != nullptr) { + IterDomain* c_stream_logical_id = + getLogicalFromLoopId(consumer, c_stream_id); + if (c_stream_logical_id == p2c_map.at(p_logical_id)) { + NVF_CHECK( + same_mesh, + "Broadcast based allgather in stream parallel requires same " + "mesh."); + auto* swizzle = dynamic_cast(c_stream_id->definition()); + CommunicationType type = swizzle != nullptr + ? CommunicationType::CollectivePermute + : CommunicationType::StreamBroadcast; + fill_communication_info(type, p_logical_id, c_stream_logical_id); + continue; + } + } CommunicationType type = same_mesh ? CommunicationType::Allgather : CommunicationType::Gather; fill_communication_info(type, p_logical_id, p2c_map.at(p_logical_id)); @@ -478,7 +552,9 @@ Layout getCommunicationLayout( type == CommunicationType::Allreduce || type == CommunicationType::Broadcast || type == CommunicationType::SendRecv || - type == CommunicationType::AllToAll) { + type == CommunicationType::CollectivePermute || + type == CommunicationType::AllToAll || + type == CommunicationType::StreamBroadcast) { return layout; } @@ -536,6 +612,7 @@ bool isCommunicationLayoutCompliant(Expr* expr) { std::vector convertSingleOpToCommunication( Expr* e, DeviceIdxType my_device_idx, + Val* root, const CommunicatorBackend backend) { FusionGuard fg(e->fusion()); @@ -605,6 +682,21 @@ std::vector convertSingleOpToCommunication( case CommunicationType::AllToAll: lowerToAllToAll(input_tv, output_tv, backend, comms); break; + case CommunicationType::StreamBroadcast: + NVF_ERROR( + root != nullptr, + "StreamBroadcast requires a root value passed in through lowering"); + lowerToStreamBroadcast(input_tv, output_tv, backend, comms, root); + break; + case CommunicationType::CollectivePermute: + // FIXME: Rename this to host loop index. Collective Permute has no root. + // The send and recv peer indices are computed using the host loop index. + NVF_ERROR( + root != nullptr, + "CollectivePermute requires a root value passed in through lowering"); + lowerToCollectivePermute( + input_tv, output_tv, backend, comms, root, my_device_idx); + break; } return comms; diff --git a/csrc/host_ir/lower_to_communication.h b/csrc/host_ir/lower_to_communication.h index 65a9182e605..8c789377478 100644 --- a/csrc/host_ir/lower_to_communication.h +++ b/csrc/host_ir/lower_to_communication.h @@ -51,9 +51,15 @@ Layout getCommunicationLayout( const CommunicationType type, IterDomain* sharded_id); +// Creates a communication expr corresponding to the given +// resharding expr. In most cases, `root` is inferred based +// on communication type. However, in some cases, for e.g. +// decomposing allgather as broadcast in a host for-loop, `root` +// may be passed in through lowering. std::vector convertSingleOpToCommunication( Expr* c, DeviceIdxType my_device_idx, + Val* root = nullptr, const CommunicatorBackend backend = CommunicatorBackend::kNccl); } // namespace nvfuser diff --git a/csrc/host_ir/lowering.cpp b/csrc/host_ir/lowering.cpp index 3924d5658aa..7c6a18865b1 100644 --- a/csrc/host_ir/lowering.cpp +++ b/csrc/host_ir/lowering.cpp @@ -12,8 +12,10 @@ #include "host_ir/ir.h" #include "host_ir/lower_to_communication.h" #include "host_ir/ops.h" +#include "ir/builder.h" #include "ir/iostream.h" #include "ir/utils.h" +#include "kernel_ir.h" #include "multidevice/propagation.h" #include "multidevice/resharding.h" #include "multidevice/utils.h" @@ -174,19 +176,28 @@ void lowerSegment( // If a value is already cloned, IrCloner::clone returns the cloned value // without cloning the value again. Expr* e = ir_cloner.clone(group.exprs().front()); + debug() << "Cloned e: " << e << std::endl; // TODO: `replacement_map` should be associated with the scope so // ShardByStream across segments in the same for-loop can be reused. std::unordered_map replacement_map; - for (Expr* c : convertSingleOpToCommunication(e, device_id)) { + for (Expr* c : convertSingleOpToCommunication( + e, device_id, innermost.loop->index())) { NVF_ERROR( - c->isA(), - "Exprs in a Communication group should be Communication: ", + c->isA() || c->isA(), + "Exprs in a Communication group should be Communication or " + "CollectivePermute: ", c); - auto* communication = c->as(); - TensorView* in = communication->in(); - TensorView* out = communication->out(); - if (haveDifferentShardings( + TensorView* in = c->input(0)->as(); + TensorView* out = c->output(0)->as(); + bool can_shard_in = true; + if (c->isA() || + c->as()->type() == + CommunicationType::StreamBroadcast) { + can_shard_in = false; + } + if (can_shard_in && + haveDifferentShardings( in, DomainType::kAllocation, out, @@ -194,13 +205,11 @@ void lowerSegment( {ParallelType::Stream})) { Val*& sharded_in = replacement_map[in]; if (sharded_in == nullptr) { - sharded_in = - hir::shardByStream(in, innermost.loop->index(), communication); + sharded_in = hir::shardByStream(in, innermost.loop->index(), c); innermost_scope.pushBack(sharded_in->definition()); } } - // Allocate the recv buffers of communications auto* allocate = IrBuilder::create(out, out->getMemoryType()); if (getShardedIterDomain( @@ -211,8 +220,7 @@ void lowerSegment( innermost.parent_scope->insert( innermost.parent_insertion_point, allocate); auto [i, inserted] = replacement_map.emplace( - out, - hir::shardByStream(out, innermost.loop->index(), communication)); + out, hir::shardByStream(out, innermost.loop->index(), c)); NVF_ERROR(inserted, "The input segmented fusion should be SSA."); innermost_scope.pushBack(i->second->definition()); } else { diff --git a/csrc/host_ir/ops.cpp b/csrc/host_ir/ops.cpp index a11ffb0e652..d3fe0cc8740 100644 --- a/csrc/host_ir/ops.cpp +++ b/csrc/host_ir/ops.cpp @@ -89,8 +89,10 @@ TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e) { destination, ParallelType::Stream, DomainType::kAllocation) != nullptr, "Destination allocation should be sharded on stream after " - "shardAllocationAsLoop: ", - destination); + "shardAllocationAsLoop. ", + destination->name(), + ":", + destination->domain()->toString(0, /*loop_only=*/false)); // Refine the contiguity flags so `out` aliases `in`. This is done similar // to AliasFinder::handle(const SliceOp*). We scan through the allocation diff --git a/csrc/host_ir/pass/convert_op_to_communication.cpp b/csrc/host_ir/pass/convert_op_to_communication.cpp index 4ce4c59b0ce..c8402811230 100644 --- a/csrc/host_ir/pass/convert_op_to_communication.cpp +++ b/csrc/host_ir/pass/convert_op_to_communication.cpp @@ -35,7 +35,10 @@ void ConvertOpToCommunication::passImplementation(Fusion* fusion) { return new_top_level_exprs.push_back(top_level_expr); } for (auto* expr : nvfuser::convertSingleOpToCommunication( - top_level_expr, my_device_index, params_.communicator_backend)) { + top_level_expr, + my_device_index, + /*root=*/nullptr, + params_.communicator_backend)) { // Allocate the recv buffers of communications if (expr->isA()) { auto* communication = expr->as(); diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 9557374bf6c..78521149cab 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -57,6 +57,12 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) { case CommunicationType::AllToAll: os << "AllToAll"; break; + case CommunicationType::StreamBroadcast: + os << "StreamBroadcast"; + break; + case CommunicationType::CollectivePermute: + os << "CollectivePermute"; + break; } return os; } @@ -149,11 +155,13 @@ bool hasRoot(CommunicationType type) { case CommunicationType::Reduce: case CommunicationType::Broadcast: case CommunicationType::SendRecv: + case CommunicationType::StreamBroadcast: return true; case CommunicationType::Allgather: case CommunicationType::Allreduce: case CommunicationType::ReduceScatter: case CommunicationType::AllToAll: + case CommunicationType::CollectivePermute: return false; } std::unreachable(); @@ -171,6 +179,8 @@ bool isReduction(CommunicationType type) { case CommunicationType::Broadcast: case CommunicationType::SendRecv: case CommunicationType::AllToAll: + case CommunicationType::StreamBroadcast: + case CommunicationType::CollectivePermute: return false; default: NVF_THROW("unrecognized CommunicationType: ", type); @@ -321,6 +331,47 @@ std::string P2PCommunication::toString(int indent_size) const { return toInlineString(indent_size) + "\n"; } +CollectivePermute::CollectivePermute( + IrBuilderPasskey passkey, + TensorView* out, + TensorView* in, + Team team, + Val* send_peer, + Val* recv_peer, + CommunicatorBackend backend) + : Expr(passkey) { + NVF_ERROR( + in->getDeviceMesh().size() > 0, + "The input mesh size must be greater than 0."); + NVF_ERROR( + out->getDeviceMesh().size() > 0, + "The output mesh size must be greater than 0."); + addInput(in); + addInput(send_peer); + addInput(recv_peer); + addOutput(out); + addDataAttribute(CommunicationType::CollectivePermute); + addDataAttribute(team); + addDataAttribute(backend); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(CollectivePermute) + +std::string CollectivePermute::toInlineString(const int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "CollectivePermute " << name() << " (" + << "team=(" << team() << ")" + << ", send_peer=" << sendPeer()->toInlineString() + << ", recv_peer=" << recvPeer()->toInlineString() + << ", input=" << in() << ", output=" << out() + << ", backend=" << backend() << ")"; + return ss.str(); +} + +std::string CollectivePermute::toString(int indent_size) const { + return toInlineString(indent_size) + "\n"; +} + MoeDispatch::MoeDispatch( IrBuilderPasskey passkey, TensorView* out_x, @@ -782,6 +833,33 @@ c10::intrusive_ptr postAllToAll( empty_split_sizes, /*options=*/{}); } + +c10::intrusive_ptr postCollectivePermute( + CollectivePermute* communication, + DeviceIdxType my_device_index, + DeviceIdxType send_peer_index, + DeviceIdxType recv_peer_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor) { + if (my_device_index == send_peer_index && + my_device_index == recv_peer_index) { + doLocalCopy(output_tensor, input_tensor); + return nullptr; + } + backend->startCoalescing(); + std::vector send_tensors = {input_tensor}; + backend->send( + send_tensors, + send_peer_index, + /*tag=*/0); + std::vector recv_tensors = {output_tensor}; + backend->recv( + recv_tensors, + recv_peer_index, + /*tag=*/0); + return backend->endCoalescing(); +} } // namespace c10::intrusive_ptr postSingleCommunication( @@ -853,6 +931,7 @@ c10::intrusive_ptr postSingleCommunication( input_tensor, output_tensor); case CommunicationType::Broadcast: + case CommunicationType::StreamBroadcast: return postBroadcast( communication, my_device_index, @@ -877,6 +956,39 @@ c10::intrusive_ptr postSingleCommunication( } } +c10::intrusive_ptr postSingleCommunication( + CollectivePermute* communication, + DeviceIdxType my_device_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor, + DeviceIdxType send_peer_index, + DeviceIdxType recv_peer_index) { + const Team& team = communication->team(); + if (std::find(team.begin(), team.end(), my_device_index) == team.end()) { + return nullptr; + } + NVF_CHECK(backend != nullptr); + + if (isDebugDumpEnabled(DebugDumpOption::Communication) && + my_device_index == 0) { + debug() << "Posting " << communication->toInlineString() + << " with input_tensor " << input_tensor.sizes() + << " and output_tensor " << output_tensor.sizes() + << " send_peer=" << send_peer_index + << " recv_peer=" << recv_peer_index << std::endl; + } + + return postCollectivePermute( + communication, + my_device_index, + send_peer_index, + recv_peer_index, + backend, + input_tensor, + output_tensor); +} + namespace { c10::intrusive_ptr postSend( diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index ce217a01adc..9b37c5fadf9 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -33,7 +33,9 @@ enum class CommunicationType { ReduceScatter, Broadcast, SendRecv, - AllToAll + AllToAll, + StreamBroadcast, + CollectivePermute, }; std::ostream& operator<<(std::ostream& os, const CommunicationType& type); @@ -129,6 +131,62 @@ class Communication : public Expr { void validate(); }; +// CollectivePermute: send to send_peer, recv from recv_peer. Separate from +// Communication (no root, no reduce op). Layout: inputs [in, send_peer, +// recv_peer], output [out], attributes [type, team, backend]. +class CollectivePermute : public Expr { + public: + using Expr::Expr; + + CollectivePermute( + IrBuilderPasskey passkey, + TensorView* out, + TensorView* in, + Team team, + Val* send_peer, + Val* recv_peer, + CommunicatorBackend backend = CommunicatorBackend::kNccl); + + CollectivePermute(const CollectivePermute& other) = delete; + CollectivePermute& operator=(const CollectivePermute& other) = delete; + CollectivePermute(CollectivePermute&& other) = delete; + CollectivePermute& operator=(CollectivePermute&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "CollectivePermute"; + } + + CommunicationType type() const { + return attribute(0); + } + + TensorView* in() const { + return input(0)->as(); + } + TensorView* out() const { + return output(0)->as(); + } + Val* sendPeer() const { + return input(1); + } + Val* recvPeer() const { + return input(2); + } + const Team& team() const { + return attribute(1); + } + int64_t team_size() const { + return static_cast(team().size()); + } + CommunicatorBackend backend() const { + return attribute(2); + } +}; + enum class P2PCommunicationType { SEND, RECV }; std::ostream& operator<<(std::ostream& os, const P2PCommunicationType& type); @@ -347,6 +405,10 @@ class MoeCombine : public Expr { // - the root has one src buffer, and no or one dst buffer // - non-roots have no src buffer and one dst buffer // - all buffers have the same size +// (*) StreamBroadcast +// Shares the same postBroadcast logic with Broadcast. The difference is the +// root is the for-loop index. I kept it separate from Broadcast so I don't need +// to inspect the tensorviews later to distinguish the two. // (*) Gather // Copies each device's source buffer to the root's respective src // buffer. The order of the sender devices matches the order of the @@ -399,6 +461,15 @@ c10::intrusive_ptr postSingleCommunication( at::Tensor output_tensor, DeviceIdxType root_index = -1); +c10::intrusive_ptr postSingleCommunication( + CollectivePermute* communication, + DeviceIdxType my_device_index, + c10d::Backend* backend, + at::Tensor input_tensor, + at::Tensor output_tensor, + DeviceIdxType send_peer_index, + DeviceIdxType recv_peer_index); + c10::intrusive_ptr postSingleCommunication( P2PCommunication* communication, DeviceIdxType my_device_index, diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index beb7283c5a1..1b08456d753 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -14,8 +14,10 @@ #include #include "compute_at_map.h" +#include "ir/builder.h" #include "ir/internal_base_nodes.h" #include "ir/internal_nodes.h" +#include "ops/arith.h" #include "transform_replay.h" #include "type.h" @@ -355,4 +357,21 @@ int64_t getRFactorDeviceDimensionIndex(const TensorView* tv) { return rfactor_did_idx; } +std::pair dispatchSwizzle1D( + Val* host_loop_index, + DeviceIdxType device_id, + ParallelType pt, + const DeviceMesh& mesh) { + int64_t team_size = mesh.size(pt); + at::Tensor md_index = mesh.multiDimensionalIndexOf(device_id); + auto pt_axis = mesh.parallelTypeToAxis(pt); + int64_t team_index = md_index[pt_axis].item(); + Val* team_size_val = IrBuilder::create(team_size, DataType::Index); + Val* team_index_val = IrBuilder::create(team_index, DataType::Index); + return std::make_pair( + mod(add(host_loop_index, team_index_val), team_size_val), + mod(add(team_size_val, sub(team_index_val, host_loop_index)), + team_size_val)); +} + } // namespace nvfuser diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index e924e7fcc75..bad7730fee1 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -91,4 +91,10 @@ bool isValidDeviceSplit(Expr* expr); // See tests/python/test_multidevice.py/test_matmul_allreduce_loop_split int64_t getRFactorDeviceDimensionIndex(const TensorView* tv); +std::pair dispatchSwizzle1D( + Val* my_rank, + DeviceIdxType device_id, + ParallelType pt, + const DeviceMesh& mesh); + } // namespace nvfuser diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index 676e27e8805..e66b3b8e793 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -95,11 +95,24 @@ class ForwardTraverseFromLogicalToAlloc { .second); } + void handle(Swizzle1D* swizzle1d) { + // Swizzle1D does not affect allocation (same size/stride, just reindexing). + auto in = swizzle1d->in(); + auto out = swizzle1d->out(); + auto in_it = active_ids_.find(in); + auto [in_size, in_stride] = in_it->second; + NVF_ERROR(active_ids_.erase(in) == 1); + NVF_ERROR( + active_ids_.emplace(out, std::make_pair(in_size, in_stride)).second); + } + void handle(Expr* expr) { if (auto split = dynamic_cast(expr)) { handle(split); } else if (auto merge = dynamic_cast(expr)) { handle(merge); + } else if (auto swizzle1d = dynamic_cast(expr)) { + handle(swizzle1d); } else { NVF_THROW("Unsupported transormation in allocation domain"); } @@ -190,11 +203,24 @@ class BackwardTraverseFromLogicalToAlloc { .second); } + void handle(Swizzle1D* swizzle1d) { + // Swizzle1D does not affect allocation (same size/stride, just reindexing). + auto in = swizzle1d->in(); + auto out = swizzle1d->out(); + auto out_it = active_ids_.find(out); + auto [out_size, out_stride] = out_it->second; + NVF_ERROR(active_ids_.erase(out) == 1); + NVF_ERROR( + active_ids_.emplace(in, std::make_pair(out_size, out_stride)).second); + } + void handle(Expr* expr) { if (auto split = dynamic_cast(expr)) { handle(split); } else if (auto merge = dynamic_cast(expr)) { handle(merge); + } else if (auto swizzle1d = dynamic_cast(expr)) { + handle(swizzle1d); } else { NVF_THROW("Unsupported transormation in allocation domain"); } diff --git a/python/python_direct/ir.cpp b/python/python_direct/ir.cpp index 93c032e4ef6..c149663da38 100644 --- a/python/python_direct/ir.cpp +++ b/python/python_direct/ir.cpp @@ -501,6 +501,27 @@ Returns ------- TensorView A TensorView with the swizzled axes in its loop domain. +)") + .def( + "swizzle1d", + [](TensorView* self, int64_t x, ParallelType parallel_type) { + return self->swizzle1d(x, parallel_type); + }, + py::return_value_policy::reference, + py::arg("x"), + py::arg("parallel_type"), + R"( +Swizzle the specified axis with the device index corresponding to the given parallel type. +Parameters +---------- +x : int +The axis to swizzle. +parallel_type : ParallelType +The device parallel type for the 1D swizzle. +Returns +------- +TensorView +A TensorView with the swizzled axis in its loop domain. )") .def( "rfactor", diff --git a/tests/python/multidevice/test_communication.py b/tests/python/multidevice/test_communication.py index 833ab511ff3..be7d28aef43 100644 --- a/tests/python/multidevice/test_communication.py +++ b/tests/python/multidevice/test_communication.py @@ -171,3 +171,32 @@ def test_alltoall(multidevice_test, inp_axis, out_axis): inp = multidevice_test.shard_tensor(in_ref, inp_tv) (out,) = fd.execute([inp]) torch.testing.assert_close(out, multidevice_test.shard_tensor(out_ref, out_tv)) + + +def test_collective_permute(multidevice_test): + d = multidevice_test.size + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) + + with FusionDefinition() as fd: + inp_tv = fd.define_tensor((d * 3,), contiguity=True, dtype=DataType.Float) + out_tv = fd.ops.set(inp_tv) + fd.add_output(out_tv) + + inp_tv.set_device_mesh(mesh) + inp_tv.outer_split(0, d) + inp_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x) + + out_tv.set_device_mesh(mesh) + out_tv.outer_split(0, d) + out_tv.swizzle1d(0, nvfuser.ParallelType.mesh_x) + out_tv.axis(0).parallelize(nvfuser.ParallelType.stream) + + inp_ref = torch.randn(d * 3) + inp = multidevice_test.shard_tensor(inp_ref, inp_tv) + with torch.profiler.profile() as prof: + (out,) = fd.execute([inp], _enable_options=["host_ir_lowering"]) + torch.testing.assert_close(out.cpu(), inp_ref) + collective_permute_events = [ + event for event in prof.events() if "ncclDevKernel_SendRecv" in event.name + ] + assert len(collective_permute_events) == (d - 1) diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index c6453e513c5..24f6234edcc 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -417,6 +417,120 @@ def test_column_parallel_linear_forward_reference_benchmark( benchmark.pedantic(benchmark_fn, rounds=5) +def column_parallel_linear_forward(h: int, d: int): + with FusionDefinition() as fd: + inp_tv = fd.define_tensor((-1, h), contiguity=True, dtype=DataType.BFloat16) + weight_tv = fd.define_tensor( + (4 * h, h), contiguity=True, dtype=DataType.BFloat16 + ) + ag_out = fd.ops.set(inp_tv) + out_tv = fd.ops.linear(ag_out, weight_tv) + fd.add_output(out_tv) + + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d)) + + for tv in [inp_tv, weight_tv]: + tv.set_device_mesh(mesh) + tv.outer_split(0, d) + tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x) + + ag_out.set_device_mesh(mesh) + ag_out.outer_split(0, d) + ag_out.axis(0).parallelize(nvfuser.ParallelType.stream) + + # Fusion IR before segmentation will look like this: + # [t, h] + # /\. + # d + # (deviceIdx.x) + # | + # | set (lowered to StreamBroadcast. This decomposition is done manually in the definition above. It will later be done by preseg) + # | + # [t, h] [4h, h] + # /\ /\. + # s d + # (streamIdx) + # | + # | linear + # | + # [t, 4h, r{h}] + # /\ /\. + # s* d + + return fd + + +@pytest.mark.mpi +def test_column_parallel_linear_forward(multidevice_test): + # This is a port of CollectiveBasedOverlapTest.ColumnAndSequenceParallelLinear_Forward. + # The difference is we are using broadcast based overlapping instead of send/recv. + h, t = 2, 24 + d = multidevice_test.size + if (h * 4) % d != 0: + pytest.skip( + f"Row-parallel linear requires {h * 4} to be divisible by world size {d}." + ) + if t % d != 0: + pytest.skip( + f"Column-parallel linear requires {t} to be divisible by world size {d}." + ) + + fd = column_parallel_linear_forward(h, d) + + inp_ref = torch.testing.make_tensor(t, h, dtype=torch.int32, device="cpu").to( + torch.bfloat16 + ) + weight_ref = torch.testing.make_tensor( + 4 * h, h, dtype=torch.int32, device="cpu" + ).to(torch.bfloat16) + + inp = multidevice_test.shard_tensor(inp_ref, fd.fusion.inputs()[0]) + weight = multidevice_test.shard_tensor(weight_ref, fd.fusion.inputs()[1]) + + out_ref = torch.nn.functional.linear(inp_ref.cuda(), weight) + + with torch.profiler.profile(record_shapes=True) as prof: + (out,) = fd.execute([inp, weight], _enable_options=["host_ir_lowering"]) + torch.testing.assert_close(out, out_ref) + broadcast_events = [ + event for event in prof.events() if "ncclDevKernel_Broadcast" in event.name + ] + assert len(broadcast_events) == d + + +@pytest.mark.mpi +@pytest.mark.benchmark +def test_column_parallel_linear_forward_benchmark(multidevice_test, benchmark): + # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. + h, t = 8192, 8192 + d = multidevice_test.size + if (4 * h) % d != 0: + pytest.skip( + f"Column-parallel linear requires {4 * h} to be divisible by world size {d}." + ) + if t % d != 0: + pytest.skip( + f"Column-parallel linear requires {t} to be divisible by world size {d}." + ) + + fd = column_parallel_linear_forward(h, d) + + inp_ref = torch.randn(t, h, dtype=torch.bfloat16, device="cpu") + weight_ref = torch.randn(4 * h, h, dtype=torch.bfloat16, device="cpu") + + inp = multidevice_test.shard_tensor(inp_ref, fd.fusion.inputs()[0]) + weight = multidevice_test.shard_tensor(weight_ref, fd.fusion.inputs()[1]) + + warmup_fn, benchmark_fn = get_benchmark_fns( + lambda: fd.execute( + [inp, weight], + _enable_options=["host_ir_lowering"], + ) + ) + warmup_fn() + benchmark.pedantic(benchmark_fn, rounds=5) + + @pytest.mark.mpi @pytest.mark.parametrize("backend_type", [CommunicatorBackend.nccl]) @pytest.mark.parametrize("s", [1, 8])