Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions csrc/host_ir/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,23 @@ void HostIrEvaluator::handle(hir::ForLoop* for_loop) {
auto stop = expr_evaluator_.evaluate(for_loop->stop()).as<int64_t>();

for (auto i = start; i < stop; i++) {
// Expressions dependent on loop index and all allocations
// inside the loop body should be invalidated. We cannot
// simply use allConsumerValsOf because the loop index can be an input to
// fusion outputs or buffers allocated outside the loop.
std::unordered_set<Val*> allocations;
for (Expr* e : for_loop->body().exprs()) {
if (auto* alloc = dynamic_cast<kir::Allocate*>(e)) {
allocations.insert(alloc->buffer());
}
}
expr_evaluator_.invalidate(for_loop->index());
for (auto consumer : allConsumerValsOf(for_loop->index())) {
if (consumer->isA<TensorView>() && !allocations.contains(consumer)) {
continue;
}
expr_evaluator_.invalidate(consumer);
}
expr_evaluator_.bind(for_loop->index(), i);
for (Expr* e : for_loop->body().exprs()) {
dispatch(e);
Expand Down
66 changes: 61 additions & 5 deletions csrc/host_ir/lower_to_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,32 @@ void lowerToBroadcast(
backend));
}

void lowerToStreamBroadcast(
TensorView* input_tv,
TensorView* output_tv,
const CommunicatorBackend backend,
std::vector<Expr*>& comms,
Val* root) {
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
NVF_ERROR_EQ(
sender_mesh,
receiver_mesh,
"StreamBroadcast sender and receiver meshes must be the same. Given ",
sender_mesh,
" and ",
receiver_mesh);
Team team = receiver_mesh.vector();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

team contains absolute device IDs (e.g. {4, 5, 6, 7}), but root is a loop index Val* that evaluates to 0, 1, 2, ... at runtime. In postBroadcast (communication.cpp:474), the code checks if (my_device_index == root_index) and calls getRootRelativeIndex(root_index) which does std::find(team.begin(), team.end(), root_index). Both expect root_index to be an actual device ID from the team, not a loop iteration index. For a mesh like {4,5,6,7}, loop index 0 won't match any device and the assertion at communication.cpp:255 will fire.

Root should be the device ID at the loop index position: receiver_mesh.at(loop_index). Since loop_index is evaluated at runtime, you'll need to compute this device ID lookup at evaluation time (e.g., via an IR expression that indexes into the mesh).

comms.push_back(IrBuilder::create<Communication>(
CommunicationType::StreamBroadcast,
output_tv,
input_tv,
team,
root,
c10d::ReduceOp::RedOpType::UNUSED,
backend));
Comment on lines +185 to +193
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StreamBroadcast root is a raw loop index, not a device ID

root is the for-loop index—an integer in [0, d). However team is receiver_mesh.vector(), which contains absolute device IDs. Inside postBroadcast the root is used in two ways that both assume it equals an absolute device ID:

  1. if (my_device_index == root_index) — decides which device does the local copy.
  2. getRootRelativeIndex(root_index) — calls std::find(team.begin(), team.end(), root_index) and asserts the value is present in the team.

For a mesh such as DeviceMesh({4, 5, 6, 7}), loop index 0 is neither equal to any my_device_index in {4,5,6,7} nor present in the team vector, so the assert fires at runtime.

The root should be the mesh device at position loop_index, i.e. receiver_mesh.at(loop_index). Since loop_index is a Val* evaluated at runtime, one approach is to look up the mesh device ID at evaluation time (e.g. via a GetItem/helper expression on the mesh tensor), or document this as a hard requirement (mesh must be arange(d)) and add a validation check at communication-creation time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Communication can accept a Val as root or a DeviceIdx

}

// Adds several SendRecv communications to the vector 'comms'
// For now, we assume that this function is called only if
// the input and output have the same sharding. Later we could support more
Expand Down Expand Up @@ -371,14 +397,17 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
const auto c2p_map = pairwise_map.mapConsumerToProducer();

// This ignores device dimensions on reduction axis.
auto producer_pt_to_did =
const std::unordered_map<ParallelType, IterDomain*>& producer_pt_to_id =
mapDeviceAndStreamParallelTypeToId(producer->getLoopDomain());
auto consumer_pt_to_did =
const std::unordered_map<ParallelType, IterDomain*>& consumer_pt_to_id =
mapDeviceAndStreamParallelTypeToId(consumer->getLoopDomain());

IterDomain* c_stream_id =
getOrDefault(consumer_pt_to_id, ParallelType::Stream);

for (ParallelType pt : kParallelTypeDIDs) {
IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt);
IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt);
IterDomain* p_loop_did = getOrDefault(producer_pt_to_id, pt);
IterDomain* c_loop_did = getOrDefault(consumer_pt_to_id, pt);

