Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class Val;
f(SdpaFwdOp); \
f(SdpaBwdOp); \
f(EmbeddingFwdOp); \
f(CollectivePermute); \
f(Communication); \
f(P2PCommunication);
#define DISPATCH_FOR_ALL_KIR_EXPRS(f) \
Expand Down
127 changes: 127 additions & 0 deletions csrc/host_ir/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,38 @@ void HostIrEvaluator::handle(ShareMemHandles* share_mem_handles) {
ipc_handle_cache_.exchangeHandles(share_mem_handles->communications());
}

void HostIrEvaluator::handle(CollectivePermute* communication) {
NVF_ERROR(
communicator_ != nullptr && communicator_->is_available(),
"A valid communicator must be provided");

at::Tensor input_tensor = getKnownTensorOrUndefined(communication->input(0));
at::Tensor output_tensor =
getKnownTensorOrUndefined(communication->output(0));

#ifndef NDEBUG
validateSizesAndStrides(
{input_tensor, output_tensor},
{communication->in(), communication->out()},
expr_evaluator_);
#endif

CommunicatorBackend backend_type = communication->backend();
// CollectivePermute is only supported with NCCL backend because
// UCC does not support coalescing.
NVF_CHECK_EQ(backend_type, CommunicatorBackend::kNccl);
c10d::Backend* backend =
communicator_->getBackendForTeam(communication->team(), backend_type);
works_[communication] = postSingleCommunication(
communication,
communicator_->deviceId(),
backend,
input_tensor,
output_tensor,
expr_evaluator_.evaluate(communication->sendPeer()).as<int64_t>(),
expr_evaluator_.evaluate(communication->recvPeer()).as<int64_t>());
}

