Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
cf77bdb
first working dispatch and combine primitive for k=1
samnordmann Jan 21, 2026
66e7811
add comments and cleanup
samnordmann Jan 21, 2026
afd948d
review
samnordmann Jan 22, 2026
dda9aa7
add kernel based a2av and cuda backend for d/c
samnordmann Jan 22, 2026
ba6612d
minor comments
samnordmann Jan 26, 2026
4693c53
minor review
samnordmann Jan 29, 2026
8041c46
Merge branch 'main' of github.com:NVIDIA/Fuser into dispatch_combine/…
samnordmann Jan 29, 2026
f1ce74c
Merge branch 'main' of github.com:NVIDIA/Fuser into dispatch_combine/…
samnordmann Feb 4, 2026
a81a514
renaming
samnordmann Feb 4, 2026
a0de605
add back topk_weights
samnordmann Feb 5, 2026
74d18d1
harden tests
samnordmann Feb 9, 2026
6b994ba
assume continuous expert-to-rank mapping and simplify API and impleme…
samnordmann Feb 9, 2026
47d710f
simplify by enforcing 2D shapes
samnordmann Feb 9, 2026
f39daf2
lint
samnordmann Feb 9, 2026
da52220
remove combined_topk_weights
samnordmann Feb 9, 2026
c089049
minor simplification
samnordmann Feb 9, 2026
490200f
remove (in|out|send)_src_rank
samnordmann Feb 9, 2026
f148137
Merge branch 'main' of github.com:NVIDIA/Fuser into dispatch_combine/…
samnordmann Feb 9, 2026
6f56706
Merge branch 'dispatch_combine/stub' into dispatch_combine/stub_for_k…
samnordmann Feb 10, 2026
ea5ad45
Merge branch 'main' of github.com:NVIDIA/Fuser into dispatch_combine/…
samnordmann Feb 10, 2026
3828247
lint
samnordmann Feb 10, 2026
72232c5
Merge branch 'main' of github.com:NVIDIA/Fuser into dispatch_combine/…
samnordmann Feb 16, 2026
c5add1b
minor comment
samnordmann Feb 16, 2026
374c8b3
lint
samnordmann Feb 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,7 @@ if(BUILD_TEST)
${NVFUSER_ROOT}/tests/cpp/multidevice.cpp
${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_alltoallv.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_dispatch_combine.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp
Expand Down Expand Up @@ -1393,6 +1394,7 @@ list(APPEND NVFUSER_RUNTIME_FILES
${NVFUSER_ROOT}/runtime/mbarrier.cu
${NVFUSER_ROOT}/runtime/memory.cu
${NVFUSER_ROOT}/runtime/multicast.cu
${NVFUSER_SRCS_DIR}/multidevice/alltoallv.cu
${NVFUSER_ROOT}/runtime/random_numbers.cu
${NVFUSER_ROOT}/runtime/tensor_memory.cu
${NVFUSER_ROOT}/runtime/tensor.cu
Expand Down
36 changes: 36 additions & 0 deletions csrc/multidevice/alltoallv.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on

extern "C" __global__ void alltoallv_kernel(
const unsigned char* send,
const unsigned long long* recv_ptrs,
const long long* send_offsets,
const long long* send_sizes,
const long long* recv_offsets,
long long world_size,
long long elem_size,
long long max_send_bytes) {
const long long peer = static_cast<long long>(blockIdx.y);
if (peer >= world_size) {
return;
}
const long long bytes = send_sizes[peer] * elem_size;
if (bytes == 0) {
return;
}
const long long idx =
static_cast<long long>(blockIdx.x) * blockDim.x + threadIdx.x;
if (idx >= bytes) {
return;
}
const long long send_byte_offset = send_offsets[peer] * elem_size + idx;
const long long recv_byte_offset = recv_offsets[peer] * elem_size + idx;
auto* dst = reinterpret_cast<unsigned char*>(
static_cast<unsigned long long>(recv_ptrs[peer]));
dst[recv_byte_offset] = send[send_byte_offset];
}
315 changes: 315 additions & 0 deletions csrc/multidevice/cuda_p2p.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
// clang-format on
#include "multidevice/cuda_p2p.h"
#include "nvfuser_resources/alltoallv.h"
#include "nvfuser_resources/multicast.h"

#include "cuda_utils.h"
Expand Down Expand Up @@ -34,6 +35,143 @@ P2pProtocol getP2pProtocol() {
}

namespace {
void launchAlltoallvKernel(
const void* send,
const uint64_t* recv_ptrs,
const int64_t* send_offsets,
const int64_t* send_sizes,
const int64_t* recv_offsets,
int64_t world_size,
int64_t elem_size,
int64_t max_send_bytes,
CUstream stream) {
static CUmodule module = nullptr;
static CUfunction kernel = nullptr;
Comment on lines +48 to +49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static variables without thread safety - multiple threads calling launchAlltoallvKernel concurrently could race on the module == nullptr check during first initialization

Suggested change
static CUmodule module = nullptr;
static CUfunction kernel = nullptr;
static std::once_flag init_flag;
static CUmodule module = nullptr;
static CUfunction kernel = nullptr;
std::call_once(init_flag, [&]() {


if (module == nullptr) {
nvrtcProgram prog;
NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram(
&prog,
nvfuser_resources::alltoallv_cu,
"alltoallv.cu",
Copy link
Collaborator

@wujingyue wujingyue Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why nvrtc? Can't we simply alltoallv_kernel<<<...>>>?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I used nvrtc here to match nvFuser’s existing runtime-kernel pattern, to the best of my understanding, not because alltoallv_kernel<<<...>>> is impossible.

In our codebase, IIUC, these helper CUDA kernels are treated as runtime resources:
CMakeLists.txt adds csrc/multidevice/alltoallv.cu to NVFUSER_RUNTIME_FILES, then stringifies it into nvfuser_resources/alltoallv.h.

At runtime, we compile/load that source with NVRTC + cuModuleLoadData, same style as other runtime kernels in csrc/runtime/compiled_kernel.cpp, and similarly to the multicast helper in cuda_p2p.cpp.

If you’d prefer the static CUDA-launch route (alltoallv_kernel<<<...>>>), I can switch to that — could you clarify the exact direction you want?

0,
nullptr,
nullptr));

int major = 0;
int minor = 0;
int device = 0;
NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device));
cudaDeviceProp prop;
NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDeviceProperties(&prop, device));
major = prop.major;
minor = prop.minor;

std::string arch_arg = "--gpu-architecture=compute_" +
std::to_string(major) + std::to_string(minor);
std::vector<const char*> opts = {arch_arg.c_str(), "--std=c++17"};
// NVRTC needs CUDA headers to compile alltoallv.cu.
opts.push_back("-I/usr/local/cuda/include");
opts.push_back("-I/usr/local/cuda/include/cccl");
Comment on lines +74 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hardcoded CUDA include paths may break on non-standard installations

Suggested change
opts.push_back("-I/usr/local/cuda/include");
opts.push_back("-I/usr/local/cuda/include/cccl");
// Use CUDA_HOME environment variable or CMake-detected paths
std::string cuda_home = std::getenv("CUDA_HOME") ? std::getenv("CUDA_HOME") : "/usr/local/cuda";
opts.push_back(("-I" + cuda_home + "/include").c_str());
opts.push_back(("-I" + cuda_home + "/include/cccl").c_str());


nvrtcResult res = nvrtcCompileProgram(prog, (int)opts.size(), opts.data());
if (res != NVRTC_SUCCESS) {
size_t logSize;
NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLogSize(prog, &logSize));
std::vector<char> log(logSize);
NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLog(prog, log.data()));
NVF_ERROR(false, "Alltoallv kernel compilation failed:\n", log.data());
}

size_t ptxSize;
NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize));
std::vector<char> ptx(ptxSize);
NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data()));
NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog));

