diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index e77b8908a82..63b7cbe5ab9 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -718,12 +718,17 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) { void HostIrEvaluator::handle(HirAliasSelect* hir_alias_select) { auto indexed_id = hir_alias_select->in()->getLogicalDomain().at(hir_alias_select->axis()); - auto index = indexed_id->isBroadcast() - ? 0 - : expr_evaluator_.evaluate(hir_alias_select->index()).as(); auto input = getKnownConcreteValue(hir_alias_select->in()->as()) .as(); + // If the axis being selected is a reduction axis, the tensor doesn't have + // that dimension (it was skipped during allocation). The select is a no-op - + // just bind the input tensor directly to the output. + if (indexed_id->isReduction()) { + expr_evaluator_.bind(hir_alias_select->out(), input); + return; + } + // Count reduction axes up to the target axis int64_t reduction_count = std::count_if( hir_alias_select->in()->getLogicalDomain().begin(), @@ -732,6 +737,14 @@ void HostIrEvaluator::handle(HirAliasSelect* hir_alias_select) { [](const IterDomain* id) { return id->isReduction(); }); // Adjust the ATen axis by subtracting the number of reduction axes int64_t axis = hir_alias_select->axis() - reduction_count; + + // Use index 0 if the IterDomain is marked as broadcast, or if the actual + // tensor dimension has size 1 (behaves like broadcast at runtime even if + // not marked as such in the IR) + auto index = (indexed_id->isBroadcast() || input.size(axis) == 1) + ? 0 + : expr_evaluator_.evaluate(hir_alias_select->index()).as(); + expr_evaluator_.bind(hir_alias_select->out(), input.select(axis, index)); } diff --git a/csrc/host_ir/pass/stream_parallel_type.cpp b/csrc/host_ir/pass/stream_parallel_type.cpp index 906aec227c7..0f44b73ba9b 100644 --- a/csrc/host_ir/pass/stream_parallel_type.cpp +++ b/csrc/host_ir/pass/stream_parallel_type.cpp @@ -59,14 +59,6 @@ void validateStreamAxis(IterDomain* stream_axis, const TensorView* tv) { it_logical_stream_axis != tv->getLogicalDomain().end(), "Cannot stream parallelize on a split/merge axis ", stream_axis); - - // Verify stream axis is an iteration or broadcast axis - NVF_CHECK( - stream_axis->getIterType() == IterType::Iteration || - stream_axis->getIterType() == IterType::Broadcast, - "Stream axis ", - stream_axis, - " should be an iteration or broadcast axis."); } // Checks if two iteration domains are mapped in the ID model @@ -371,108 +363,149 @@ std::list processForLoopBodies( // Lower to MM + RS algorithm if (did_to_stream && stream_to_did) { - NVF_ERROR( - body_expr->isA() && - body_expr->as()->opType() == LoadStoreOpType::Set, - "expected a set operation but got ", - body_expr); - NVF_ERROR( - body_expr->isA(), - "expected a Tv operation but got ", - body_expr); - NVF_ERROR( - params.offset_stream_indexing_by_rank == true, - "offset_stream_indexing_by_rank==false not supported for " - "ReduceScatter patterns"); - auto* set_op = body_expr->as(); - auto* input_tv = set_op->in()->as(); - auto* output_tv = set_op->out()->as(); - NVF_ERROR( - input_tv->axis(0)->isDeviceDim(), - "expected a sharded first axis on the input but got ", - input_tv); - NVF_ERROR( - output_tv->axis(0)->getParallelType() == ParallelType::Stream, - "expected a stream parallelized first axis on the output but got ", - output_tv); - NVF_ERROR( - input_tv->axis(1)->getParallelType() == ParallelType::Stream, - "expected a stream parallelized second axis on the input but got ", - input_tv); - NVF_ERROR( - output_tv->axis(1)->isDeviceDim(), - "expected a sharded second axis on the output but got ", - output_tv); - auto* is_sending_to_self = - IrBuilder::create(eq(tensor_index, my_device_id)); - auto if_sending_to_self = - IrBuilder::create(is_sending_to_self); - auto [slicing_input, is_new] = tensor_slicing_cache.get( - input_tv, - /*dim*/ - findStreamAxisIndex(input_tv, for_loop->iterDomain(), id_model), - /*index=*/tensor_index); - auto [slicing_output, is_new_] = - tensor_slicing_cache.get(output_tv, /*dim*/ 0, /*index=*/recv_peer); - auto* local_copy = IrBuilder::create( - LoadStoreOpType::Set, slicing_output->out(), slicing_input->out()); - if_sending_to_self->thenBody().pushBack(local_copy); - auto recv = IrBuilder::create( - P2PCommunicationType::RECV, - slicing_output->out(), - recv_peer, - communicator_backend); - auto send = IrBuilder::create( - P2PCommunicationType::SEND, - slicing_input->out(), - tensor_index, - communicator_backend); - if (communicator_backend == CommunicatorBackend::kNccl) { - auto start_coalescing = IrBuilder::create(); - auto end_coalescing = IrBuilder::create(); - auto wait = IrBuilder::create(end_coalescing); - - if_sending_to_self->elseBody().pushBack(start_coalescing); - if_sending_to_self->elseBody().pushBack(recv); - if_sending_to_self->elseBody().pushBack(send); - if_sending_to_self->elseBody().pushBack(end_coalescing); - if_sending_to_self->elseBody().pushBack(wait); - } else if (communicator_backend == CommunicatorBackend::kCuda) { - auto share_mem_handles = IrBuilder::create( - std::vector({recv, send})); - auto wait_send = IrBuilder::create(send); - auto wait_recv = IrBuilder::create(recv); - - if_sending_to_self->elseBody().pushBack(share_mem_handles); - switch (getP2pProtocol()) { - case P2pProtocol::Get: { - if_sending_to_self->elseBody().pushBack(send); - if_sending_to_self->elseBody().pushBack(recv); - break; - } - case P2pProtocol::Put: { - if_sending_to_self->elseBody().pushBack(recv); - if_sending_to_self->elseBody().pushBack(send); - break; + if (params.offset_stream_indexing_by_rank) { + // Lower to MM + RS p2p based algorithm + NVF_ERROR( + body_expr->isA() && + body_expr->as()->opType() == + LoadStoreOpType::Set, + "expected a set operation but got ", + body_expr); + NVF_ERROR( + body_expr->isA(), + "expected a Tv operation but got ", + body_expr); + auto* set_op = body_expr->as(); + auto* input_tv = set_op->in()->as(); + auto* output_tv = set_op->out()->as(); + NVF_ERROR( + input_tv->axis(0)->isDeviceDim(), + "expected a sharded first axis on the input but got ", + input_tv); + NVF_ERROR( + output_tv->axis(0)->getParallelType() == ParallelType::Stream, + "expected a stream parallelized first axis on the output but " + "got ", + output_tv); + NVF_ERROR( + input_tv->axis(1)->getParallelType() == ParallelType::Stream, + "expected a stream parallelized second axis on the input but " + "got ", + input_tv); + NVF_ERROR( + output_tv->axis(1)->isDeviceDim(), + "expected a sharded second axis on the output but got ", + output_tv); + auto* is_sending_to_self = + IrBuilder::create(eq(tensor_index, my_device_id)); + auto if_sending_to_self = + IrBuilder::create(is_sending_to_self); + auto [slicing_input, is_new] = tensor_slicing_cache.get( + input_tv, + /*dim*/ + findStreamAxisIndex(input_tv, for_loop->iterDomain(), id_model), + /*index=*/tensor_index); + auto [slicing_output, is_new_] = tensor_slicing_cache.get( + output_tv, /*dim*/ 0, /*index=*/recv_peer); + auto* local_copy = IrBuilder::create( + LoadStoreOpType::Set, + slicing_output->out(), + slicing_input->out()); + if_sending_to_self->thenBody().pushBack(local_copy); + auto recv = IrBuilder::create( + P2PCommunicationType::RECV, + slicing_output->out(), + recv_peer, + communicator_backend); + auto send = IrBuilder::create( + P2PCommunicationType::SEND, + slicing_input->out(), + tensor_index, + communicator_backend); + if (communicator_backend == CommunicatorBackend::kNccl) { + auto start_coalescing = IrBuilder::create(); + auto end_coalescing = IrBuilder::create(); + auto wait = IrBuilder::create(end_coalescing); + + if_sending_to_self->elseBody().pushBack(start_coalescing); + if_sending_to_self->elseBody().pushBack(recv); + if_sending_to_self->elseBody().pushBack(send); + if_sending_to_self->elseBody().pushBack(end_coalescing); + if_sending_to_self->elseBody().pushBack(wait); + } else if (communicator_backend == CommunicatorBackend::kCuda) { + auto share_mem_handles = IrBuilder::create( + std::vector({recv, send})); + auto wait_send = IrBuilder::create(send); + auto wait_recv = IrBuilder::create(recv); + + if_sending_to_self->elseBody().pushBack(share_mem_handles); + switch (getP2pProtocol()) { + case P2pProtocol::Get: { + if_sending_to_self->elseBody().pushBack(send); + if_sending_to_self->elseBody().pushBack(recv); + break; + } + case P2pProtocol::Put: { + if_sending_to_self->elseBody().pushBack(recv); + if_sending_to_self->elseBody().pushBack(send); + break; + } } + if_sending_to_self->elseBody().pushBack(wait_recv); + // Defer the wait on send to the loop epilogue under the same + // predicate + auto* deferred_wait_if = IrBuilder::create( + if_sending_to_self->input(0)->as()); + deferred_wait_if->elseBody().pushBack(wait_send); + new_loop_body_epilogue.push_back(deferred_wait_if); + } else { + NVF_THROW( + "Unsupported communicator backend for lowering stream parallel " + "type into p2p: ", + communicator_backend); } - if_sending_to_self->elseBody().pushBack(wait_recv); - // Defer the wait on send to the loop epilogue under the same - // predicate - auto* deferred_wait_if = IrBuilder::create( - if_sending_to_self->input(0)->as()); - deferred_wait_if->elseBody().pushBack(wait_send); - new_loop_body_epilogue.push_back(deferred_wait_if); + new_loop_body.push_back(slicing_input); + new_loop_body.push_back(slicing_output); + new_loop_body.push_back(if_sending_to_self); } else { - NVF_THROW( - "Unsupported communicator backend for lowering stream parallel " - "type into p2p: ", + // Lower to the MM+RS reduce-collective-based algorithm + NVF_ERROR( + body_expr->isA(), + "expected a reduction operation but got ", + body_expr); + NVF_ERROR( + body_expr->as()->getReductionOpType() == + BinaryOpType::Add, + "expected a reduce operation but got ", + body_expr); + auto* reduction_op = body_expr->as(); + auto* input_tv = reduction_op->in()->as(); + auto* output_tv = reduction_op->out()->as(); + auto [slicing_input, is_new] = tensor_slicing_cache.get( + input_tv, + /*dim*/ + findStreamAxisIndex(input_tv, for_loop->iterDomain(), id_model), + /*index=*/tensor_index); + auto [slicing_output, is_new_] = tensor_slicing_cache.get( + output_tv, + /*dim*/ + findStreamAxisIndex(output_tv, for_loop->iterDomain(), id_model), + /*index=*/tensor_index); + auto reduce = IrBuilder::create( + CommunicationType::Reduce, + slicing_output->out(), + slicing_input->out(), + input_tv->getDeviceMesh().vector(), + tensor_index, + c10d::ReduceOp::RedOpType::SUM, communicator_backend); + auto wait = IrBuilder::create(reduce); + new_loop_body.push_back(slicing_input); + new_loop_body.push_back(slicing_output); + new_loop_body.push_back(reduce); + new_loop_body.push_back(wait); } - new_loop_body.push_back(slicing_input); - new_loop_body.push_back(slicing_output); - new_loop_body.push_back(if_sending_to_self); - } else if (did_to_stream) { + } else if (did_to_stream && !stream_to_did) { // Lower to AG+MM algorithm if did_to_stream=true && stream_to_did=false // // We have a special handling for when an axis pass from DIDx to Stream diff --git a/tests/cpp/test_host_ir_stream_lowering.cpp b/tests/cpp/test_host_ir_stream_lowering.cpp index 17f10b43dfc..2fb4fffa820 100644 --- a/tests/cpp/test_host_ir_stream_lowering.cpp +++ b/tests/cpp/test_host_ir_stream_lowering.cpp @@ -438,24 +438,6 @@ TEST_F(HirLowerStreamTest, Matmul_N) { << "Output: " << output << " Expected: " << expected_output; } -TEST_F(HirLowerStreamTest, Matmul_K) { - auto hic = std::make_unique(); - FusionGuard fg(hic.get()); - TensorView* a = makeContigTensor(2); - TensorView* b = makeContigTensor(2); - TensorView* c = matmul(a, b); - hic->addInput(a); - hic->addInput(b); - hic->addOutput(c); - hic->pushBackTopLevelExprs(c->definition()); - a->setMemoryType(MemoryType::Global); - b->setMemoryType(MemoryType::Global); - c->setMemoryType(MemoryType::Global); - c->axis(-1)->parallelize(ParallelType::Stream); - - EXPECT_ANY_THROW(hir_pass::StreamParallelType().runPass(hic.get())); -} - // We don's support PostOnStream because it does not support well pre-allocated // outputs. There is no strong motivation to support PostOnStream TEST_F(HirLowerStreamTest, DoNotSupportPostOnStream) { diff --git a/tests/cpp/test_multidevice_stream_parallel_type.cpp b/tests/cpp/test_multidevice_stream_parallel_type.cpp index 5bd25fd3056..821bbc66207 100644 --- a/tests/cpp/test_multidevice_stream_parallel_type.cpp +++ b/tests/cpp/test_multidevice_stream_parallel_type.cpp @@ -663,7 +663,7 @@ TEST_P(RSMatmulTest, ReduceScatterP2p) { MultiDeviceExecutorParams params; params.lower.communicator_backend = communicator_backend; - params.lower.offset_stream_indexing_by_rank = true; + params.lower.offset_stream_indexing_by_rank = true; // Will fail if false MultiDeviceExecutor executor(std::move(fusion), *communicator_, params); auto tensor_options = @@ -684,6 +684,84 @@ TEST_P(RSMatmulTest, ReduceScatterP2p) { << "Output: " << t2 << " Expected: " << t2_ref; } +// The difference between this test and the previous one is that this test +// is resharding on the sum instead of the matmul. +// TODO: support both true/false for params.lower.offset_stream_indexing_by_rank +// in both fusions +TEST_P(RSMatmulTest, ReduceScatterReduceBased) { + CommunicatorBackend communicator_backend = GetParam(); + constexpr int64_t M = 64; + constexpr int64_t K = 64; + constexpr int64_t N = 64; + constexpr int64_t S = 4; + const int64_t D = communicator_->size(); + if (M % (S * D) != 0) { + GTEST_SKIP() << "M must be a multiple of S * D, but got M = " << M + << ", S = " << S << ", D = " << D; + } + if (K % D != 0) { + GTEST_SKIP() << "K must be a multiple of D, but got K = " << K + << ", D = " << D; + } + if (communicator_backend == CommunicatorBackend::kCuda) { + GTEST_SKIP() << "CUDA backend is not supported for this test"; + } + + EnableOptionsGuard::getCurOptions().set(EnableOption::InsertReshardingAfter); + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // Only the reduced dimension (D) is actually sharded, M is split logically + // for convenience + TensorView* A = makeContigTensor(4); // [DIDx(D), Stream(D), M/D, K/D] + TensorView* B = makeContigTensor(4); // [DIDx(D), 1, K/D, N] + TensorView* C_unreduced = matmul(A, B); // [DIDx(D), Stream(D), M/D, N] + TensorView* C = sum(C_unreduced, {0}); // [Stream(r(D)), DIDx(D), M/D, N] + + fusion->addInput(A); + fusion->addInput(B); + fusion->addOutput(C); + + auto mesh = DeviceMesh::createForNumDevices(D); + A->setDeviceMesh(mesh); + B->setDeviceMesh(mesh); + C_unreduced->setDeviceMesh(mesh); + C->setDeviceMesh(mesh); + + A->axis(0)->parallelize(ParallelType::DIDx); + A->axis(1)->parallelize(ParallelType::Stream); + B->axis(0)->parallelize(ParallelType::DIDx); + C_unreduced->axis(1)->parallelize(ParallelType::Stream); + C_unreduced->axis(0)->parallelize(ParallelType::DIDx); + C->axis(1)->parallelize(ParallelType::DIDx); + C->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutorParams params; + params.lower.communicator_backend = communicator_backend; + params.lower.offset_stream_indexing_by_rank = false; // Will fail if true + MultiDeviceExecutor executor(std::move(fusion), *communicator_, params); + + auto tensor_options = + at::TensorOptions().dtype(at::kFloat).device(communicator_->device()); + auto A_unsharded = at::randn({D, D, M / D, K / D}, tensor_options); + auto B_unsharded = at::randn({D, 1, K / D, N}, tensor_options); + auto A_sharded = shardTensor1D(A_unsharded, /*axis=*/0, mesh); + auto B_sharded = shardTensor1D(B_unsharded, /*axis=*/0, mesh); + + auto C_out = + executor.runWithInput({A_sharded, B_sharded})[0].as(); + + auto C_unreduced_unsharded = + at::matmul(A_unsharded, B_unsharded); // {D, D, M / D, N} + auto C_reduced_unsharded = + at::sum(C_unreduced_unsharded, {0}); // {D, M / D, N} + auto C_ref = + shardTensor1D(C_reduced_unsharded, /*axis=*/0, mesh); // {M / D, N} + EXPECT_TRUE(at::allclose(C_ref, C_out, 1e-1, 1e-1)) + << "Output: " << C_out << " Expected: " << C_ref; +} + INSTANTIATE_TEST_SUITE_P( , RSMatmulTest,