if (p_loop_did == nullptr && c_loop_did == nullptr) {
// Not sharded on this parallel type
Expand All @@ -392,6 +421,25 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
if (e->isA<LoadStoreOp>()) {
if (p_loop_did && !c_loop_did) {
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
// Check if we are going from DIDx -> Stream, which is a ring allgather.
// This can be executed as a broadcast or send recvs, which is decided
// by the presence of a swizzle in the stream id definition.
// TODO: Lower to SendRecv if swizzle is present.
if (c_stream_id != nullptr) {
IterDomain* c_stream_logical_id =
getLogicalFromLoopId(consumer, c_stream_id);
if (c_stream_logical_id == p2c_map.at(p_logical_id)) {
NVF_CHECK(
same_mesh,
"Broadcast based allgather in stream parallel requires same "
"mesh.");
fill_communication_info(
CommunicationType::StreamBroadcast,
p_logical_id,
c_stream_logical_id);
continue;
}
}
CommunicationType type = same_mesh ? CommunicationType::Allgather
: CommunicationType::Gather;
fill_communication_info(type, p_logical_id, p2c_map.at(p_logical_id));
Expand Down Expand Up @@ -479,7 +527,8 @@ Layout getCommunicationLayout(
type == CommunicationType::Allreduce ||
type == CommunicationType::Broadcast ||
type == CommunicationType::SendRecv ||
type == CommunicationType::AllToAll) {
type == CommunicationType::AllToAll ||
type == CommunicationType::StreamBroadcast) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understood the motivation but can this be consolidated into the same Broadcast?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept it separate so I don't need to check for the StreamParallel Type in lowerToBroadcast when deciding the root. Posting the communication uses a common function.
I also wanted to first integrate SendRecv based decomposition and then reconsider the design based on what is needed for both these comms.

return layout;
}

Expand Down Expand Up @@ -537,6 +586,7 @@ bool isCommunicationLayoutCompliant(Expr* expr) {
std::vector<Expr*> convertSingleOpToCommunication(
Expr* e,
DeviceIdxType my_device_idx,
Val* root,
const CommunicatorBackend backend) {
FusionGuard fg(e->fusion());

Expand Down Expand Up @@ -606,6 +656,12 @@ std::vector<Expr*> convertSingleOpToCommunication(
case CommunicationType::AllToAll:
lowerToAllToAll(input_tv, output_tv, backend, comms);
break;
case CommunicationType::StreamBroadcast:
NVF_ERROR(
root != nullptr,
"StreamBroadcast requires a root value passed in through lowering");
lowerToStreamBroadcast(input_tv, output_tv, backend, comms, root);
break;
}

return comms;
Expand Down
6 changes: 6 additions & 0 deletions csrc/host_ir/lower_to_communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,15 @@ Layout getCommunicationLayout(
const CommunicationType type,
IterDomain* sharded_id);

// Creates a communication expr corresponding to the given
// resharding expr. In most cases, `root` is inferred based
// on communication type. However, in some cases, for e.g.
// decomposing allgather as broadcast in a host for-loop, `root`
// may be passed in through lowering.
std::vector<Expr*> convertSingleOpToCommunication(
Expr* c,
DeviceIdxType my_device_idx,
Val* root = nullptr,
const CommunicatorBackend backend = CommunicatorBackend::kNccl);

} // namespace nvfuser
11 changes: 8 additions & 3 deletions csrc/host_ir/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ Expr* cloneWithNewOperands(
return e;
}

return e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes());
auto* new_e =
e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes());
e->container()->removeExpr(e);
return new_e;
}