CUresult load_result = cuModuleLoadData(&module, ptx.data());
if (load_result != CUDA_SUCCESS) {
constexpr size_t kLogSize = 8192;
char error_log[kLogSize];
char info_log[kLogSize];
CUjit_option options[] = {
CU_JIT_ERROR_LOG_BUFFER,
CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
CU_JIT_INFO_LOG_BUFFER,
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
CU_JIT_LOG_VERBOSE};
void* option_values[] = {
(void*)error_log,
(void*)kLogSize,
(void*)info_log,
(void*)kLogSize,
(void*)1};
cuModuleLoadDataEx(&module, ptx.data(), 5, options, option_values);
NVF_ERROR(
false,
"Alltoallv kernel module load failed with error: ",
load_result,
"\nInfo Log:\n",
info_log,
"\nError Log:\n",
error_log);
}

NVFUSER_CUDA_SAFE_CALL(
cuModuleGetFunction(&kernel, module, "alltoallv_kernel"));
}

if (max_send_bytes == 0) {
return;
}

constexpr int kThreads = 256;
const int64_t blocks_x = (max_send_bytes + kThreads - 1) / kThreads;
void* args_kernel[] = {
const_cast<void*>(static_cast<const void*>(&send)),
const_cast<void*>(static_cast<const void*>(&recv_ptrs)),
const_cast<void*>(static_cast<const void*>(&send_offsets)),
const_cast<void*>(static_cast<const void*>(&send_sizes)),
const_cast<void*>(static_cast<const void*>(&recv_offsets)),
&world_size,
&elem_size,
&max_send_bytes};
NVFUSER_CUDA_SAFE_CALL(cuLaunchKernel(
kernel,
blocks_x,
static_cast<unsigned int>(world_size),
1,
kThreads,
1,
1,
0,
stream,
args_kernel,
nullptr));
}

std::vector<uint8_t> serializeInt64Vector(const std::vector<int64_t>& values) {
std::vector<uint8_t> bytes(values.size() * sizeof(int64_t));
std::memcpy(bytes.data(), values.data(), bytes.size());
return bytes;
}

std::vector<int64_t> deserializeInt64Vector(const std::vector<uint8_t>& bytes) {
NVF_CHECK(
bytes.size() % sizeof(int64_t) == 0, "Invalid int64 byte buffer size.");
const size_t count = bytes.size() / sizeof(int64_t);
std::vector<int64_t> values(count);
std::memcpy(values.data(), bytes.data(), bytes.size());
return values;
}

std::string alltoallvCountsKey(const std::string& tag, int64_t rank) {
return "nvfuser_alltoallv_counts_" + tag + "_" + std::to_string(rank);
}

std::string alltoallvBarrierKey(const std::string& tag, int64_t rank) {
return "nvfuser_alltoallv_barrier_" + tag + "_" + std::to_string(rank);
}
Comment on lines +172 to +174
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused function - alltoallvBarrierKey is defined but never called


void launchMulticastKernel(
void* dst,
Expand Down Expand Up @@ -710,4 +848,181 @@ void waitWithCudaBackend(
}
}

AlltoallvMetadata prepareAlltoallvMetadata(
const at::Tensor& send_counts,
const std::string& tag) {
Communicator& comm = Communicator::getInstance();
const int64_t world_size = comm.size();
const int64_t my_rank = comm.deviceId();
NVF_CHECK(
send_counts.is_cuda(), "alltoallv send_counts must be CUDA tensor.");
NVF_CHECK(
send_counts.dim() == 1 && send_counts.numel() == world_size,
"alltoallv send_counts must be 1D [R].");

auto store = comm.getTcpStore();
auto send_counts_cpu = send_counts.to(at::kCPU);
auto* send_ptr = send_counts_cpu.data_ptr<int64_t>();
std::vector<int64_t> send_counts_vec(send_ptr, send_ptr + world_size);

store->set(
alltoallvCountsKey(tag, my_rank), serializeInt64Vector(send_counts_vec));

std::vector<std::vector<int64_t>> counts_matrix(world_size);
for (int64_t rank = 0; rank < world_size; ++rank) {
auto bytes = store->get(alltoallvCountsKey(tag, rank));
counts_matrix[rank] = deserializeInt64Vector(bytes);
NVF_CHECK(
(int64_t)counts_matrix[rank].size() == world_size,
"Invalid alltoallv counts size.");
}
comm.barrier();
for (int64_t rank = 0; rank < world_size; ++rank) {
store->deleteKey(alltoallvCountsKey(tag, rank));
}

std::vector<int64_t> recv_counts_vec(world_size, 0);
for (int64_t sender = 0; sender < world_size; ++sender) {
recv_counts_vec[sender] = counts_matrix[sender][my_rank];
}

std::vector<int64_t> send_offsets_vec(world_size, 0);
int64_t prefix = 0;
for (int64_t rank = 0; rank < world_size; ++rank) {
send_offsets_vec[rank] = prefix;
prefix += send_counts_vec[rank];
}

std::vector<int64_t> recv_offsets_vec(world_size, 0);
for (int64_t peer = 0; peer < world_size; ++peer) {
int64_t offset = 0;
for (int64_t sender = 0; sender < my_rank; ++sender) {
offset += counts_matrix[sender][peer];
}
recv_offsets_vec[peer] = offset;
}

int64_t total_recv = 0;
for (auto value : recv_counts_vec) {
total_recv += value;
}

int64_t max_recv = 0;
int64_t max_send_total = 0;
for (int64_t rank = 0; rank < world_size; ++rank) {
int64_t total = 0;
for (int64_t sender = 0; sender < world_size; ++sender) {
total += counts_matrix[sender][rank];
}
if (total > max_recv) {
max_recv = total;
}
}

for (int64_t rank = 0; rank < world_size; ++rank) {
int64_t total = 0;
for (int64_t dest = 0; dest < world_size; ++dest) {
total += counts_matrix[rank][dest];
}
if (total > max_send_total) {
max_send_total = total;
}
}

int64_t max_send = 0;
for (auto value : send_counts_vec) {
if (value > max_send) {
max_send = value;
}
}

auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
auto send_offsets_cpu = at::empty({world_size}, cpu_options);
std::memcpy(
send_offsets_cpu.data_ptr<int64_t>(),
send_offsets_vec.data(),
world_size * sizeof(int64_t));
auto recv_offsets_cpu = at::empty({world_size}, cpu_options);
std::memcpy(
recv_offsets_cpu.data_ptr<int64_t>(),
recv_offsets_vec.data(),
world_size * sizeof(int64_t));
auto recv_counts_cpu = at::empty({world_size}, cpu_options);
std::memcpy(
recv_counts_cpu.data_ptr<int64_t>(),
recv_counts_vec.data(),
world_size * sizeof(int64_t));

AlltoallvMetadata metadata;
metadata.send_counts = send_counts;
metadata.recv_counts = recv_counts_cpu.to(send_counts.device());
metadata.send_offsets = send_offsets_cpu.to(send_counts.device());
metadata.recv_offsets = recv_offsets_cpu.to(send_counts.device());
metadata.total_recv = total_recv;
metadata.max_recv = max_recv;
metadata.max_send_total = max_send_total;
metadata.max_send_bytes = max_send;
metadata.world_size = world_size;
return metadata;
}

