Skip to content

Merging code from IFU branch.#8

Open
RichardChamberlain1 wants to merge 24 commits intoROCm:mainfrom
RichardChamberlain1:main
Open

Merging code from IFU branch.#8
RichardChamberlain1 wants to merge 24 commits intoROCm:mainfrom
RichardChamberlain1:main

Conversation

@RichardChamberlain1
Copy link

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist


# To build rocSHMEM with MPI disabled, please add this flag -DUSE_EXTERNAL_MPI=OFF
MPI_ROOT=$BUILD_DIR/ompi ../rocSHMEM/scripts/build_configs/gda_mlx5 --fresh \
-DUSE_IPC=ON \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you reverting the README?

parser.add_argument("--verbose", action="store_true", help="Verbose build")
parser.add_argument("--enable_timer", action="store_true", help="Enable timer to debug time out in internode")
parser.add_argument("--rocm-disable-ctx", action="store_true", help="Disable workgroup context optimization in internode")
parser.add_argument("--disable-mpi", action="store_true", help="Disable MPI detection and configuration")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disable-mpi should be kept.

for (int j = 0; j < kNumElemsPerRead; j += 2) {
float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
#ifdef USE_ROCM
#if defined(__gfx942__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes need to be reverted. It breaks for MI350.

internode::shmem_ctx_schar_put_nbi_warp(ctx,
#endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr), num_bytes_per_msg, dst_rank);
#if defined(ROCM_DISABLE_CTX)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes also need to be reverted.

Comment on lines 750 to 761
// Assign bias pointers
/*auto bias_opts = std::vector<std::optional<torch::Tensor>>({bias_0, bias_1});
void* bias_ptrs[2] = {nullptr, nullptr};
for (int i = 0; i < 2; ++i)
if (bias_opts[i].has_value()) {
auto bias = bias_opts[i].value();
EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous());
EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type());
EP_HOST_ASSERT(bias.size(0) == num_recv_tokens and bias.size(1) == hidden);
bias_ptrs[i] = bias.data_ptr();
}
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove it or comment that it might be needed for future work

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a comment to say it's not supported at this time.

Comment on lines +800 to +804
/*for (auto& to : {topk_weights, recv_topk_weights, bias_0, bias_1}) {
to.has_value() ? to->record_stream(comm_stream) : void();
if (allocate_on_comm_stream)
to.has_value() ? to->record_stream(compute_stream) : void();
}*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove it or comment that it might be needed for future work

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment.

csrc/deep_ep.hpp Outdated
Comment on lines 166 to 167
//const std::optional<torch::Tensor>& bias_0,
//const std::optional<torch::Tensor>& bias_1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove it

namespace intranode {

void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
//void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove it

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
trap(); \
abort();\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was that changed? As far as I remember, abort() function was unavailable on device side

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trap was unrecognized during compilation.

#if !defined(ROCM_DISABLE_CTX)
__shared__ internode::shmem_ctx_t ctx;
internode::shmem_wg_ctx_create(&ctx);
EP_DEVICE_ASSERT(internode::shmem_wg_ctx_create(&ctx) == 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe there's something like INVALID_CTX to compare against, but not zero?

Comment on lines 40 to 42
//#pragma unroll
//for (int i = 0; i < kNumRanks; ++ i)
// per_rank_buffer[rank * kNumRanks + i] = num_tokens_per_rank[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's clean-up

}

void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) {
/*void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove old version

#include "exception.cuh"

#ifdef USE_ROCM
#define syncthreads() __syncthreads()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we just use __syncthreads() everywhere? There's no custom functionality added behind this function, and __ will explicitly mark that we're using runtime one

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wondered this, but was just following how it's always done and assumed that there was some good reason for this?
Probably just some debug at some point?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's seems like there's no point for that particular function to wrap it. It is (was) necessary for some other calls like __shfl_sync for example, because there we have different number of arguments compared to CUDA runtime thus a decorator is required. Let's revert to __synchtreads()

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR merges code from an IFU (Internal Feature Update) branch into the main branch. The changes are extensive and include:

Changes:

  • Major refactoring of low-latency communication APIs with support for failure simulation and shrink mode
  • Addition of configurable topk_idx data type (32-bit or 64-bit integers via TOPK_IDX_BITS compile flag)
  • Enhanced test infrastructure with argument parsers, pressure testing, and failure simulation capabilities
  • Improved error handling, diagnostic features, and memory management with explicit destroy() method
  • ROCm compatibility improvements and NVSHMEM optional compilation support

Reviewed changes

Copilot reviewed 23 out of 23 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
tests/utils.py Updated utility functions with improved FP8 casting, alignment helpers, and benchmark timing
tests/test_low_latency.py Enhanced with failure simulation, shrink testing, LogFMT support, and CLI arguments
tests/test_intranode.py Added argument parsing, fabric support, and improved test coverage
tests/test_internode.py Enhanced with pressure testing, CLI arguments, and bias support
setup.py Added NVSHMEM detection, TOPK_IDX_BITS configuration, and improved OpenMPI detection
deep_ep/init.py Exported topk_idx_t type
deep_ep/utils.py Added NVLink connection checking utility
deep_ep/buffer.py Major API expansion with shrink mode, mask buffers, diagnostic stats, and explicit destruction
csrc/deep_ep.hpp Updated Buffer class with new features and reorganized method signatures
csrc/deep_ep.cpp Implemented new features including mask buffers, diagnostic stats, and GIL release
csrc/config.hpp Fixed typo in ceil_div function name
csrc/kernels/*.cu/.cuh Extensive kernel updates for new features and API changes

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
This is useful for detecting and pre-cisely localizing slow anomalies.
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a typo in the comment: "pre-cisely" should be "precisely" (the word is incorrectly split across a line break in the original).

Suggested change
This is useful for detecting and pre-cisely localizing slow anomalies.
This is useful for detecting and precisely localizing slow anomalies.

Copilot uses AI. Check for mistakes.
import nvidia.nvshmem as nvshmem # noqa: F401
except (ModuleNotFoundError, AttributeError, IndexError):
print(
'Warning: `NVSHMEM_DIR` is not specified, and the NVSHMEM module is not installed. All internode and low-latency features are disabled\n'
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing space after comma in the help text. Should be "internode, and low-latency" instead of "internode and low-latency".

Suggested change
'Warning: `NVSHMEM_DIR` is not specified, and the NVSHMEM module is not installed. All internode and low-latency features are disabled\n'
'Warning: `NVSHMEM_DIR` is not specified, and the NVSHMEM module is not installed. All internode, and low-latency features are disabled\n'

Copilot uses AI. Check for mistakes.
#endif
} else {
st_na_release(reinterpret_cast<int64_t *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The line st_na_release(reinterpret_cast<int64_t *>(...) has been changed to st_na_release(reinterpret_cast<int *>(...). However, the original code had the function signature for int64_t* removed from utils.cuh. This change appears intentional but could cause type mismatch issues if the underlying pointer type is actually int64_t. Verify that the pointer types are consistent.

Copilot uses AI. Check for mistakes.
#endif
}
#else //USE_ROCM
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr), num_bytes_per_msg, dst_rank);
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 239 appears to duplicate parameters in the function call. The line shows reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr), num_bytes_per_msg, dst_rank); but this looks like it's incomplete or malformed in the diff. This could be a diff artifact or an actual code issue that needs investigation.

Suggested change
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr), num_bytes_per_msg, dst_rank);
}

Copilot uses AI. Check for mistakes.
Comment on lines +10 to 12
//void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
void barrier(int **task_fifo_ptrs, int rank, int num_ranks, cudaStream_t stream, int head = 0);

Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function signature change from void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) to void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream, int head = 0) reorders parameters and adds a default value. This is a breaking API change that could cause issues if there are external callers. Verify all call sites have been updated.

Suggested change
//void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
void barrier(int **task_fifo_ptrs, int rank, int num_ranks, cudaStream_t stream, int head = 0);
void barrier(int **task_fifo_ptrs, int rank, int num_ranks, cudaStream_t stream, int head = 0);
inline void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) {
barrier(task_fifo_ptrs, rank, num_ranks, stream, head);
}

Copilot uses AI. Check for mistakes.
Comment on lines +236 to +237
suppress_kineto_output=True,
num_kernels_per_period=2 if return_recv_hook else 1)
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keyword argument 'num_kernels_per_period' is not a supported parameter name of function bench_kineto.

Suggested change
suppress_kineto_output=True,
num_kernels_per_period=2 if return_recv_hook else 1)
suppress_kineto_output=True)

Copilot uses AI. Check for mistakes.
# Test combine
bias_0 = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
bias_1 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combine_args = {'x': recv_x, 'bias': (bias_0, bias_1), 'handle': handle, 'config': config, 'async_finish': async_mode}
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assignment to 'combine_args' is unnecessary as it is redefined before this value is used.

Suggested change
combine_args = {'x': recv_x, 'bias': (bias_0, bias_1), 'handle': handle, 'config': config, 'async_finish': async_mode}

Copilot uses AI. Check for mistakes.
try:
nvshmem_dir = importlib.util.find_spec("nvidia.nvshmem").submodule_search_locations[0]
nvshmem_host_lib = get_nvshmem_host_lib_name(nvshmem_dir)
import nvidia.nvshmem as nvshmem # noqa: F401
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'nvshmem' is not used.

Suggested change
import nvidia.nvshmem as nvshmem # noqa: F401

Copilot uses AI. Check for mistakes.
import torch
import torch.distributed as dist
from typing import Optional
from typing import Optional, Union
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'Union' is not used.

Suggested change
from typing import Optional, Union
from typing import Optional

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants