Skip to content

Add kernel based alltoallv and cuda backend for MoE dispatch and combine#5863

Open
samnordmann wants to merge 24 commits intomainfrom
dispatch_combine/stub_for_kernel
Open

Add kernel based alltoallv and cuda backend for MoE dispatch and combine#5863
samnordmann wants to merge 24 commits intomainfrom
dispatch_combine/stub_for_kernel

Conversation

@samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented Jan 22, 2026

  • Add alltoallv implementation using GPU-initiated comms (SM-driven NVLink), taking only GPU buffers, even for the alltoallv "metadate" such as splitSize. Available throughkCuda backend. Requires recv buffer to be allocated as symmetric memory
  • Add Cuda backend for dispatch and combine which avoids gpu->cpu sync (compared to nccl backed version)

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann samnordmann changed the title add kernel based a2av and cuda backend for d/c Add kernel based alltoallv and cuda backend for MoE dispatch and combine Jan 22, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile Summary

Adds GPU-initiated communication support for MoE dispatch/combine operations using CUDA backend, avoiding GPU-to-CPU synchronization overhead compared to NCCL.

Major changes:

  • New alltoallv kernel implementation using SM-driven NVLink communication with GPU buffers
  • CUDA backend for dispatch/combine using symmetric memory allocation and IPC handles for direct peer-to-peer writes
  • Metadata exchange via TCPStore instead of NCCL alltoall for split sizes
  • Extended tests to cover both NCCL and CUDA backends with parameterized testing

Issues found:

  • Thread safety issue in static kernel module initialization (both launchAlltoallvKernel and launchMulticastKernel)
  • Hardcoded CUDA include paths (already flagged)
  • Unused alltoallvBarrierKey function (already flagged)

Confidence Score: 3/5

  • This PR has thread safety issues that could cause race conditions during concurrent kernel initialization
  • Score reflects critical thread safety bugs in static kernel module initialization that could cause runtime failures in multi-threaded environments, though the core alltoallv algorithm appears sound
  • Pay close attention to csrc/multidevice/cuda_p2p.cpp for thread safety in launchAlltoallvKernel and launchMulticastKernel functions

Important Files Changed

Filename Overview
csrc/multidevice/alltoallv.cu New CUDA kernel for alltoallv communication pattern, implements byte-by-byte copy from send buffer to remote recv buffers using peer pointers
csrc/multidevice/cuda_p2p.cpp Implements CUDA backend for P2P, broadcast, allgather, and alltoallv operations using GPU-initiated communication; contains hardcoded CUDA paths and unused function
csrc/multidevice/dispatch_combine.cpp Refactored MoE dispatch/combine to support both NCCL and CUDA backends; CUDA backend uses symmetric memory allocation and GPU-initiated alltoallv communication
tests/cpp/test_multidevice_dispatch_combine.cpp Extended tests to cover both NCCL and CUDA backends for dispatch/combine operations with parameterized testing

Sequence Diagram

sequenceDiagram
    participant R0 as Rank 0
    participant R1 as Rank 1
    participant Store as TCPStore
    participant GPU0 as GPU 0
    participant GPU1 as GPU 1

    Note over R0,R1: Dispatch Phase
    R0->>R0: Sort tokens by expert_id
    R1->>R1: Sort tokens by expert_id
    R0->>R0: Compute send_counts per rank
    R1->>R1: Compute send_counts per rank
    
    Note over R0,Store: CUDA Backend: Exchange metadata via TCPStore
    R0->>Store: Put send_counts
    R1->>Store: Put send_counts
    R0->>Store: Get all ranks' send_counts
    R1->>Store: Get all ranks' send_counts
    R0->>R0: Compute recv_counts/offsets
    R1->>R1: Compute recv_counts/offsets
    
    Note over GPU0,GPU1: Allocate symmetric memory for send/recv buffers
    R0->>GPU0: Allocate send_x_sym, recv_x_sym
    R1->>GPU1: Allocate send_x_sym, recv_x_sym
    R0->>R1: Exchange IPC handles
    R1->>R0: Exchange IPC handles
    
    Note over GPU0,GPU1: GPU-initiated alltoallv kernel
    GPU0->>GPU1: Write to remote recv_x_sym
    GPU1->>GPU0: Write to remote recv_x_sym
    
    Note over R0,R1: Barrier synchronization
    R0->>R1: Barrier
    R1->>R0: Barrier
    
    Note over R0,R1: Combine Phase
    R0->>R0: Sort by src_rank
    R1->>R1: Sort by src_rank
    GPU0->>GPU1: Alltoallv send_x back
    GPU1->>GPU0: Alltoallv send_x back
    R0->>R0: Scatter by src_idx
    R1->>R1: Scatter by src_idx
