-
Notifications
You must be signed in to change notification settings - Fork 78
Reduce-based MM+RS in MultiDeviceExecutor #5923
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
704cf58
cd9efb5
e4410b6
5f605b2
92cca3e
a67afdb
1367127
a363e0d
e603771
3ada821
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,14 +59,6 @@ void validateStreamAxis(IterDomain* stream_axis, const TensorView* tv) { | |
| it_logical_stream_axis != tv->getLogicalDomain().end(), | ||
| "Cannot stream parallelize on a split/merge axis ", | ||
| stream_axis); | ||
|
|
||
| // Verify stream axis is an iteration or broadcast axis | ||
| NVF_CHECK( | ||
| stream_axis->getIterType() == IterType::Iteration || | ||
| stream_axis->getIterType() == IterType::Broadcast, | ||
| "Stream axis ", | ||
| stream_axis, | ||
| " should be an iteration or broadcast axis."); | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We stream parallelize the reduced axis in the sum op |
||
| } | ||
|
|
||
| // Checks if two iteration domains are mapped in the ID model | ||
|
|
@@ -371,108 +363,149 @@ std::list<Expr*> processForLoopBodies( | |
|
|
||
| // Lower to MM + RS algorithm | ||
| if (did_to_stream && stream_to_did) { | ||
| NVF_ERROR( | ||
| body_expr->isA<LoadStoreOp>() && | ||
| body_expr->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set, | ||
| "expected a set operation but got ", | ||
| body_expr); | ||
| NVF_ERROR( | ||
| body_expr->isA<LoadStoreOp>(), | ||
| "expected a Tv operation but got ", | ||
| body_expr); | ||
| NVF_ERROR( | ||
| params.offset_stream_indexing_by_rank == true, | ||
| "offset_stream_indexing_by_rank==false not supported for " | ||
| "ReduceScatter patterns"); | ||
| auto* set_op = body_expr->as<LoadStoreOp>(); | ||
| auto* input_tv = set_op->in()->as<TensorView>(); | ||
| auto* output_tv = set_op->out()->as<TensorView>(); | ||
| NVF_ERROR( | ||
| input_tv->axis(0)->isDeviceDim(), | ||
| "expected a sharded first axis on the input but got ", | ||
| input_tv); | ||
| NVF_ERROR( | ||
| output_tv->axis(0)->getParallelType() == ParallelType::Stream, | ||
| "expected a stream parallelized first axis on the output but got ", | ||
| output_tv); | ||
| NVF_ERROR( | ||
| input_tv->axis(1)->getParallelType() == ParallelType::Stream, | ||
| "expected a stream parallelized second axis on the input but got ", | ||
| input_tv); | ||
| NVF_ERROR( | ||
| output_tv->axis(1)->isDeviceDim(), | ||
| "expected a sharded second axis on the output but got ", | ||
| output_tv); | ||
| auto* is_sending_to_self = | ||
| IrBuilder::create<kir::Predicate>(eq(tensor_index, my_device_id)); | ||
| auto if_sending_to_self = | ||
| IrBuilder::create<kir::IfThenElse>(is_sending_to_self); | ||
| auto [slicing_input, is_new] = tensor_slicing_cache.get( | ||
| input_tv, | ||
| /*dim*/ | ||
| findStreamAxisIndex(input_tv, for_loop->iterDomain(), id_model), | ||
| /*index=*/tensor_index); | ||
| auto [slicing_output, is_new_] = | ||
| tensor_slicing_cache.get(output_tv, /*dim*/ 0, /*index=*/recv_peer); | ||
| auto* local_copy = IrBuilder::create<LoadStoreOp>( | ||
| LoadStoreOpType::Set, slicing_output->out(), slicing_input->out()); | ||
| if_sending_to_self->thenBody().pushBack(local_copy); | ||
| auto recv = IrBuilder::create<P2PCommunication>( | ||
| P2PCommunicationType::RECV, | ||
| slicing_output->out(), | ||
| recv_peer, | ||
| communicator_backend); | ||
| auto send = IrBuilder::create<P2PCommunication>( | ||
| P2PCommunicationType::SEND, | ||
| slicing_input->out(), | ||
| tensor_index, | ||
| communicator_backend); | ||
| if (communicator_backend == CommunicatorBackend::kNccl) { | ||
| auto start_coalescing = IrBuilder::create<hir::StartCoalescing>(); | ||
| auto end_coalescing = IrBuilder::create<hir::EndCoalescing>(); | ||
| auto wait = IrBuilder::create<hir::Wait>(end_coalescing); | ||
|
|
||
| if_sending_to_self->elseBody().pushBack(start_coalescing); | ||
| if_sending_to_self->elseBody().pushBack(recv); | ||
| if_sending_to_self->elseBody().pushBack(send); | ||
| if_sending_to_self->elseBody().pushBack(end_coalescing); | ||
| if_sending_to_self->elseBody().pushBack(wait); | ||
| } else if (communicator_backend == CommunicatorBackend::kCuda) { | ||
| auto share_mem_handles = IrBuilder::create<hir::ShareMemHandles>( | ||
| std::vector<P2PCommunication*>({recv, send})); | ||
| auto wait_send = IrBuilder::create<hir::Wait>(send); | ||
| auto wait_recv = IrBuilder::create<hir::Wait>(recv); | ||
|
|
||
| if_sending_to_self->elseBody().pushBack(share_mem_handles); | ||
| switch (getP2pProtocol()) { | ||
| case P2pProtocol::Get: { | ||
| if_sending_to_self->elseBody().pushBack(send); | ||
| if_sending_to_self->elseBody().pushBack(recv); | ||
| break; | ||
| } | ||
| case P2pProtocol::Put: { | ||
| if_sending_to_self->elseBody().pushBack(recv); | ||
| if_sending_to_self->elseBody().pushBack(send); | ||
| break; | ||
| if (params.offset_stream_indexing_by_rank) { | ||
| // Lower to MM + RS p2p based algorithm | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block is the same as before. The only difference is the indentation level |
||
| NVF_ERROR( | ||
| body_expr->isA<LoadStoreOp>() && | ||
| body_expr->as<LoadStoreOp>()->opType() == | ||
| LoadStoreOpType::Set, | ||
| "expected a set operation but got ", | ||
| body_expr); | ||
| NVF_ERROR( | ||
| body_expr->isA<LoadStoreOp>(), | ||
| "expected a Tv operation but got ", | ||
| body_expr); | ||
| auto* set_op = body_expr->as<LoadStoreOp>(); | ||
| auto* input_tv = set_op->in()->as<TensorView>(); | ||
| auto* output_tv = set_op->out()->as<TensorView>(); | ||
| NVF_ERROR( | ||
| input_tv->axis(0)->isDeviceDim(), | ||
| "expected a sharded first axis on the input but got ", | ||
| input_tv); | ||
| NVF_ERROR( | ||
| output_tv->axis(0)->getParallelType() == ParallelType::Stream, | ||
| "expected a stream parallelized first axis on the output but " | ||
| "got ", | ||
| output_tv); | ||
| NVF_ERROR( | ||
| input_tv->axis(1)->getParallelType() == ParallelType::Stream, | ||
| "expected a stream parallelized second axis on the input but " | ||
| "got ", | ||
| input_tv); | ||
| NVF_ERROR( | ||
| output_tv->axis(1)->isDeviceDim(), | ||
| "expected a sharded second axis on the output but got ", | ||
| output_tv); | ||
| auto* is_sending_to_self = | ||
| IrBuilder::create<kir::Predicate>(eq(tensor_index, my_device_id)); | ||
| auto if_sending_to_self = | ||
| IrBuilder::create<kir::IfThenElse>(is_sending_to_self); | ||
| auto [slicing_input, is_new] = tensor_slicing_cache.get( | ||
| input_tv, | ||
| /*dim*/ | ||
| findStreamAxisIndex(input_tv, for_loop->iterDomain(), id_model), | ||
| /*index=*/tensor_index); | ||
| auto [slicing_output, is_new_] = tensor_slicing_cache.get( | ||
| output_tv, /*dim*/ 0, /*index=*/recv_peer); | ||
| auto* local_copy = IrBuilder::create<LoadStoreOp>( | ||
| LoadStoreOpType::Set, | ||
| slicing_output->out(), | ||
| slicing_input->out()); | ||
| if_sending_to_self->thenBody().pushBack(local_copy); | ||
| auto recv = IrBuilder::create<P2PCommunication>( | ||
| P2PCommunicationType::RECV, | ||
| slicing_output->out(), | ||
| recv_peer, | ||
| communicator_backend); | ||
| auto send = IrBuilder::create<P2PCommunication>( | ||
| P2PCommunicationType::SEND, | ||
| slicing_input->out(), | ||
| tensor_index, | ||
| communicator_backend); | ||
| if (communicator_backend == CommunicatorBackend::kNccl) { | ||
| auto start_coalescing = IrBuilder::create<hir::StartCoalescing>(); | ||
| auto end_coalescing = IrBuilder::create<hir::EndCoalescing>(); | ||
| auto wait = IrBuilder::create<hir::Wait>(end_coalescing); | ||
|
|
||
| if_sending_to_self->elseBody().pushBack(start_coalescing); | ||
| if_sending_to_self->elseBody().pushBack(recv); | ||
| if_sending_to_self->elseBody().pushBack(send); | ||
| if_sending_to_self->elseBody().pushBack(end_coalescing); | ||
| if_sending_to_self->elseBody().pushBack(wait); | ||
| } else if (communicator_backend == CommunicatorBackend::kCuda) { | ||
| auto share_mem_handles = IrBuilder::create<hir::ShareMemHandles>( | ||
| std::vector<P2PCommunication*>({recv, send})); | ||
| auto wait_send = IrBuilder::create<hir::Wait>(send); | ||
| auto wait_recv = IrBuilder::create<hir::Wait>(recv); | ||
|
|
||
| if_sending_to_self->elseBody().pushBack(share_mem_handles); | ||
| switch (getP2pProtocol()) { | ||
| case P2pProtocol::Get: { | ||
| if_sending_to_self->elseBody().pushBack(send); | ||
| if_sending_to_self->elseBody().pushBack(recv); | ||
| break; | ||
| } | ||
| case P2pProtocol::Put: { | ||
| if_sending_to_self->elseBody().pushBack(recv); | ||
| if_sending_to_self->elseBody().pushBack(send); | ||
| break; | ||
| } | ||
| } | ||
| if_sending_to_self->elseBody().pushBack(wait_recv); | ||
| // Defer the wait on send to the loop epilogue under the same | ||
| // predicate | ||
| auto* deferred_wait_if = IrBuilder::create<kir::IfThenElse>( | ||
| if_sending_to_self->input(0)->as<kir::Predicate>()); | ||
| deferred_wait_if->elseBody().pushBack(wait_send); | ||
| new_loop_body_epilogue.push_back(deferred_wait_if); | ||
| } else { | ||
| NVF_THROW( | ||
| "Unsupported communicator backend for lowering stream parallel " | ||
| "type into p2p: ", | ||
| communicator_backend); | ||
| } | ||
| if_sending_to_self->elseBody().pushBack(wait_recv); | ||
| // Defer the wait on send to the loop epilogue under the same | ||
| // predicate | ||
| auto* deferred_wait_if = IrBuilder::create<kir::IfThenElse>( | ||
| if_sending_to_self->input(0)->as<kir::Predicate>()); | ||
| deferred_wait_if->elseBody().pushBack(wait_send); | ||
| new_loop_body_epilogue.push_back(deferred_wait_if); | ||
| new_loop_body.push_back(slicing_input); | ||
| new_loop_body.push_back(slicing_output); | ||
| new_loop_body.push_back(if_sending_to_self); | ||
| } else { | ||
| NVF_THROW( | ||
| "Unsupported communicator backend for lowering stream parallel " | ||
| "type into p2p: ", | ||
| // Lower to the MM+RS reduce-collective-based algorithm | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block is the core change |
||
| NVF_ERROR( | ||
| body_expr->isA<ReductionOp>(), | ||
| "expected a reduction operation but got ", | ||
| body_expr); | ||
| NVF_ERROR( | ||
| body_expr->as<ReductionOp>()->getReductionOpType() == | ||
| BinaryOpType::Add, | ||
| "expected a reduce operation but got ", | ||
| body_expr); | ||
| auto* reduction_op = body_expr->as<ReductionOp>(); | ||
| auto* input_tv = reduction_op->in()->as<TensorView>(); | ||
nsarka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| auto* output_tv = reduction_op->out()->as<TensorView>(); | ||
nsarka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| auto [slicing_input, is_new] = tensor_slicing_cache.get( | ||
| input_tv, | ||
| /*dim*/ | ||
| findStreamAxisIndex(input_tv, for_loop->iterDomain(), id_model), | ||
| /*index=*/tensor_index); | ||
| auto [slicing_output, is_new_] = tensor_slicing_cache.get( | ||
| output_tv, | ||
| /*dim*/ | ||
| findStreamAxisIndex(output_tv, for_loop->iterDomain(), id_model), | ||
| /*index=*/tensor_index); | ||
| auto reduce = IrBuilder::create<Communication>( | ||
| CommunicationType::Reduce, | ||
| slicing_output->out(), | ||
| slicing_input->out(), | ||
| input_tv->getDeviceMesh().vector(), | ||
| tensor_index, | ||
| c10d::ReduceOp::RedOpType::SUM, | ||
| communicator_backend); | ||
| auto wait = IrBuilder::create<hir::Wait>(reduce); | ||
| new_loop_body.push_back(slicing_input); | ||
| new_loop_body.push_back(slicing_output); | ||
| new_loop_body.push_back(reduce); | ||
| new_loop_body.push_back(wait); | ||
| } | ||
| new_loop_body.push_back(slicing_input); | ||
| new_loop_body.push_back(slicing_output); | ||
| new_loop_body.push_back(if_sending_to_self); | ||
| } else if (did_to_stream) { | ||
| } else if (did_to_stream && !stream_to_did) { | ||
| // Lower to AG+MM algorithm if did_to_stream=true && stream_to_did=false | ||
| // | ||
| // We have a special handling for when an axis pass from DIDx to Stream | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -438,24 +438,6 @@ TEST_F(HirLowerStreamTest, Matmul_N) { | |
| << "Output: " << output << " Expected: " << expected_output; | ||
| } | ||
|
|
||
| TEST_F(HirLowerStreamTest, Matmul_K) { | ||
| auto hic = std::make_unique<HostIrContainer>(); | ||
| FusionGuard fg(hic.get()); | ||
| TensorView* a = makeContigTensor(2); | ||
| TensorView* b = makeContigTensor(2); | ||
| TensorView* c = matmul(a, b); | ||
| hic->addInput(a); | ||
| hic->addInput(b); | ||
| hic->addOutput(c); | ||
| hic->pushBackTopLevelExprs(c->definition()); | ||
| a->setMemoryType(MemoryType::Global); | ||
| b->setMemoryType(MemoryType::Global); | ||
| c->setMemoryType(MemoryType::Global); | ||
| c->axis(-1)->parallelize(ParallelType::Stream); | ||
|
|
||
| EXPECT_ANY_THROW(hir_pass::StreamParallelType().runPass(hic.get())); | ||
| } | ||
|
|
||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the reduced axis can be stream parallelized, this test fails. So here I removed it. |
||
| // We don's support PostOnStream because it does not support well pre-allocated | ||
| // outputs. There is no strong motivation to support PostOnStream | ||
| TEST_F(HirLowerStreamTest, DoNotSupportPostOnStream) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the broadcast op from the fusion. It failed here because it was trying to select on the D axis which is 1 locally
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so why adding here the case of a broadcasted axis (which btw looks good to me) ?
so why not detect if the axis is sharded ? I think that checking that the axis is of size 1 is not correct. Firstly, if the dimension is DIDx then the symbolic size will be D and not 1. Secondly, if the axis is neither broadcast nor sharded but just happens to be of size 1, then we want to error out.
Does it make sense ?