Skip to content
19 changes: 16 additions & 3 deletions csrc/host_ir/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,12 +718,17 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) {
void HostIrEvaluator::handle(HirAliasSelect* hir_alias_select) {
auto indexed_id =
hir_alias_select->in()->getLogicalDomain().at(hir_alias_select->axis());
auto index = indexed_id->isBroadcast()
? 0
: expr_evaluator_.evaluate(hir_alias_select->index()).as<int64_t>();
auto input = getKnownConcreteValue(hir_alias_select->in()->as<TensorView>())
.as<at::Tensor>();

// If the axis being selected is a reduction axis, the tensor doesn't have
// that dimension (it was skipped during allocation). The select is a no-op -
// just bind the input tensor directly to the output.
if (indexed_id->isReduction()) {
expr_evaluator_.bind(hir_alias_select->out(), input);
return;
}

// Count reduction axes up to the target axis
int64_t reduction_count = std::count_if(
hir_alias_select->in()->getLogicalDomain().begin(),
Expand All @@ -732,6 +737,14 @@ void HostIrEvaluator::handle(HirAliasSelect* hir_alias_select) {
[](const IterDomain* id) { return id->isReduction(); });
// Adjust the ATen axis by subtracting the number of reduction axes
int64_t axis = hir_alias_select->axis() - reduction_count;

// Use index 0 if the IterDomain is marked as broadcast, or if the actual
// tensor dimension has size 1 (behaves like broadcast at runtime even if
// not marked as such in the IR)
auto index = (indexed_id->isBroadcast() || input.size(axis) == 1)
? 0
: expr_evaluator_.evaluate(hir_alias_select->index()).as<int64_t>();

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the broadcast op from the fusion. It failed here because it was trying to select on the D axis which is 1 locally

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the broadcast op from the fusion.

so why adding here the case of a broadcasted axis (which btw looks good to me) ?

It failed here because it was trying to select on the D axis which is 1 locally

so why not detect if the axis is sharded ? I think that checking that the axis is of size 1 is not correct. Firstly, if the dimension is DIDx then the symbolic size will be D and not 1. Secondly, if the axis is neither broadcast nor sharded but just happens to be of size 1, then we want to error out.

Does it make sense ?

expr_evaluator_.bind(hir_alias_select->out(), input.select(axis, index));
}

Expand Down
243 changes: 138 additions & 105 deletions csrc/host_ir/pass/stream_parallel_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,6 @@ void validateStreamAxis(IterDomain* stream_axis, const TensorView* tv) {
it_logical_stream_axis != tv->getLogicalDomain().end(),
"Cannot stream parallelize on a split/merge axis ",
stream_axis);

// Verify stream axis is an iteration or broadcast axis
NVF_CHECK(
stream_axis->getIterType() == IterType::Iteration ||
stream_axis->getIterType() == IterType::Broadcast,
"Stream axis ",
stream_axis,
" should be an iteration or broadcast axis.");
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We stream parallelize the reduced axis in the sum op

}

// Checks if two iteration domains are mapped in the ID model
Expand Down Expand Up @@ -371,108 +363,149 @@ std::list<Expr*> processForLoopBodies(

// Lower to MM + RS algorithm
if (did_to_stream && stream_to_did) {
NVF_ERROR(
body_expr->isA<LoadStoreOp>() &&
body_expr->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set,
"expected a set operation but got ",
body_expr);
NVF_ERROR(
body_expr->isA<LoadStoreOp>(),
"expected a Tv operation but got ",
body_expr);
NVF_ERROR(
params.offset_stream_indexing_by_rank == true,
"offset_stream_indexing_by_rank==false not supported for "
"ReduceScatter patterns");
auto* set_op = body_expr->as<LoadStoreOp>();
auto* input_tv = set_op->in()->as<TensorView>();
auto* output_tv = set_op->out()->as<TensorView>();
NVF_ERROR(
input_tv->axis(0)->isDeviceDim(),
"expected a sharded first axis on the input but got ",
input_tv);
NVF_ERROR(
output_tv->axis(0)->getParallelType() == ParallelType::Stream,
"expected a stream parallelized first axis on the output but got ",
output_tv);
NVF_ERROR(
input_tv->axis(1)->getParallelType() == ParallelType::Stream,
"expected a stream parallelized second axis on the input but got ",
input_tv);
NVF_ERROR(
output_tv->axis(1)->isDeviceDim(),
"expected a sharded second axis on the output but got ",
output_tv);
auto* is_sending_to_self =
IrBuilder::create<kir::Predicate>(eq(tensor_index, my_device_id));
auto if_sending_to_self =
IrBuilder::create<kir::IfThenElse>(is_sending_to_self);
auto [slicing_input, is_new] = tensor_slicing_cache.get(
input_tv,
/*dim*/
findStreamAxisIndex(input_tv, for_loop->iterDomain(), id_model),
/*index=*/tensor_index);
auto [slicing_output, is_new_] =
tensor_slicing_cache.get(output_tv, /*dim*/ 0, /*index=*/recv_peer);
auto* local_copy = IrBuilder::create<LoadStoreOp>(
LoadStoreOpType::Set, slicing_output->out(), slicing_input->out());
if_sending_to_self->thenBody().pushBack(local_copy);
auto recv = IrBuilder::create<P2PCommunication>(
P2PCommunicationType::RECV,
slicing_output->out(),
recv_peer,
communicator_backend);
auto send = IrBuilder::create<P2PCommunication>(
P2PCommunicationType::SEND,
slicing_input->out(),
tensor_index,
communicator_backend);
if (communicator_backend == CommunicatorBackend::kNccl) {
auto start_coalescing = IrBuilder::create<hir::StartCoalescing>();
auto end_coalescing = IrBuilder::create<hir::EndCoalescing>();
auto wait = IrBuilder::create<hir::Wait>(end_coalescing);

if_sending_to_self->elseBody().pushBack(start_coalescing);
if_sending_to_self->elseBody().pushBack(recv);
if_sending_to_self->elseBody().pushBack(send);
if_sending_to_self->elseBody().pushBack(end_coalescing);
if_sending_to_self->elseBody().pushBack(wait);
} else if (communicator_backend == CommunicatorBackend::kCuda) {
auto share_mem_handles = IrBuilder::create<hir::ShareMemHandles>(
std::vector<P2PCommunication*>({recv, send}));
auto wait_send = IrBuilder::create<hir::Wait>(send);
auto wait_recv = IrBuilder::create<hir::Wait>(recv);

if_sending_to_self->elseBody().pushBack(share_mem_handles);
switch (getP2pProtocol()) {
case P2pProtocol::Get: {
if_sending_to_self->elseBody().pushBack(send);
if_sending_to_self->elseBody().pushBack(recv);
break;
}
case P2pProtocol::Put: {
if_sending_to_self->elseBody().pushBack(recv);
if_sending_to_self->elseBody().pushBack(send);
break;
if (params.offset_stream_indexing_by_rank) {
// Lower to MM + RS p2p based algorithm
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block is the same as before. The only difference is the indentation level

NVF_ERROR(
body_expr->isA<LoadStoreOp>() &&
body_expr->as<LoadStoreOp>()->opType() ==
LoadStoreOpType::Set,
"expected a set operation but got ",
body_expr);
NVF_ERROR(
body_expr->isA<LoadStoreOp>(),
"expected a Tv operation but got ",
body_expr);
auto* set_op = body_expr->as<LoadStoreOp>();
auto* input_tv = set_op->in()->as<TensorView>();
auto* output_tv = set_op->out()->as<TensorView>();
NVF_ERROR(
input_tv->axis(0)->isDeviceDim(),
"expected a sharded first axis on the input but got ",
input_tv);
NVF_ERROR(
output_tv->axis(0)->getParallelType() == ParallelType::Stream,
"expected a stream parallelized first axis on the output but "
"got ",
output_tv);
NVF_ERROR(
input_tv->axis(1)->getParallelType() == ParallelType::Stream,
"expected a stream parallelized second axis on the input but "
"got ",
input_tv);
NVF_ERROR(
output_tv->axis(1)->isDeviceDim(),
"expected a sharded second axis on the output but got ",
output_tv);
auto* is_sending_to_self =
IrBuilder::create<kir::Predicate>(eq(tensor_index, my_device_id));
auto if_sending_to_self =
IrBuilder::create<kir::IfThenElse>(is_sending_to_self);
auto [slicing_input, is_new] = tensor_slicing_cache.get(
input_tv,
/*dim*/
findStreamAxisIndex(input_tv, for_loop->iterDomain(), id_model),
/*index=*/tensor_index);
auto [slicing_output, is_new_] = tensor_slicing_cache.get(
output_tv, /*dim*/ 0, /*index=*/recv_peer);
auto* local_copy = IrBuilder::create<LoadStoreOp>(
LoadStoreOpType::Set,
slicing_output->out(),
slicing_input->out());
if_sending_to_self->thenBody().pushBack(local_copy);
auto recv = IrBuilder::create<P2PCommunication>(
P2PCommunicationType::RECV,
slicing_output->out(),
recv_peer,
communicator_backend);
auto send = IrBuilder::create<P2PCommunication>(
P2PCommunicationType::SEND,
slicing_input->out(),
tensor_index,
communicator_backend);
if (communicator_backend == CommunicatorBackend::kNccl) {
auto start_coalescing = IrBuilder::create<hir::StartCoalescing>();
auto end_coalescing = IrBuilder::create<hir::EndCoalescing>();
auto wait = IrBuilder::create<hir::Wait>(end_coalescing);

if_sending_to_self->elseBody().pushBack(start_coalescing);
if_sending_to_self->elseBody().pushBack(recv);
if_sending_to_self->elseBody().pushBack(send);
if_sending_to_self->elseBody().pushBack(end_coalescing);
if_sending_to_self->elseBody().pushBack(wait);
} else if (communicator_backend == CommunicatorBackend::kCuda) {
auto share_mem_handles = IrBuilder::create<hir::ShareMemHandles>(
std::vector<P2PCommunication*>({recv, send}));
auto wait_send = IrBuilder::create<hir::Wait>(send);
auto wait_recv = IrBuilder::create<hir::Wait>(recv);

if_sending_to_self->elseBody().pushBack(share_mem_handles);
switch (getP2pProtocol()) {
case P2pProtocol::Get: {
if_sending_to_self->elseBody().pushBack(send);
if_sending_to_self->elseBody().pushBack(recv);
break;
}
case P2pProtocol::Put: {
if_sending_to_self->elseBody().pushBack(recv);
if_sending_to_self->elseBody().pushBack(send);
break;
}
}
if_sending_to_self->elseBody().pushBack(wait_recv);
// Defer the wait on send to the loop epilogue under the same
// predicate
auto* deferred_wait_if = IrBuilder::create<kir::IfThenElse>(
if_sending_to_self->input(0)->as<kir::Predicate>());
deferred_wait_if->elseBody().pushBack(wait_send);
new_loop_body_epilogue.push_back(deferred_wait_if);
} else {
NVF_THROW(
"Unsupported communicator backend for lowering stream parallel "
"type into p2p: ",
communicator_backend);
}
if_sending_to_self->elseBody().pushBack(wait_recv);
// Defer the wait on send to the loop epilogue under the same
// predicate
auto* deferred_wait_if = IrBuilder::create<kir::IfThenElse>(
if_sending_to_self->input(0)->as<kir::Predicate>());
deferred_wait_if->elseBody().pushBack(wait_send);
new_loop_body_epilogue.push_back(deferred_wait_if);
new_loop_body.push_back(slicing_input);
new_loop_body.push_back(slicing_output);
new_loop_body.push_back(if_sending_to_self);
} else {
NVF_THROW(
"Unsupported communicator backend for lowering stream parallel "
"type into p2p: ",
// Lower to the MM+RS reduce-collective-based algorithm
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block is the core change

NVF_ERROR(
body_expr->isA<ReductionOp>(),
"expected a reduction operation but got ",
body_expr);
NVF_ERROR(
body_expr->as<ReductionOp>()->getReductionOpType() ==
BinaryOpType::Add,
"expected a reduce operation but got ",
body_expr);
auto* reduction_op = body_expr->as<ReductionOp>();
auto* input_tv = reduction_op->in()->as<TensorView>();
auto* output_tv = reduction_op->out()->as<TensorView>();
auto [slicing_input, is_new] = tensor_slicing_cache.get(
input_tv,
/*dim*/
findStreamAxisIndex(input_tv, for_loop->iterDomain(), id_model),
/*index=*/tensor_index);
auto [slicing_output, is_new_] = tensor_slicing_cache.get(
output_tv,
/*dim*/
findStreamAxisIndex(output_tv, for_loop->iterDomain(), id_model),
/*index=*/tensor_index);
auto reduce = IrBuilder::create<Communication>(
CommunicationType::Reduce,
slicing_output->out(),
slicing_input->out(),
input_tv->getDeviceMesh().vector(),
tensor_index,
c10d::ReduceOp::RedOpType::SUM,
communicator_backend);
auto wait = IrBuilder::create<hir::Wait>(reduce);
new_loop_body.push_back(slicing_input);
new_loop_body.push_back(slicing_output);
new_loop_body.push_back(reduce);
new_loop_body.push_back(wait);
}
new_loop_body.push_back(slicing_input);
new_loop_body.push_back(slicing_output);
new_loop_body.push_back(if_sending_to_self);
} else if (did_to_stream) {
} else if (did_to_stream && !stream_to_did) {
// Lower to AG+MM algorithm if did_to_stream=true && stream_to_did=false
//
// We have a special handling for when an axis pass from DIDx to Stream
Expand Down
18 changes: 0 additions & 18 deletions tests/cpp/test_host_ir_stream_lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,24 +438,6 @@ TEST_F(HirLowerStreamTest, Matmul_N) {
<< "Output: " << output << " Expected: " << expected_output;
}

TEST_F(HirLowerStreamTest, Matmul_K) {
auto hic = std::make_unique<HostIrContainer>();
FusionGuard fg(hic.get());
TensorView* a = makeContigTensor(2);
TensorView* b = makeContigTensor(2);
TensorView* c = matmul(a, b);
hic->addInput(a);
hic->addInput(b);
hic->addOutput(c);
hic->pushBackTopLevelExprs(c->definition());
a->setMemoryType(MemoryType::Global);
b->setMemoryType(MemoryType::Global);
c->setMemoryType(MemoryType::Global);
c->axis(-1)->parallelize(ParallelType::Stream);

EXPECT_ANY_THROW(hir_pass::StreamParallelType().runPass(hic.get()));
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the reduced axis can be stream parallelized, this test fails. So here I removed it.

// We don's support PostOnStream because it does not support well pre-allocated
// outputs. There is no strong motivation to support PostOnStream
TEST_F(HirLowerStreamTest, DoNotSupportPostOnStream) {
Expand Down
Loading
Loading