Loading

Last reviewed commit: 374c8b3

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 36 to 38
if (!communicator_->isBackendAvailable(CommunicatorBackend::kNccl)) {
GTEST_SKIP() << "Backend " << backend << " not available.";
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: checking wrong backend constant - should check backend parameter, not hardcoded kNccl

Suggested change
if (!communicator_->isBackendAvailable(CommunicatorBackend::kNccl)) {
GTEST_SKIP() << "Backend " << backend << " not available.";
}
if (!communicator_->isBackendAvailable(backend)) {
GTEST_SKIP() << "Backend " << backend << " not available.";
}

Base automatically changed from dispatch_combine/stub to main February 9, 2026 14:43
@github-actions
Copy link

github-actions bot commented Feb 10, 2026

Review updated until commit 374c8b3

Description

  • Add kernel-based alltoallv implementation using GPU-initiated communications (SM-driven NVLink)

  • Implement CUDA backend for MoE dispatch and combine operations to avoid GPU->CPU synchronization

  • Support both NCCL and CUDA backends in dispatch/combine with symmetric memory allocation

  • Add comprehensive tests for alltoallv functionality and parameterized backend testing

Changes walkthrough

Relevant files
Enhancement
cuda_p2p.cpp
Implement kernel-based alltoallv with CUDA backend             

csrc/multidevice/cuda_p2p.cpp

  • Add launchAlltoallvKernel function with NVRTC compilation and kernel
    launching
  • Implement prepareAlltoallvMetadata for computing send/recv offsets and
    counts
  • Add alltoallvWithCudaBackend using kernel-based GPU-initiated
    communications
  • Include serialization/deserialization helpers and barrier functions
  • +315/-0 
    dispatch_combine.cpp
    Add CUDA backend support for MoE dispatch/combine               

    csrc/multidevice/dispatch_combine.cpp

  • Extend doMoeDispatch to support both NCCL and CUDA backends
  • Add CUDA backend path using symmetric tensors and alltoallv kernel
  • Modify doMoeCombine to support CUDA backend with symmetric memory
  • Maintain backward compatibility with existing NCCL implementation
  • +201/-58
    alltoallv.cu
    Implement alltoallv CUDA kernel                                                   

    csrc/multidevice/alltoallv.cu

  • Implement alltoallv_kernel CUDA kernel for peer-to-peer data transfer
  • Handle byte-level copying with proper offset calculations
  • Support multi-dimensional kernel launches for world_size peers
  • +36/-0   
    cuda_p2p.h
    Add alltoallv function declarations and metadata struct   

    csrc/multidevice/cuda_p2p.h

  • Add AlltoallvMetadata struct for storing send/recv metadata
  • Declare prepareAlltoallvMetadata and alltoallvWithCudaBackend
    functions
  • Add alltoallvBarrier function declaration
  • +26/-0   
    Tests
    test_multidevice_alltoallv.cpp
    Add alltoallv CUDA backend tests                                                 

    tests/cpp/test_multidevice_alltoallv.cpp

  • Add new test file for alltoallv CUDA backend functionality
  • Test asymmetric alltoallv patterns with varying send counts
  • Verify data integrity across ranks using symmetric tensor setup
  • +80/-0   
    test_multidevice_dispatch_combine.cpp
    Parameterize dispatch/combine tests for multiple backends

    tests/cpp/test_multidevice_dispatch_combine.cpp

  • Convert tests to parameterized tests supporting both NCCL and CUDA
    backends
  • Add backend availability checks for each test case
  • Test dispatch-only, combine-only, and full dispatch-combine workflows
  • +32/-20 
    Documentation
    dispatch_combine.h
    Update dispatch/combine documentation for CUDA backend     

    csrc/multidevice/dispatch_combine.h

  • Update documentation to reflect CUDA backend support
  • Clarify backend parameter supports both CUDA and NCCL
  • +4/-4     
    Configuration changes
    CMakeLists.txt
    Add alltoallv test and kernel to build configuration         

    CMakeLists.txt

  • Add test_multidevice_alltoallv.cpp to test build configuration
  • Include alltoallv.cu in runtime files for compilation
  • +2/-0     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Runtime Compilation Risk

    The alltoallv kernel is compiled at runtime using NVRTC on first use (lines 52-122). This introduces potential startup latency and compilation failure risks. Consider adding fallback mechanisms or pre-compilation options. The error handling is comprehensive but runtime compilation should be documented as a known limitation.

      nvrtcProgram prog;
      NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram(
          &prog,
          nvfuser_resources::alltoallv_cu,
          "alltoallv.cu",
          0,
          nullptr,
          nullptr));
    
      int major = 0;
      int minor = 0;
      int device = 0;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device));
      cudaDeviceProp prop;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDeviceProperties(&prop, device));
      major = prop.major;
      minor = prop.minor;
    
      std::string arch_arg = "--gpu-architecture=compute_" +
          std::to_string(major) + std::to_string(minor);
      std::vector<const char*> opts = {arch_arg.c_str(), "--std=c++17"};
      // NVRTC needs CUDA headers to compile alltoallv.cu.
      opts.push_back("-I/usr/local/cuda/include");
      opts.push_back("-I/usr/local/cuda/include/cccl");
    
      nvrtcResult res = nvrtcCompileProgram(prog, (int)opts.size(), opts.data());
      if (res != NVRTC_SUCCESS) {
        size_t logSize;
        NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLogSize(prog, &logSize));
        std::vector<char> log(logSize);
        NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLog(prog, log.data()));
        NVF_ERROR(false, "Alltoallv kernel compilation failed:\n", log.data());
      }
    
      size_t ptxSize;
      NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize));
      std::vector<char> ptx(ptxSize);
      NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data()));
      NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog));
    
      CUresult load_result = cuModuleLoadData(&module, ptx.data());
      if (load_result != CUDA_SUCCESS) {
        constexpr size_t kLogSize = 8192;
        char error_log[kLogSize];
        char info_log[kLogSize];
        CUjit_option options[] = {
            CU_JIT_ERROR_LOG_BUFFER,
            CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
            CU_JIT_INFO_LOG_BUFFER,
            CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
            CU_JIT_LOG_VERBOSE};
        void* option_values[] = {
            (void*)error_log,
            (void*)kLogSize,
            (void*)info_log,
            (void*)kLogSize,
            (void*)1};
        cuModuleLoadDataEx(&module, ptx.data(), 5, options, option_values);
        NVF_ERROR(
            false,
            "Alltoallv kernel module load failed with error: ",
            load_result,
            "\nInfo Log:\n",
            info_log,
            "\nError Log:\n",
            error_log);
      }
    
      NVFUSER_CUDA_SAFE_CALL(
          cuModuleGetFunction(&kernel, module, "alltoallv_kernel"));
    }
    Symmetric Memory Requirement

    The PR requires recv buffers to be allocated as symmetric memory, which is a significant hardware/software limitation. This requirement should be more prominently documented in the API documentation and user-facing interfaces, as it may not be available on all systems.

    std::string alltoallvCountsKey(const std::string& tag, int64_t rank) {
      return "nvfuser_alltoallv_counts_" + tag + "_" + std::to_string(rank);
    }
    
    std::string alltoallvBarrierKey(const std::string& tag, int64_t rank) {
      return "nvfuser_alltoallv_barrier_" + tag + "_" + std::to_string(rank);
    }
    
    void launchMulticastKernel(
        void* dst,
        const void* src,
        size_t size,
        CUstream stream) {
      static CUmodule module = nullptr;
      static CUfunction kernel = nullptr;
    
      if (module == nullptr) {
        nvrtcProgram prog;
        NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram(
            &prog,
            nvfuser_resources::multicast_cu,
            "multicast.cu",
            0,
            nullptr,
            nullptr));
    Performance Validation Missing

    The PR claims performance improvements by avoiding GPU->CPU sync compared to NCCL, but no quantitative performance data or benchmarks are provided. Consider adding performance measurements or at least documenting the expected performance characteristics.

    auto metadata =
        prepareAlltoallvMetadata(n_tokens_to_rank, "moe_dispatch_counts");
    auto n_tokens_from_rank = metadata.recv_counts;
    const int64_t total_recv = metadata.total_recv;
    const int64_t max_recv = metadata.max_recv;
    
    // Allocate symmetric buffers for send/recv payloads.
    auto send_x_sym = SymmetricTensor::allocate(
        {metadata.max_send_total, hidden}, x.scalar_type(), x.device());
    send_x_sym.narrow(0, 0, num_tokens).copy_(send_x);
    auto send_topk_idx_sym = SymmetricTensor::allocate(
        {metadata.max_send_total, topk_idx.size(1)},
        topk_idx.scalar_type(),
        x.device());
    send_topk_idx_sym.narrow(0, 0, num_tokens).copy_(send_topk_idx);
    auto send_topk_weights_sym = SymmetricTensor::allocate(
        {metadata.max_send_total, topk_weights.size(1)},
        topk_weights.scalar_type(),
        x.device());
    send_topk_weights_sym.narrow(0, 0, num_tokens).copy_(send_topk_weights);
    auto send_src_idx_sym = SymmetricTensor::allocate(
        {metadata.max_send_total}, send_src_idx.scalar_type(), x.device());
    send_src_idx_sym.narrow(0, 0, num_tokens).copy_(send_src_idx);
    
    auto recv_x_sym = SymmetricTensor::allocate(
        {max_recv, hidden}, x.scalar_type(), x.device());
    auto recv_topk_idx_sym = SymmetricTensor::allocate(
        {max_recv, topk_idx.size(1)}, topk_idx.scalar_type(), x.device());
    auto recv_topk_weights_sym = SymmetricTensor::allocate(
        {max_recv, topk_weights.size(1)}, topk_weights.scalar_type(), x.device());
    auto recv_src_idx_sym = SymmetricTensor::allocate(
        {max_recv}, send_src_idx.scalar_type(), x.device());
    
    SymmetricTensor recv_x_handle(recv_x_sym);
    SymmetricTensor recv_topk_idx_handle(recv_topk_idx_sym);
    SymmetricTensor recv_topk_weights_handle(recv_topk_weights_sym);
    SymmetricTensor recv_src_idx_handle(recv_src_idx_sym);
    recv_x_handle.setupRemoteHandles("moe_dispatch_recv_x");
    recv_topk_idx_handle.setupRemoteHandles("moe_dispatch_recv_topk_idx");
    recv_topk_weights_handle.setupRemoteHandles("moe_dispatch_recv_topk_weights");
    recv_src_idx_handle.setupRemoteHandles("moe_dispatch_recv_src_idx");
    
    std::vector<void*> recv_x_ptrs(world_size);
    std::vector<void*> recv_topk_idx_ptrs(world_size);
    std::vector<void*> recv_topk_weights_ptrs(world_size);
    std::vector<void*> recv_src_idx_ptrs(world_size);
    for (int64_t rank = 0; rank < world_size; ++rank) {
      recv_x_ptrs[rank] = recv_x_handle.remoteTensor(rank).data_ptr();
      recv_topk_idx_ptrs[rank] =
          recv_topk_idx_handle.remoteTensor(rank).data_ptr();
      recv_topk_weights_ptrs[rank] =
          recv_topk_weights_handle.remoteTensor(rank).data_ptr();
      recv_src_idx_ptrs[rank] = recv_src_idx_handle.remoteTensor(rank).data_ptr();
    }
    
    auto stream =
        static_cast<CUstream>(at::cuda::getDefaultCUDAStream().stream());
    alltoallvWithCudaBackend(
        send_x_sym, recv_x_sym, metadata, recv_x_ptrs, stream);
    alltoallvWithCudaBackend(
        send_topk_idx_sym,
        recv_topk_idx_sym,
        metadata,
        recv_topk_idx_ptrs,
        stream);
    alltoallvWithCudaBackend(
        send_topk_weights_sym,
        recv_topk_weights_sym,
        metadata,
        recv_topk_weights_ptrs,
        stream);
    alltoallvWithCudaBackend(
        send_src_idx_sym, recv_src_idx_sym, metadata, recv_src_idx_ptrs, stream);
    alltoallvBarrier("moe_dispatch_counts");
    
    auto recv_x = recv_x_sym.narrow(0, 0, total_recv);
    auto recv_topk_idx = recv_topk_idx_sym.narrow(0, 0, total_recv);
    auto recv_topk_weights = recv_topk_weights_sym.narrow(0, 0, total_recv);
    auto recv_src_idx = recv_src_idx_sym.narrow(0, 0, total_recv);
    
    return DispatchResult{

    @samnordmann
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines +74 to +75
    opts.push_back("-I/usr/local/cuda/include");
    opts.push_back("-I/usr/local/cuda/include/cccl");
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    hardcoded CUDA include paths may break on non-standard installations

    Suggested change
    opts.push_back("-I/usr/local/cuda/include");
    opts.push_back("-I/usr/local/cuda/include/cccl");
    // Use CUDA_HOME environment variable or CMake-detected paths
    std::string cuda_home = std::getenv("CUDA_HOME") ? std::getenv("CUDA_HOME") : "/usr/local/cuda";
    opts.push_back(("-I" + cuda_home + "/include").c_str());
    opts.push_back(("-I" + cuda_home + "/include/cccl").c_str());

    Comment on lines +172 to +174
    std::string alltoallvBarrierKey(const std::string& tag, int64_t rank) {
    return "nvfuser_alltoallv_barrier_" + tag + "_" + std::to_string(rank);
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    unused function - alltoallvBarrierKey is defined but never called

    @samnordmann samnordmann requested a review from nsarka February 10, 2026 19:21
    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM otherwise

    NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram(
    &prog,
    nvfuser_resources::alltoallv_cu,
    "alltoallv.cu",
    Copy link
    Collaborator

    @wujingyue wujingyue Feb 11, 2026

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Why nvrtc? Can't we simply alltoallv_kernel<<<...>>>?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Good question. I used nvrtc here to match nvFuser’s existing runtime-kernel pattern, to the best of my understanding, not because alltoallv_kernel<<<...>>> is impossible.

    In our codebase, IIUC, these helper CUDA kernels are treated as runtime resources:
    CMakeLists.txt adds csrc/multidevice/alltoallv.cu to NVFUSER_RUNTIME_FILES, then stringifies it into nvfuser_resources/alltoallv.h.

    At runtime, we compile/load that source with NVRTC + cuModuleLoadData, same style as other runtime kernels in csrc/runtime/compiled_kernel.cpp, and similarly to the multicast helper in cuda_p2p.cpp.

    If you’d prefer the static CUDA-launch route (alltoallv_kernel<<<...>>>), I can switch to that — could you clarify the exact direction you want?

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines +48 to +49
    static CUmodule module = nullptr;
    static CUfunction kernel = nullptr;
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    static variables without thread safety - multiple threads calling launchAlltoallvKernel concurrently could race on the module == nullptr check during first initialization

    Suggested change
    static CUmodule module = nullptr;
    static CUfunction kernel = nullptr;
    static std::once_flag init_flag;
    static CUmodule module = nullptr;
    static CUfunction kernel = nullptr;
    std::call_once(init_flag, [&]() {

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 16, 2026

    Additional Comments (1)

    csrc/multidevice/cuda_p2p.cpp
    same thread safety issue as in launchAlltoallvKernel - static variables need synchronization for concurrent first-time initialization

    @samnordmann
    Copy link
    Collaborator Author

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants