-
Notifications
You must be signed in to change notification settings - Fork 78
Add kernel based alltoallv and cuda backend for MoE dispatch and combine #5863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
cf77bdb
66e7811
afd948d
dda9aa7
ba6612d
4693c53
8041c46
f1ce74c
a81a514
a0de605
74d18d1
6b994ba
47d710f
f39daf2
da52220
c089049
490200f
f148137
6f56706
ea5ad45
3828247
72232c5
c5add1b
374c8b3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| // clang-format off | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. | ||
| * All rights reserved. | ||
| * SPDX-License-Identifier: BSD-3-Clause | ||
| */ | ||
| // clang-format on | ||
|
|
||
| extern "C" __global__ void alltoallv_kernel( | ||
| const unsigned char* send, | ||
| const unsigned long long* recv_ptrs, | ||
| const long long* send_offsets, | ||
| const long long* send_sizes, | ||
| const long long* recv_offsets, | ||
| long long world_size, | ||
| long long elem_size, | ||
| long long max_send_bytes) { | ||
| const long long peer = static_cast<long long>(blockIdx.y); | ||
| if (peer >= world_size) { | ||
| return; | ||
| } | ||
| const long long bytes = send_sizes[peer] * elem_size; | ||
| if (bytes == 0) { | ||
| return; | ||
| } | ||
| const long long idx = | ||
| static_cast<long long>(blockIdx.x) * blockDim.x + threadIdx.x; | ||
| if (idx >= bytes) { | ||
| return; | ||
| } | ||
| const long long send_byte_offset = send_offsets[peer] * elem_size + idx; | ||
| const long long recv_byte_offset = recv_offsets[peer] * elem_size + idx; | ||
| auto* dst = reinterpret_cast<unsigned char*>( | ||
| static_cast<unsigned long long>(recv_ptrs[peer])); | ||
| dst[recv_byte_offset] = send[send_byte_offset]; | ||
| } |
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |||||||||||||
| */ | ||||||||||||||
| // clang-format on | ||||||||||||||
| #include "multidevice/cuda_p2p.h" | ||||||||||||||
| #include "nvfuser_resources/alltoallv.h" | ||||||||||||||
| #include "nvfuser_resources/multicast.h" | ||||||||||||||
|
|
||||||||||||||
| #include "cuda_utils.h" | ||||||||||||||
|
|
@@ -34,6 +35,143 @@ P2pProtocol getP2pProtocol() { | |||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| namespace { | ||||||||||||||
| void launchAlltoallvKernel( | ||||||||||||||
| const void* send, | ||||||||||||||
| const uint64_t* recv_ptrs, | ||||||||||||||
| const int64_t* send_offsets, | ||||||||||||||
| const int64_t* send_sizes, | ||||||||||||||
| const int64_t* recv_offsets, | ||||||||||||||
| int64_t world_size, | ||||||||||||||
| int64_t elem_size, | ||||||||||||||
| int64_t max_send_bytes, | ||||||||||||||
| CUstream stream) { | ||||||||||||||
| static CUmodule module = nullptr; | ||||||||||||||
| static CUfunction kernel = nullptr; | ||||||||||||||
|
|
||||||||||||||
| if (module == nullptr) { | ||||||||||||||
| nvrtcProgram prog; | ||||||||||||||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram( | ||||||||||||||
| &prog, | ||||||||||||||
| nvfuser_resources::alltoallv_cu, | ||||||||||||||
| "alltoallv.cu", | ||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why nvrtc? Can't we simply
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question. I used nvrtc here to match nvFuser’s existing runtime-kernel pattern, to the best of my understanding, not because In our codebase, IIUC, these helper CUDA kernels are treated as runtime resources: At runtime, we compile/load that source with NVRTC + cuModuleLoadData, same style as other runtime kernels in If you’d prefer the static CUDA-launch route ( |
||||||||||||||
| 0, | ||||||||||||||
| nullptr, | ||||||||||||||
| nullptr)); | ||||||||||||||
|
|
||||||||||||||
| int major = 0; | ||||||||||||||
| int minor = 0; | ||||||||||||||
| int device = 0; | ||||||||||||||
| NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device)); | ||||||||||||||
| cudaDeviceProp prop; | ||||||||||||||
| NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDeviceProperties(&prop, device)); | ||||||||||||||
| major = prop.major; | ||||||||||||||
| minor = prop.minor; | ||||||||||||||
|
|
||||||||||||||
| std::string arch_arg = "--gpu-architecture=compute_" + | ||||||||||||||
| std::to_string(major) + std::to_string(minor); | ||||||||||||||
| std::vector<const char*> opts = {arch_arg.c_str(), "--std=c++17"}; | ||||||||||||||
| // NVRTC needs CUDA headers to compile alltoallv.cu. | ||||||||||||||
| opts.push_back("-I/usr/local/cuda/include"); | ||||||||||||||
| opts.push_back("-I/usr/local/cuda/include/cccl"); | ||||||||||||||
|
Comment on lines
+74
to
+75
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hardcoded CUDA include paths may break on non-standard installations
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| nvrtcResult res = nvrtcCompileProgram(prog, (int)opts.size(), opts.data()); | ||||||||||||||
| if (res != NVRTC_SUCCESS) { | ||||||||||||||
| size_t logSize; | ||||||||||||||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLogSize(prog, &logSize)); | ||||||||||||||
| std::vector<char> log(logSize); | ||||||||||||||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLog(prog, log.data())); | ||||||||||||||
| NVF_ERROR(false, "Alltoallv kernel compilation failed:\n", log.data()); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| size_t ptxSize; | ||||||||||||||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize)); | ||||||||||||||
| std::vector<char> ptx(ptxSize); | ||||||||||||||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data())); | ||||||||||||||
| NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); | ||||||||||||||
|
|
||||||||||||||
| CUresult load_result = cuModuleLoadData(&module, ptx.data()); | ||||||||||||||
| if (load_result != CUDA_SUCCESS) { | ||||||||||||||
| constexpr size_t kLogSize = 8192; | ||||||||||||||
| char error_log[kLogSize]; | ||||||||||||||
| char info_log[kLogSize]; | ||||||||||||||
| CUjit_option options[] = { | ||||||||||||||
| CU_JIT_ERROR_LOG_BUFFER, | ||||||||||||||
| CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, | ||||||||||||||
| CU_JIT_INFO_LOG_BUFFER, | ||||||||||||||
| CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, | ||||||||||||||
| CU_JIT_LOG_VERBOSE}; | ||||||||||||||
| void* option_values[] = { | ||||||||||||||
| (void*)error_log, | ||||||||||||||
| (void*)kLogSize, | ||||||||||||||
| (void*)info_log, | ||||||||||||||
| (void*)kLogSize, | ||||||||||||||
| (void*)1}; | ||||||||||||||
| cuModuleLoadDataEx(&module, ptx.data(), 5, options, option_values); | ||||||||||||||
| NVF_ERROR( | ||||||||||||||
| false, | ||||||||||||||
| "Alltoallv kernel module load failed with error: ", | ||||||||||||||
| load_result, | ||||||||||||||
| "\nInfo Log:\n", | ||||||||||||||
| info_log, | ||||||||||||||
| "\nError Log:\n", | ||||||||||||||
| error_log); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| NVFUSER_CUDA_SAFE_CALL( | ||||||||||||||
| cuModuleGetFunction(&kernel, module, "alltoallv_kernel")); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| if (max_send_bytes == 0) { | ||||||||||||||
| return; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| constexpr int kThreads = 256; | ||||||||||||||
| const int64_t blocks_x = (max_send_bytes + kThreads - 1) / kThreads; | ||||||||||||||
| void* args_kernel[] = { | ||||||||||||||
| const_cast<void*>(static_cast<const void*>(&send)), | ||||||||||||||
| const_cast<void*>(static_cast<const void*>(&recv_ptrs)), | ||||||||||||||
| const_cast<void*>(static_cast<const void*>(&send_offsets)), | ||||||||||||||
| const_cast<void*>(static_cast<const void*>(&send_sizes)), | ||||||||||||||
| const_cast<void*>(static_cast<const void*>(&recv_offsets)), | ||||||||||||||
| &world_size, | ||||||||||||||
| &elem_size, | ||||||||||||||
| &max_send_bytes}; | ||||||||||||||
| NVFUSER_CUDA_SAFE_CALL(cuLaunchKernel( | ||||||||||||||
| kernel, | ||||||||||||||
| blocks_x, | ||||||||||||||
| static_cast<unsigned int>(world_size), | ||||||||||||||
| 1, | ||||||||||||||
| kThreads, | ||||||||||||||
| 1, | ||||||||||||||
| 1, | ||||||||||||||
| 0, | ||||||||||||||
| stream, | ||||||||||||||
| args_kernel, | ||||||||||||||
| nullptr)); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| std::vector<uint8_t> serializeInt64Vector(const std::vector<int64_t>& values) { | ||||||||||||||
| std::vector<uint8_t> bytes(values.size() * sizeof(int64_t)); | ||||||||||||||
| std::memcpy(bytes.data(), values.data(), bytes.size()); | ||||||||||||||
| return bytes; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| std::vector<int64_t> deserializeInt64Vector(const std::vector<uint8_t>& bytes) { | ||||||||||||||
| NVF_CHECK( | ||||||||||||||
| bytes.size() % sizeof(int64_t) == 0, "Invalid int64 byte buffer size."); | ||||||||||||||
| const size_t count = bytes.size() / sizeof(int64_t); | ||||||||||||||
| std::vector<int64_t> values(count); | ||||||||||||||
| std::memcpy(values.data(), bytes.data(), bytes.size()); | ||||||||||||||
| return values; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| std::string alltoallvCountsKey(const std::string& tag, int64_t rank) { | ||||||||||||||
| return "nvfuser_alltoallv_counts_" + tag + "_" + std::to_string(rank); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| std::string alltoallvBarrierKey(const std::string& tag, int64_t rank) { | ||||||||||||||
| return "nvfuser_alltoallv_barrier_" + tag + "_" + std::to_string(rank); | ||||||||||||||
| } | ||||||||||||||
|
Comment on lines
+172
to
+174
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unused function - |
||||||||||||||
|
|
||||||||||||||
| void launchMulticastKernel( | ||||||||||||||
| void* dst, | ||||||||||||||
|
|
@@ -710,4 +848,181 @@ void waitWithCudaBackend( | |||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| AlltoallvMetadata prepareAlltoallvMetadata( | ||||||||||||||
| const at::Tensor& send_counts, | ||||||||||||||
| const std::string& tag) { | ||||||||||||||
| Communicator& comm = Communicator::getInstance(); | ||||||||||||||
| const int64_t world_size = comm.size(); | ||||||||||||||
| const int64_t my_rank = comm.deviceId(); | ||||||||||||||
| NVF_CHECK( | ||||||||||||||
| send_counts.is_cuda(), "alltoallv send_counts must be CUDA tensor."); | ||||||||||||||
| NVF_CHECK( | ||||||||||||||
| send_counts.dim() == 1 && send_counts.numel() == world_size, | ||||||||||||||
| "alltoallv send_counts must be 1D [R]."); | ||||||||||||||
|
|
||||||||||||||
| auto store = comm.getTcpStore(); | ||||||||||||||
| auto send_counts_cpu = send_counts.to(at::kCPU); | ||||||||||||||
| auto* send_ptr = send_counts_cpu.data_ptr<int64_t>(); | ||||||||||||||
| std::vector<int64_t> send_counts_vec(send_ptr, send_ptr + world_size); | ||||||||||||||
|
|
||||||||||||||
| store->set( | ||||||||||||||
| alltoallvCountsKey(tag, my_rank), serializeInt64Vector(send_counts_vec)); | ||||||||||||||
|
|
||||||||||||||
| std::vector<std::vector<int64_t>> counts_matrix(world_size); | ||||||||||||||
| for (int64_t rank = 0; rank < world_size; ++rank) { | ||||||||||||||
| auto bytes = store->get(alltoallvCountsKey(tag, rank)); | ||||||||||||||
| counts_matrix[rank] = deserializeInt64Vector(bytes); | ||||||||||||||
| NVF_CHECK( | ||||||||||||||
| (int64_t)counts_matrix[rank].size() == world_size, | ||||||||||||||
| "Invalid alltoallv counts size."); | ||||||||||||||
| } | ||||||||||||||
| comm.barrier(); | ||||||||||||||
| for (int64_t rank = 0; rank < world_size; ++rank) { | ||||||||||||||
| store->deleteKey(alltoallvCountsKey(tag, rank)); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| std::vector<int64_t> recv_counts_vec(world_size, 0); | ||||||||||||||
| for (int64_t sender = 0; sender < world_size; ++sender) { | ||||||||||||||
| recv_counts_vec[sender] = counts_matrix[sender][my_rank]; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| std::vector<int64_t> send_offsets_vec(world_size, 0); | ||||||||||||||
| int64_t prefix = 0; | ||||||||||||||
| for (int64_t rank = 0; rank < world_size; ++rank) { | ||||||||||||||
| send_offsets_vec[rank] = prefix; | ||||||||||||||
| prefix += send_counts_vec[rank]; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| std::vector<int64_t> recv_offsets_vec(world_size, 0); | ||||||||||||||
| for (int64_t peer = 0; peer < world_size; ++peer) { | ||||||||||||||
| int64_t offset = 0; | ||||||||||||||
| for (int64_t sender = 0; sender < my_rank; ++sender) { | ||||||||||||||
| offset += counts_matrix[sender][peer]; | ||||||||||||||
| } | ||||||||||||||
| recv_offsets_vec[peer] = offset; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| int64_t total_recv = 0; | ||||||||||||||
| for (auto value : recv_counts_vec) { | ||||||||||||||
| total_recv += value; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| int64_t max_recv = 0; | ||||||||||||||
| int64_t max_send_total = 0; | ||||||||||||||
| for (int64_t rank = 0; rank < world_size; ++rank) { | ||||||||||||||
| int64_t total = 0; | ||||||||||||||
| for (int64_t sender = 0; sender < world_size; ++sender) { | ||||||||||||||
| total += counts_matrix[sender][rank]; | ||||||||||||||
| } | ||||||||||||||
| if (total > max_recv) { | ||||||||||||||
| max_recv = total; | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| for (int64_t rank = 0; rank < world_size; ++rank) { | ||||||||||||||
| int64_t total = 0; | ||||||||||||||
| for (int64_t dest = 0; dest < world_size; ++dest) { | ||||||||||||||
| total += counts_matrix[rank][dest]; | ||||||||||||||
| } | ||||||||||||||
| if (total > max_send_total) { | ||||||||||||||
| max_send_total = total; | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| int64_t max_send = 0; | ||||||||||||||
| for (auto value : send_counts_vec) { | ||||||||||||||
| if (value > max_send) { | ||||||||||||||
| max_send = value; | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); | ||||||||||||||
| auto send_offsets_cpu = at::empty({world_size}, cpu_options); | ||||||||||||||
| std::memcpy( | ||||||||||||||
| send_offsets_cpu.data_ptr<int64_t>(), | ||||||||||||||
| send_offsets_vec.data(), | ||||||||||||||
| world_size * sizeof(int64_t)); | ||||||||||||||
| auto recv_offsets_cpu = at::empty({world_size}, cpu_options); | ||||||||||||||
| std::memcpy( | ||||||||||||||
| recv_offsets_cpu.data_ptr<int64_t>(), | ||||||||||||||
| recv_offsets_vec.data(), | ||||||||||||||
| world_size * sizeof(int64_t)); | ||||||||||||||
| auto recv_counts_cpu = at::empty({world_size}, cpu_options); | ||||||||||||||
| std::memcpy( | ||||||||||||||
| recv_counts_cpu.data_ptr<int64_t>(), | ||||||||||||||
| recv_counts_vec.data(), | ||||||||||||||
| world_size * sizeof(int64_t)); | ||||||||||||||
|
|
||||||||||||||
| AlltoallvMetadata metadata; | ||||||||||||||
| metadata.send_counts = send_counts; | ||||||||||||||
| metadata.recv_counts = recv_counts_cpu.to(send_counts.device()); | ||||||||||||||
| metadata.send_offsets = send_offsets_cpu.to(send_counts.device()); | ||||||||||||||
| metadata.recv_offsets = recv_offsets_cpu.to(send_counts.device()); | ||||||||||||||
| metadata.total_recv = total_recv; | ||||||||||||||
| metadata.max_recv = max_recv; | ||||||||||||||
| metadata.max_send_total = max_send_total; | ||||||||||||||
| metadata.max_send_bytes = max_send; | ||||||||||||||
| metadata.world_size = world_size; | ||||||||||||||
| return metadata; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| void alltoallvWithCudaBackend( | ||||||||||||||
| const at::Tensor& send, | ||||||||||||||
| const at::Tensor& recv, | ||||||||||||||
| const AlltoallvMetadata& metadata, | ||||||||||||||
| const std::vector<void*>& recv_ptrs, | ||||||||||||||
| CUstream stream) { | ||||||||||||||
| NVF_CHECK(send.is_cuda(), "alltoallv send must be CUDA."); | ||||||||||||||
| NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA."); | ||||||||||||||
| NVF_CHECK( | ||||||||||||||
| (int64_t)recv_ptrs.size() == metadata.world_size, | ||||||||||||||
| "recv_ptrs size must match world size."); | ||||||||||||||
|
|
||||||||||||||
| auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); | ||||||||||||||
| auto recv_ptrs_cpu = at::empty({metadata.world_size}, cpu_options); | ||||||||||||||
| auto* ptrs = recv_ptrs_cpu.data_ptr<int64_t>(); | ||||||||||||||
| for (int64_t rank = 0; rank < metadata.world_size; ++rank) { | ||||||||||||||
| ptrs[rank] = | ||||||||||||||
| static_cast<int64_t>(reinterpret_cast<uintptr_t>(recv_ptrs[rank])); | ||||||||||||||
| } | ||||||||||||||
| auto recv_ptrs_cuda = recv_ptrs_cpu.to(send.device()); | ||||||||||||||
|
|
||||||||||||||
| const int64_t elem_stride = | ||||||||||||||
| metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1; | ||||||||||||||
| NVF_CHECK( | ||||||||||||||
| metadata.max_send_total == 0 || | ||||||||||||||
| send.numel() % metadata.max_send_total == 0, | ||||||||||||||
| "alltoallv send numel must be divisible by max_send_total."); | ||||||||||||||
| NVF_CHECK( | ||||||||||||||
| metadata.max_recv == 0 || recv.numel() % metadata.max_recv == 0, | ||||||||||||||
| "alltoallv recv numel must be divisible by max_recv."); | ||||||||||||||
|
|
||||||||||||||
| auto send_offsets = metadata.send_offsets; | ||||||||||||||
| auto send_counts = metadata.send_counts; | ||||||||||||||
| auto recv_offsets = metadata.recv_offsets; | ||||||||||||||
| int64_t max_send_bytes = metadata.max_send_bytes; | ||||||||||||||
| if (elem_stride > 1) { | ||||||||||||||
| send_offsets = metadata.send_offsets * elem_stride; | ||||||||||||||
| send_counts = metadata.send_counts * elem_stride; | ||||||||||||||
| recv_offsets = metadata.recv_offsets * elem_stride; | ||||||||||||||
| max_send_bytes = metadata.max_send_bytes * elem_stride; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| launchAlltoallvKernel( | ||||||||||||||
| send.data_ptr(), | ||||||||||||||
| reinterpret_cast<const uint64_t*>(recv_ptrs_cuda.data_ptr<int64_t>()), | ||||||||||||||
| send_offsets.data_ptr<int64_t>(), | ||||||||||||||
| send_counts.data_ptr<int64_t>(), | ||||||||||||||
| recv_offsets.data_ptr<int64_t>(), | ||||||||||||||
| metadata.world_size, | ||||||||||||||
| send.element_size(), | ||||||||||||||
| max_send_bytes * send.element_size(), | ||||||||||||||
| stream); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| void alltoallvBarrier(const std::string& tag) { | ||||||||||||||
| Communicator& comm = Communicator::getInstance(); | ||||||||||||||
| comm.barrier(); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| } // namespace nvfuser | ||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static variables without thread safety - multiple threads calling
launchAlltoallvKernelconcurrently could race on themodule == nullptrcheck during first initialization