void HostIrEvaluator::handle(Communication* communication) {
NVF_ERROR(
communicator_ != nullptr && communicator_->is_available(),
Expand Down Expand Up @@ -532,6 +564,101 @@ void HostIrEvaluator::handle(hir::ForLoop* for_loop) {
auto stop = expr_evaluator_.evaluate(for_loop->stop()).as<int64_t>();

for (auto i = start; i < stop; i++) {
// This is not ideal. In lowering, we create communication expr.
// The collective permute has the output tensorview, and input vals of
// send_peer and recv_peer. While the definition of output_tv is not
// modified and remains `set`, this output_tv is a use of the vals Even
// though we shardByStream, the use of vals is not modified and has a
// dependency on T1. Cloned e: T1_g_float[istreamIdx6{1}, iS5{3}]
// (DeviceMesh{0})
// = Set( T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}),
// cache_op=Streaming )

// c: CollectivePermute 77 (team=(0), send_peer=( ( 1 + ( 0 - i140 ) ) % 1
// ), recv_peer=( ( i140 + 0 ) % 1 ), input=T0_g_float[ideviceIdx.x2{1},
// iS3{3}] (DeviceMesh{0}), output=T1_g_float[istreamIdx6{1}, iS5{3}]
// (DeviceMesh{0}), backend=NCCL)

// %HostIrContainer { (T0_g_float[ideviceIdx.x2{1}, iS3{3}]
// (DeviceMesh{0})) -> (T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}))
// :
// T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}) =
// ALLOCATE(buffer=T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}),
// mem_type=global, size=3, zero_init=false, resets_to_zero=false) Stream
// 0x281a6c60 = GetCurrentStream() FOR i140 from 0 to 1:
// SetCurrentStream(Stream i140)
// Synchronize(Stream 0x281a6c60)
// T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) =
// ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}),
// stream_index=i140) CollectivePermute 82 (team=(0), send_peer=( ( 1 +
// ( 0 - i140 ) ) % 1 ), recv_peer=( ( i140 + 0 ) % 1 ),
// input=T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}),
// output=T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}),
// backend=NCCL) Wait(Communication 82)
// SetCurrentStream(Stream 0x281a6c60)
// FOR i140 from 0 to 1:
// Synchronize(Stream i140)
// } // %HostIrContainer

// Invalidating index: i140
// allConsumerValsOf(i140)
// Visited val: i140
// Consumer of i140: i163 definition: i163 = 0 - i140;

// Visited val: i163
// Consumer of i163: i165 definition: i165 = 1 + i163;

// Visited val: i165
// Consumer of i165: i167 definition: i167 = i165 % 1;

// Visited val: i167
// Consumer of i167: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0})
// definition: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0})
// = Set( T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}),
// cache_op=Streaming )

// Visited val: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0})
// Consumer of i167: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0})
// definition: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) =
// ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}),
// stream_index=i140)

// Visited val: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0})
// Consumer of i140: i169 definition: i169 = i140 + 0;

// Visited val: i169
// Consumer of i169: i171 definition: i171 = i169 % 1;

// Visited val: i171
// Consumer of i171: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0})
// definition: T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0})
// = Set( T0_g_float[ideviceIdx.x2{1}, iS3{3}] (DeviceMesh{0}),
// cache_op=Streaming )

// Consumer of i171: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0})
// definition: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) =
// ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}),
// stream_index=i140)

// Consumer of i140: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0})
// definition: T2_l_float[istreamIdx10{1}, iS9{3}] (DeviceMesh{0}) =
// ShardByStream(T1_g_float[istreamIdx6{1}, iS5{3}] (DeviceMesh{0}),
// stream_index=i140)

// consumer_vals: 8
// Invalidating consumer: i169
// Invalidating consumer: T2_l_float[istreamIdx10{1}, iS9{3}]
// (DeviceMesh{0}) Invalidating consumer: T1_g_float[istreamIdx6{1},
// iS5{3}] (DeviceMesh{0}) Invalidating consumer: i167 Invalidating
// consumer: i165 Invalidating consumer: i163 Invalidating consumer: i171
// Invalidating consumer: i140

expr_evaluator_.invalidate(for_loop->index());
for (auto consumer : allConsumerValsOf(for_loop->index())) {
if (!consumer->isA<TensorView>()) {
expr_evaluator_.invalidate(consumer);
}
}
expr_evaluator_.bind(for_loop->index(), i);
for (Expr* e : for_loop->body().exprs()) {
dispatch(e);
Expand Down
1 change: 1 addition & 0 deletions csrc/host_ir/evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch {
void handle(Synchronize*) override;
void handle(PostOnStream*) override;
void handle(LaunchKernel*) override;
void handle(CollectivePermute*) override;
void handle(Communication*) override;
void handle(P2PCommunication*) override;
void handle(MoeDispatch*) override;
Expand Down
10 changes: 7 additions & 3 deletions csrc/host_ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,13 @@ Wait::Wait(IrBuilderPasskey passkey, Expr* expr)
this,
"must be registered in a HostIrContainer");
NVF_ERROR(
(expr->isOneOf<Communication, P2PCommunication, EndCoalescing>()),
expr,
" must be a Communication, a P2PCommunication, or a EndCoalescing");
(expr->isOneOf<
Communication,
CollectivePermute,
P2PCommunication,
EndCoalescing>()),
"Got: ",
expr);
}

NVFUSER_DEFINE_CLONE_AND_CREATE(Wait)
Expand Down
102 changes: 97 additions & 5 deletions csrc/host_ir/lower_to_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,32 @@ void lowerToBroadcast(
backend));
}

void lowerToStreamBroadcast(
TensorView* input_tv,
TensorView* output_tv,
const CommunicatorBackend backend,
std::vector<Expr*>& comms,
Val* root) {
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
NVF_ERROR_EQ(
sender_mesh,
receiver_mesh,
"StreamBroadcast sender and receiver meshes must be the same. Given ",
sender_mesh,
" and ",
receiver_mesh);
Team team = receiver_mesh.vector();
comms.push_back(IrBuilder::create<Communication>(
CommunicationType::StreamBroadcast,
output_tv,
input_tv,
team,
root,
c10d::ReduceOp::RedOpType::UNUSED,
backend));
}

// Adds several SendRecv communications to the vector 'comms'
// For now, we assume that this function is called only if
// the input and output have the same sharding. Later we could support more
Expand Down Expand Up @@ -319,6 +345,33 @@ void lowerToAllToAll(
backend));
}

void lowerToCollectivePermute(
TensorView* input_tv,
TensorView* output_tv,
const CommunicatorBackend backend,
std::vector<Expr*>& comms,
Val* root,
DeviceIdxType my_device_idx) {
NVF_ERROR_EQ(
input_tv->getDeviceMesh(),
output_tv->getDeviceMesh(),
"CollectivePermute sender and receiver meshes must be the same. Given ",
input_tv->getDeviceMesh(),
" and ",
output_tv->getDeviceMesh());

IterDomain* stream_id =
getShardedIterDomain(output_tv, ParallelType::Stream, DomainType::kLoop);
Swizzle1D* swizzle = stream_id->definition()->as<Swizzle1D>();
ParallelType pt = swizzle->parallelType();

const auto& [recv_peer, send_peer] =
dispatchSwizzle1D(root, my_device_idx, pt, input_tv->getDeviceMesh());
Team team = input_tv->getDeviceMesh().vector();
comms.push_back(IrBuilder::create<CollectivePermute>(
output_tv, input_tv, team, send_peer, recv_peer, backend));
}

IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) {
std::unordered_set<IterDomain*> logical_ids =
getInputsInTargetDomain({loop_id}, tv->getLogicalDomain());
Expand Down Expand Up @@ -370,14 +423,16 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
const auto c2p_map = pairwise_map.mapConsumerToProducer();

// This ignores device dimensions on reduction axis.
auto producer_pt_to_did =
const std::unordered_map<ParallelType, IterDomain*>& producer_pt_to_id =
mapDeviceAndStreamParallelTypeToId(producer->getLoopDomain());
auto consumer_pt_to_did =
const std::unordered_map<ParallelType, IterDomain*>& consumer_pt_to_id =
mapDeviceAndStreamParallelTypeToId(consumer->getLoopDomain());

for (ParallelType pt : kParallelTypeDIDs) {
IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt);
IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt);
IterDomain* p_loop_did = getOrDefault(producer_pt_to_id, pt);
IterDomain* c_loop_did = getOrDefault(consumer_pt_to_id, pt);
IterDomain* c_stream_id =
getOrDefault(consumer_pt_to_id, ParallelType::Stream);

if (p_loop_did == nullptr && c_loop_did == nullptr) {
// Not sharded on this parallel type
Expand All @@ -391,6 +446,25 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
if (e->isA<LoadStoreOp>()) {
if (p_loop_did && !c_loop_did) {
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
// Check if we are going from DIDx -> Stream, which is a ring allgather.
// This can be executed as a broadcast or send recvs, which is decided
// by the presence of a swizzle in the stream id definition.
if (c_stream_id != nullptr) {
IterDomain* c_stream_logical_id =
getLogicalFromLoopId(consumer, c_stream_id);
if (c_stream_logical_id == p2c_map.at(p_logical_id)) {
NVF_CHECK(
same_mesh,
"Broadcast based allgather in stream parallel requires same "
"mesh.");
auto* swizzle = dynamic_cast<Swizzle1D*>(c_stream_id->definition());
CommunicationType type = swizzle != nullptr
? CommunicationType::CollectivePermute
: CommunicationType::StreamBroadcast;
fill_communication_info(type, p_logical_id, c_stream_logical_id);
continue;
}
}
CommunicationType type = same_mesh ? CommunicationType::Allgather
: CommunicationType::Gather;
fill_communication_info(type, p_logical_id, p2c_map.at(p_logical_id));
Expand Down Expand Up @@ -478,7 +552,9 @@ Layout getCommunicationLayout(
type == CommunicationType::Allreduce ||
type == CommunicationType::Broadcast ||
type == CommunicationType::SendRecv ||
type == CommunicationType::AllToAll) {
type == CommunicationType::CollectivePermute ||
type == CommunicationType::AllToAll ||
type == CommunicationType::StreamBroadcast) {
return layout;
}

Expand Down Expand Up @@ -536,6 +612,7 @@ bool isCommunicationLayoutCompliant(Expr* expr) {
std::vector<Expr*> convertSingleOpToCommunication(
Expr* e,
DeviceIdxType my_device_idx,
Val* root,
const CommunicatorBackend backend) {
FusionGuard fg(e->fusion());

Expand Down Expand Up @@ -605,6 +682,21 @@ std::vector<Expr*> convertSingleOpToCommunication(
case CommunicationType::AllToAll:
lowerToAllToAll(input_tv, output_tv, backend, comms);
break;
case CommunicationType::StreamBroadcast:
NVF_ERROR(
root != nullptr,
"StreamBroadcast requires a root value passed in through lowering");
lowerToStreamBroadcast(input_tv, output_tv, backend, comms, root);
break;
case CommunicationType::CollectivePermute:
// FIXME: Rename this to host loop index. Collective Permute has no root.
// The send and recv peer indices are computed using the host loop index.
NVF_ERROR(
root != nullptr,
"CollectivePermute requires a root value passed in through lowering");
lowerToCollectivePermute(
input_tv, output_tv, backend, comms, root, my_device_idx);
break;
}

return comms;
Expand Down
6 changes: 6 additions & 0 deletions csrc/host_ir/lower_to_communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,15 @@ Layout getCommunicationLayout(
const CommunicationType type,
IterDomain* sharded_id);

// Creates a communication expr corresponding to the given
// resharding expr. In most cases, `root` is inferred based
// on communication type. However, in some cases, for e.g.
// decomposing allgather as broadcast in a host for-loop, `root`
// may be passed in through lowering.
std::vector<Expr*> convertSingleOpToCommunication(
Expr* c,
DeviceIdxType my_device_idx,
Val* root = nullptr,
const CommunicatorBackend backend = CommunicatorBackend::kNccl);

} // namespace nvfuser
Loading
Loading