void alltoallvWithCudaBackend(
const at::Tensor& send,
const at::Tensor& recv,
const AlltoallvMetadata& metadata,
const std::vector<void*>& recv_ptrs,
CUstream stream) {
NVF_CHECK(send.is_cuda(), "alltoallv send must be CUDA.");
NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA.");
NVF_CHECK(
(int64_t)recv_ptrs.size() == metadata.world_size,
"recv_ptrs size must match world size.");

auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
auto recv_ptrs_cpu = at::empty({metadata.world_size}, cpu_options);
auto* ptrs = recv_ptrs_cpu.data_ptr<int64_t>();
for (int64_t rank = 0; rank < metadata.world_size; ++rank) {
ptrs[rank] =
static_cast<int64_t>(reinterpret_cast<uintptr_t>(recv_ptrs[rank]));
}
auto recv_ptrs_cuda = recv_ptrs_cpu.to(send.device());

const int64_t elem_stride =
metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1;
NVF_CHECK(
metadata.max_send_total == 0 ||
send.numel() % metadata.max_send_total == 0,
"alltoallv send numel must be divisible by max_send_total.");
NVF_CHECK(
metadata.max_recv == 0 || recv.numel() % metadata.max_recv == 0,
"alltoallv recv numel must be divisible by max_recv.");

auto send_offsets = metadata.send_offsets;
auto send_counts = metadata.send_counts;
auto recv_offsets = metadata.recv_offsets;
int64_t max_send_bytes = metadata.max_send_bytes;
if (elem_stride > 1) {
send_offsets = metadata.send_offsets * elem_stride;
send_counts = metadata.send_counts * elem_stride;
recv_offsets = metadata.recv_offsets * elem_stride;
max_send_bytes = metadata.max_send_bytes * elem_stride;
}

launchAlltoallvKernel(
send.data_ptr(),
reinterpret_cast<const uint64_t*>(recv_ptrs_cuda.data_ptr<int64_t>()),
send_offsets.data_ptr<int64_t>(),
send_counts.data_ptr<int64_t>(),
recv_offsets.data_ptr<int64_t>(),
metadata.world_size,
send.element_size(),
max_send_bytes * send.element_size(),
stream);
}

void alltoallvBarrier(const std::string& tag) {
Communicator& comm = Communicator::getInstance();
comm.barrier();
}

} // namespace nvfuser
Loading