void lowerSegment(
Expand Down Expand Up @@ -178,15 +181,17 @@ void lowerSegment(
// TODO: `replacement_map` should be associated with the scope so
// ShardByStream across segments in the same for-loop can be reused.
std::unordered_map<Val*, Val*> replacement_map;
for (Expr* c : convertSingleOpToCommunication(e, device_id)) {
Val* root = loop_nest.empty() ? nullptr : innermost.loop->index();
for (Expr* c : convertSingleOpToCommunication(e, device_id, root)) {
NVF_ERROR(
c->isA<Communication>(),
"Exprs in a Communication group should be Communication: ",
c);
auto* communication = c->as<Communication>();
TensorView* in = communication->in();
TensorView* out = communication->out();
if (haveDifferentShardings(
if (communication->type() != CommunicationType::StreamBroadcast &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I understood the motivation and that the tests pass, I'm thinking how to make this cleaner.

Is it possible to frame this as an optimization? For example, if in can be sharded on Stream in the same way as communication, insert a shardByStream.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I do think this should be merged in the shardByStream or some other logic.
For now, I kept it simple since I am not sure how it will look like with Collective Permute representation (a composite Communication, P2P comms corresponding to SendRecv etc.), so I took the verbose approach as an interim step.

Let me see what I can do in this PR itself.

haveDifferentShardings(
in,
DomainType::kAllocation,
out,
Expand Down
6 changes: 4 additions & 2 deletions csrc/host_ir/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e) {
destination, ParallelType::Stream, DomainType::kAllocation) !=
nullptr,
"Destination allocation should be sharded on stream after "
"shardAllocationAsLoop: ",
destination);
"shardAllocationAsLoop. ",
destination->name(),
":",
destination->domain()->toString(0, /*loop_only=*/false));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess destination is still worth printing in addition to the domain?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

destination prints the loop domain. I added name above to be printed in addition to the complete domain


// Refine the contiguity flags so `out` aliases `in`. This is done similar
// to AliasFinder::handle(const SliceOp*). We scan through the allocation
Expand Down
5 changes: 4 additions & 1 deletion csrc/host_ir/pass/convert_op_to_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ void ConvertOpToCommunication::passImplementation(Fusion* fusion) {
return new_top_level_exprs.push_back(top_level_expr);
}
for (auto* expr : nvfuser::convertSingleOpToCommunication(
top_level_expr, my_device_index, params_.communicator_backend)) {
top_level_expr,
my_device_index,
/*root=*/nullptr,
params_.communicator_backend)) {
// Allocate the recv buffers of communications
if (expr->isA<Communication>()) {
auto* communication = expr->as<Communication>();
Expand Down
6 changes: 6 additions & 0 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) {
case CommunicationType::AllToAll:
os << "AllToAll";
break;
case CommunicationType::StreamBroadcast:
os << "StreamBroadcast";
break;
}
return os;
}
Expand Down Expand Up @@ -149,6 +152,7 @@ bool hasRoot(CommunicationType type) {
case CommunicationType::Reduce:
case CommunicationType::Broadcast:
case CommunicationType::SendRecv:
case CommunicationType::StreamBroadcast:
return true;
case CommunicationType::Allgather:
case CommunicationType::Allreduce:
Expand All @@ -171,6 +175,7 @@ bool isReduction(CommunicationType type) {
case CommunicationType::Broadcast:
case CommunicationType::SendRecv:
case CommunicationType::AllToAll:
case CommunicationType::StreamBroadcast:
return false;
default:
NVF_THROW("unrecognized CommunicationType: ", type);
Expand Down Expand Up @@ -853,6 +858,7 @@ c10::intrusive_ptr<c10d::Work> postSingleCommunication(
input_tensor,
output_tensor);
case CommunicationType::Broadcast:
case CommunicationType::StreamBroadcast:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StreamBroadcast reuses broadcast logic but csrc/host_ir/evaluator.cpp:334-336 and :405-407 still reject it for the CUDA backend multicast path. If CUDA backend is intended to work with StreamBroadcast, add it to those checks; otherwise this will fail at runtime when using CommunicatorBackend::kCuda.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StreamBroadcast shares broadcast logic but csrc/multidevice/cuda_p2p.cpp:645-667 and :689-704 don't handle it in their switch statements. This will cause runtime errors with CUDA backend. Either add StreamBroadcast cases (treating them like Broadcast) or ensure CUDA backend is never used with this communication type.

return postBroadcast(
communication,
my_device_index,
Expand Down
7 changes: 6 additions & 1 deletion csrc/multidevice/communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ enum class CommunicationType {
ReduceScatter,
Broadcast,
SendRecv,
AllToAll
AllToAll,
StreamBroadcast,
};

std::ostream& operator<<(std::ostream& os, const CommunicationType& type);
Expand Down Expand Up @@ -347,6 +348,10 @@ class MoeCombine : public Expr {
// - the root has one src buffer, and no or one dst buffer
// - non-roots have no src buffer and one dst buffer
// - all buffers have the same size
// (*) StreamBroadcast
// Shares the same postBroadcast logic with Broadcast. The difference is the
// root is the for-loop index. I kept it separate from Broadcast so I don't need
// to inspect the tensorviews later to distinguish the two.
// (*) Gather
// Copies each device's source buffer to the root's respective src
// buffer. The order of the sender devices matches the order of the
Expand Down
114 changes: 114 additions & 0 deletions tests/python/multidevice/test_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,120 @@ def test_column_parallel_linear_forward_reference_benchmark(
benchmark.pedantic(benchmark_fn, rounds=5)


def column_parallel_linear_forward(h: int, d: int):
with FusionDefinition() as fd:
inp_tv = fd.define_tensor((-1, h), contiguity=True, dtype=DataType.BFloat16)
weight_tv = fd.define_tensor(
(4 * h, h), contiguity=True, dtype=DataType.BFloat16
)
ag_out = fd.ops.set(inp_tv)
out_tv = fd.ops.linear(ag_out, weight_tv)
fd.add_output(out_tv)

mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d))

for tv in [inp_tv, weight_tv]:
tv.set_device_mesh(mesh)
tv.outer_split(0, d)
tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x)

ag_out.set_device_mesh(mesh)
ag_out.outer_split(0, d)
ag_out.axis(0).parallelize(nvfuser.ParallelType.stream)

# Fusion IR before segmentation will look like this:
# [t, h]
# /\.
# d
# (deviceIdx.x)
# |
# | set (lowered to StreamBroadcast. This decomposition is done manually in the definition above. It will later be done by preseg)
# |
# [t, h] [4h, h]
# /\ /\.
# s d
# (streamIdx)
# |
# | linear
# |
# [t, 4h, r{h}]
# /\ /\.
# s* d

return fd


@pytest.mark.mpi
def test_column_parallel_linear_forward(multidevice_test):
# This is a port of CollectiveBasedOverlapTest.ColumnAndSequenceParallelLinear_Forward.
# The difference is we are using broadcast based overlapping instead of send/recv.
h, t = 2, 24
d = multidevice_test.size
if (h * 4) % d != 0:
pytest.skip(
f"Row-parallel linear requires {h * 4} to be divisible by world size {d}."
)
if t % d != 0:
pytest.skip(
f"Column-parallel linear requires {t} to be divisible by world size {d}."
)

fd = column_parallel_linear_forward(h, d)

inp_ref = torch.testing.make_tensor(t, h, dtype=torch.int32, device="cpu").to(
torch.bfloat16
)
weight_ref = torch.testing.make_tensor(
4 * h, h, dtype=torch.int32, device="cpu"
).to(torch.bfloat16)

inp = multidevice_test.shard_tensor(inp_ref, fd.fusion.inputs()[0])
weight = multidevice_test.shard_tensor(weight_ref, fd.fusion.inputs()[1])

out_ref = torch.nn.functional.linear(inp_ref.cuda(), weight)

with torch.profiler.profile(record_shapes=True) as prof:
(out,) = fd.execute([inp, weight], _enable_options=["host_ir_lowering"])
torch.testing.assert_close(out, out_ref)
broadcast_events = [
event for event in prof.events() if "ncclDevKernel_Broadcast" in event.name
]
assert len(broadcast_events) == (d if d > 1 else 0)


@pytest.mark.mpi
@pytest.mark.benchmark
def test_column_parallel_linear_forward_benchmark(multidevice_test, benchmark):
# This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward.
h, t = 8192, 8192
d = multidevice_test.size
if (4 * h) % d != 0:
pytest.skip(
f"Column-parallel linear requires {4 * h} to be divisible by world size {d}."
)
if t % d != 0:
pytest.skip(
f"Column-parallel linear requires {t} to be divisible by world size {d}."
)

fd = column_parallel_linear_forward(h, d)

inp_ref = torch.randn(t, h, dtype=torch.bfloat16, device="cpu")
weight_ref = torch.randn(4 * h, h, dtype=torch.bfloat16, device="cpu")

inp = multidevice_test.shard_tensor(inp_ref, fd.fusion.inputs()[0])
weight = multidevice_test.shard_tensor(weight_ref, fd.fusion.inputs()[1])

warmup_fn, benchmark_fn = get_benchmark_fns(
lambda: fd.execute(
[inp, weight],
_enable_options=["host_ir_lowering"],
)
)
warmup_fn()
benchmark.pedantic(benchmark_fn, rounds=5)


@pytest.mark.mpi
@pytest.mark.parametrize("backend_type", [CommunicatorBackend.nccl])
@pytest.mark.parametrize("s", [1, 8])
Expand Down