From 4d04536aafb287ef5a748d7ce41e879251f95403 Mon Sep 17 00:00:00 2001 From: lzhang2 Date: Mon, 15 Sep 2025 15:35:30 +0800 Subject: [PATCH 01/25] support symm memory on XPU devices --- src/xccl/IpcExchange.hpp | 400 +++++++++++++++ src/xccl/XPUSymmetricMemory.cpp | 460 ++++++++++++++++++ src/xccl/XPUSymmetricMemory.hpp | 130 +++++ src/xccl/XPUSymmetricMemoryTypes.hpp | 8 + src/xccl/XPUSymmetricMemoryUtils.cpp | 76 +++ src/xccl/XPUSymmetricMemoryUtils.hpp | 89 ++++ src/xccl/ze_symbol.hpp | 254 ++++++++++ .../distributed/test_symmetric_memory_xccl.py | 85 ++++ 8 files changed, 1502 insertions(+) create mode 100644 src/xccl/IpcExchange.hpp create mode 100644 src/xccl/XPUSymmetricMemory.cpp create mode 100644 src/xccl/XPUSymmetricMemory.hpp create mode 100644 src/xccl/XPUSymmetricMemoryTypes.hpp create mode 100644 src/xccl/XPUSymmetricMemoryUtils.cpp create mode 100644 src/xccl/XPUSymmetricMemoryUtils.hpp create mode 100644 src/xccl/ze_symbol.hpp create mode 100644 test/xpu/distributed/test_symmetric_memory_xccl.py diff --git a/src/xccl/IpcExchange.hpp b/src/xccl/IpcExchange.hpp new file mode 100644 index 0000000000..e515cd6ce0 --- /dev/null +++ b/src/xccl/IpcExchange.hpp @@ -0,0 +1,400 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "xccl/ze_symbol.hpp" + +#include + +#include +#include +#include +#include +#include + +struct exchange_contents { + // first 4-byte is file descriptor for drmbuf or gem object + union { + ze_ipc_mem_handle_t ipc_handle; + int fd = -1; + }; + size_t offset = 0; + int pid = -1; +}; + +#define sysCheck(x) \ + if (x == -1) { \ + throw std::system_error(std::make_error_code(std::errc(errno))); \ + } + +// We can't inherit it from cmsghdr because flexible array member +struct exchange_fd { + char obscure[CMSG_LEN(sizeof(int)) - sizeof(int)]; + int fd; + + exchange_fd(int cmsg_level, int cmsg_type, int fd) : fd(fd) { + auto* cmsg = reinterpret_cast(obscure); + cmsg->cmsg_len = sizeof(exchange_fd); + cmsg->cmsg_level = cmsg_level; + cmsg->cmsg_type = cmsg_type; + } + + exchange_fd() : fd(-1) { + memset(obscure, 0, sizeof(obscure)); + }; +}; + +void un_send_fd(int sock, int fd, int rank, size_t offset) { + iovec iov[1]; + msghdr msg; + auto rank_offset = std::make_pair(rank, offset); + + iov[0].iov_base = &rank_offset; + iov[0].iov_len = sizeof(rank_offset); + msg.msg_iov = iov; + msg.msg_iovlen = 1; + msg.msg_name = nullptr; + msg.msg_namelen = 0; + + exchange_fd cmsg(SOL_SOCKET, SCM_RIGHTS, fd); + + msg.msg_control = &cmsg; + msg.msg_controllen = sizeof(exchange_fd); + sysCheck(sendmsg(sock, &msg, 0)); +} + +std::tuple un_recv_fd(int sock) { + iovec iov[1]; + msghdr msg; + std::pair rank_offset; + + iov[0].iov_base = &rank_offset; + iov[0].iov_len = sizeof(rank_offset); + msg.msg_iov = iov; + msg.msg_iovlen = 1; + msg.msg_name = nullptr; + msg.msg_namelen = 0; + + exchange_fd cmsg; + msg.msg_control = &cmsg; + msg.msg_controllen = sizeof(exchange_fd); + int n_recv = recvmsg(sock, &msg, 0); + sysCheck(n_recv); + // assert(n_recv == sizeof(int)); + + return std::make_tuple(cmsg.fd, rank_offset.first, rank_offset.second); +} + +int prepare_socket(const char* sockname) { + sockaddr_un un; + memset(&un, 0, sizeof(un)); + un.sun_family = AF_UNIX; + strcpy(un.sun_path, sockname); + + auto sock = socket(AF_UNIX, SOCK_STREAM, 0); + sysCheck(sock); + + int on = 1; + sysCheck(ioctl(sock, FIONBIO, &on)); + + auto size = offsetof(sockaddr_un, sun_path) + strlen(un.sun_path); + sysCheck(bind(sock, (sockaddr*)&un, size)); + + return sock; +} + +int server_listen(const char* sockname) { + unlink(sockname); + auto sock = prepare_socket(sockname); + sysCheck(listen(sock, 10)); + + return sock; +} + +int serv_accept(int listen_sock) { + sockaddr_un un; + + socklen_t len = sizeof(un); + auto accept_sock = accept(listen_sock, (sockaddr*)&un, &len); + sysCheck(accept_sock); + + return accept_sock; +} + +bool wait_for_socket_file(const char* path, int max_seconds = 10) { + struct stat buffer; + for (int i = 0; i < max_seconds * 10; ++i) { + if (stat(path, &buffer) == 0) { + return true; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + return false; +} + +int client_connect(const char* server, const char* client) { + if (!wait_for_socket_file(server, 10)) { + std::cerr << "Error: timeout waiting for server socket file: " << server + << std::endl; + exit(EXIT_FAILURE); + } + auto sock = prepare_socket(client); + sockaddr_un sun; + memset(&sun, 0, sizeof(sun)); + sun.sun_family = AF_UNIX; + strcpy(sun.sun_path, server); + auto len = offsetof(sockaddr_un, sun_path) + strlen(server); + const int max_retries = 50; + int retry = 0; + int ret = -1; + while (retry < max_retries) { + ret = connect(sock, (sockaddr*)&sun, len); + if (ret == 0) + break; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + retry++; + } + if (ret != 0) { + perror("connect failed"); + exit(EXIT_FAILURE); + } + + // sysCheck(connect(sock, (sockaddr*)&sun, len)); + return sock; +} + +void un_allgather( + exchange_contents* send_buf, + exchange_contents recv_buf[], + int rank, + int world) { + const char* servername_prefix = "/tmp/open-peer-ipc-mem-server-rank_"; + const char* clientname_prefix = "/tmp/open-peer-ipc-mem-client-rank_"; + char server_name[64]; + /* get username to make server_name unique */ + auto uid = getuid(); + auto pwd = getpwuid(uid); + snprintf( + server_name, + sizeof(server_name), + "%s%d_%s", + servername_prefix, + rank, + pwd->pw_name); + unlink(server_name); + auto s_listen = server_listen(server_name); + + pollfd fdarray[world]; + int recv_socks[world - 1]; + + for (auto& pollfd : fdarray) + pollfd.fd = -1; + std::fill(recv_socks, recv_socks + world - 1, -1); + + auto fd_guard = [&]() { + for (int i = 0, j = 0; i < world; ++i) { + if (i != rank && recv_socks[j] != -1) + sysCheck(close(recv_socks[j++])); + if (fdarray[i].fd != -1) + sysCheck(close(fdarray[i].fd)); + } + }; + + struct guard__ { + using F = decltype(fd_guard); + F f; + guard__(const F& f) : f(f) {} + ~guard__() { + f(); + } + } free_fd(fd_guard); + + // connect to all ranks + for (int i = 0; i < world; ++i) { + if (rank == i) { + fdarray[i].fd = s_listen; + fdarray[i].events = POLLIN; + fdarray[i].revents = 0; + } else { + char peer_name[64]; + char client_name[64]; + + snprintf( + client_name, + sizeof(client_name), + "%s%d-%d_%s", + clientname_prefix, + rank, + i, + pwd->pw_name); + unlink(client_name); + + snprintf( + peer_name, + sizeof(peer_name), + "%s%d_%s", + servername_prefix, + i, + pwd->pw_name); + fdarray[i].fd = client_connect(peer_name, client_name); + fdarray[i].events = POLLOUT; + fdarray[i].revents = 0; + } + } + + // std::future> future_fds[world -1]; + int slot = 0; + uint32_t send_progress = 1 << rank; + + while (slot < world - 1 || send_progress != (1 << world) - 1) { + sysCheck(ppoll(fdarray, world, nullptr, nullptr)); + + for (int i = 0; i < world; ++i) { + if (i == rank && (fdarray[i].revents & POLLIN)) { + // auto accept_sock = serv_accept(fdarray[i].fd); + // future_fds[slot ++] = std::async( + // std::launch::async, [=]() { + // struct sock_guard{ + // int sock; + // sock_guard(int sock) : sock(sock) {} + // ~guard_sock() {sysCheck(close(sock));} + // } release(accept_sock); + // auto ret = un_recv_fd(accept_sock); + // return ret;}); + recv_socks[slot++] = serv_accept(fdarray[i].fd); + } else if ( + (send_progress & (1 << i)) == 0 && fdarray[i].revents & POLLOUT) { + un_send_fd(fdarray[i].fd, send_buf->fd, rank, send_buf->offset); + send_progress |= 1 << i; + } + } + } + + for (int i = 0; i < world - 1; ++i) { + // future_fds[i].wait(); + // auto [fd, peer, offset] = future_fds[i].get(); + auto [fd, peer, offset] = un_recv_fd(recv_socks[i]); + recv_buf[peer].fd = fd; + recv_buf[peer].offset = offset; + } + + recv_buf[rank] = *send_buf; +} + +class IpcChannel { + public: + IpcChannel() { + initialized = false; + } + void init(sycl::queue& queue, uint32_t rank_in, uint32_t world_in) { + if (initialized) + return; + + if (!load_level_zero_library()) { + throw std::runtime_error("Failed to initialize Level Zero"); + } + + zeCheck_dynamic(zeInit_dynamic(0)); + int tmp_rank, tmp_world; + + tmp_world = world_in; + tmp_rank = rank_in; + + rank = tmp_rank; + world = tmp_world; + initialized = true; + } + void release(sycl::queue& queue) { + if (!initialized) + return; + try { + auto l0_ctx = sycl::get_native( + queue.get_context()); + for (int i = 0; i < world; i++) { + if (i != rank) { + zeCheck_dynamic(zeMemCloseIpcHandle_dynamic( + l0_ctx, (char*)buffers[i] - offsets[i])); + } + } + } catch (const std::exception& e) { + std::cerr << "Warning: Level Zero cleanup failed: " << e.what() + << std::endl; + } + sycl::free(buffers[rank], queue); + initialized = false; + } + + // buffer_size as element size + void exchange_peer_ipc_mem( + sycl::queue& queue, + void* ptr, + uint32_t rank_in, + uint32_t world_in) { + if (!initialized) + init(queue, rank_in, world_in); + if (!load_level_zero_library()) { + throw std::runtime_error("Level Zero not available"); + } + + // Step 1: Get base address of the pointer + sycl::context ctx = queue.get_context(); + auto l0_ctx = sycl::get_native(ctx); + + void* base_addr; + size_t base_size; + zeCheck_dynamic( + zeMemGetAddressRange_dynamic(l0_ctx, ptr, &base_addr, &base_size)); + + // Step 2: Get IPC mem handle from base address + alignas(64) exchange_contents send_buf; + alignas(64) exchange_contents recv_buf[world]; + + // fill in the exchange info + zeCheck_dynamic( + zeMemGetIpcHandle_dynamic(l0_ctx, base_addr, &send_buf.ipc_handle)); + send_buf.offset = (char*)ptr - (char*)base_addr; + + send_buf.pid = getpid(); + + // Step 3: Exchange the handles and offsets + memset(recv_buf, 0, sizeof(recv_buf)); + // Overkill if we don't really needs all peer's handles + un_allgather(&send_buf, recv_buf, rank, world); + for (uint32_t i = 0; i < world; i++) { + // Step 4: Prepare pid file descriptor of next process + auto* peer = recv_buf + i; + // Step 6: Open IPC handle of remote peer + auto l0_device = sycl::get_native( + queue.get_device()); + void* peer_base; + + zeCheck_dynamic(zeMemOpenIpcHandle_dynamic( + l0_ctx, + l0_device, + peer->ipc_handle, + ZE_IPC_MEMORY_FLAG_BIAS_CACHED, + &peer_base)); + + buffers[i] = (char*)peer_base + peer->offset; + offsets[i] = peer->offset; + ipc_handle[i] = send_buf.ipc_handle; + } + } + + bool initialized; + static constexpr uint32_t max_rank = 16; + void* buffers[max_rank]; + void* sync_buffer[max_rank]; + size_t offsets[max_rank]; + ze_ipc_mem_handle_t ipc_handle[max_rank]; + int rank, world; + int size_per_buffer; + int data_size_per_buffer; + int buffer_index; +}; diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp new file mode 100644 index 0000000000..d49d126122 --- /dev/null +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -0,0 +1,460 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace c10d { +namespace symmetric_memory { + +static StoreExchange storeExchange = StoreExchange("XPUSymmetricMemory"); + +AllocationRef::AllocationRef( + void* ptr, + HandleType handle, + size_t block_size, + int device_idx, + bool local_allocation) + : ptr(ptr), + handle(handle), + block_size(block_size), + device_idx(device_idx), + local_allocation(local_allocation){} + +AllocationRef::~AllocationRef() { + if (is_finalizing()) { + return; + } + // Currently, we cannot free virtual memory exchanged from other device. + if (!local_allocation) { + return; + } + c10::Device local_device(c10::DeviceType::XPU, device_idx); + c10::DeviceGuard guard(local_device); + c10::xpu::syncStreamsOnDevice(); + auto stream = at::xpu::getCurrentXPUStream(); + sycl::free(ptr, stream); +} + +XPUSymmetricMemory::XPUSymmetricMemory( + std::vector> alloc_refs, + std::vector buffers, + std::vector signal_pads, + HandleType mc_handle, + void* mc_addr, + size_t buffer_size, + int local_device_idx, + int rank, + int world_size) + : alloc_refs_(std::move(alloc_refs)), + buffers_(std::move(buffers)), + signal_pads_(std::move(signal_pads)), + mc_handle_(mc_handle), + mc_addr_(mc_addr), + buffer_size_(buffer_size), + local_device_idx_(local_device_idx), + rank_(rank), + world_size_(world_size) { + const size_t arr_size = sizeof(void*) * world_size_; + buffers_dev_ = reinterpret_cast( + c10::xpu::XPUCachingAllocator::raw_alloc(arr_size)); + signal_pads_dev_ = reinterpret_cast( + c10::xpu::XPUCachingAllocator::raw_alloc(arr_size)); + + c10::Device local_device(c10::DeviceType::XPU, local_device_idx); + c10::DeviceGuard guard(local_device); + + at::xpu::getCurrentXPUStream().queue().memcpy( + buffers_dev_, buffers_.data(), arr_size); + at::xpu::getCurrentXPUStream().queue().memcpy( + signal_pads_dev_, signal_pads_.data(), arr_size); +} + +std::vector XPUSymmetricMemory::get_buffer_ptrs() { + return buffers_; +} + +std::vector XPUSymmetricMemory::get_signal_pad_ptrs() { + return signal_pads_; +} + +void** XPUSymmetricMemory::get_buffer_ptrs_dev() { + return buffers_dev_; +} + +void** XPUSymmetricMemory::get_signal_pad_ptrs_dev() { + return signal_pads_dev_; +} + +size_t XPUSymmetricMemory::get_buffer_size() { + return buffer_size_; +} + +size_t XPUSymmetricMemory::get_signal_pad_size() { + return signal_pad_size; +} + +bool XPUSymmetricMemory::has_multicast_support() { + return false; +} + +void* XPUSymmetricMemory::get_multicast_ptr() { + return nullptr; +} + +at::Tensor XPUSymmetricMemory::get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) { + const size_t numel = std::accumulate( + sizes.begin(), + sizes.end(), + static_cast(1), + std::multiplies()); + const auto element_size = c10::elementSize(dtype); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= buffer_size_, + "XPUSymmetricMemory::get_buffer: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + buffer_size_, + " bytes)"); + auto data_ptr = reinterpret_cast(buffers_[rank]) + + storage_offset * element_size; + // check the device of this device buffer + auto ptr_to_device_id = c10::xpu::get_device_idx_from_pointer(data_ptr); + auto device = c10::Device(c10::DeviceType::XPU, ptr_to_device_id); + auto options = at::TensorOptions().dtype(dtype).device(device); + + return at::for_blob(data_ptr, sizes) + .options(options) + .target_device(device) + .make_tensor(); +} + +void check_channel(int channel, int world_size) { + TORCH_CHECK( + channel >= 0, + "channel for barrier(), put_signal() and wait_signal() ", + "must be greater than 0 (got ", + channel, + ")"); + const size_t num_channels = signal_pad_size / sizeof(uint32_t) * world_size; + TORCH_CHECK( + static_cast(channel) < num_channels, + "The maximum supported channel for barrier(), put_signal() and wait_signal() is ", + num_channels - 1, + " (got ", + channel, + ")"); +} + +void XPUSymmetricMemory::barrier(int channel, size_t timeout_ms) { + check_channel(channel, world_size_); + + // Currently, we leverage oneCCL for barrier. Later, we may move to SYCL + // implementation. + auto group = c10d::resolve_process_group(group_name_); + if (group == nullptr) { + TORCH_WARN( + "Process group '", + group_name_, + "' not found, please init process group first before calling SymmetricMemory"); + throw std::runtime_error("Process group not found"); + } + auto* xcclPg = dynamic_cast( + group->getBackend(c10::DeviceType::XPU).get()); + + c10::Device local_device(c10::DeviceType::XPU, local_device_idx_); + c10::DeviceGuard guard(local_device); + + static thread_local at::Tensor barrier_tensor; + if (!barrier_tensor.defined() || barrier_tensor.device() != local_device) { + barrier_tensor = at::zeros( + {1}, at::TensorOptions().device(local_device).dtype(at::kFloat)); + } else { + barrier_tensor.zero_(); + } + + c10d::AllreduceOptions arOpts; + arOpts.asyncOp = false; + auto work = + xcclPg->allreduce_impl(barrier_tensor, "xccl:symm_mem_barrier", arOpts); + + if (work) { + bool success = work->wait(std::chrono::milliseconds(timeout_ms)); + TORCH_CHECK( + success, + "Barrier timeout after ", + timeout_ms, + " ms for group '", + group_name_, + "'"); + } +} + +void XPUSymmetricMemory::put_signal( + int dst_rank, + int channel, + size_t timeout_ms) { + LOG(ERROR) << "XPUSymmetricMemory::put_signal not supported"; +} + +void XPUSymmetricMemory::wait_signal( + int src_rank, + int channel, + size_t timeout_ms) { + LOG(ERROR) << "XPUSymmetricMemory::wait_signal not supported"; +} + +int XPUSymmetricMemory::get_rank() { + return rank_; +} + +int XPUSymmetricMemory::get_world_size() { + return world_size_; +} + +c10::Device XPUSymmetricMemory::get_device() { + return c10::Device(c10::DeviceType::XPU, local_device_idx_); +} + +Block::Block( + c10::intrusive_ptr alloc_ref, + int device_idx, + size_t block_size, + size_t buffer_size, + size_t signal_pad_offset, + const std::optional& group_name) + : alloc_ref(std::move(alloc_ref)), + device_idx(device_idx), + block_size(block_size), + buffer_size(buffer_size), + signal_pad_offset(signal_pad_offset), + default_group_name(std::move(group_name)) {} + +void* XPUSymmetricMemoryAllocator::alloc( + size_t size, + int device_idx, + const std::optional& group_name) { + size_t signal_pad_offset = at::round_up(size, 16UL); + size_t block_size = signal_pad_offset + signal_pad_size; + + sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); + void* ptr = sycl::malloc_device(block_size, current_queue); + current_queue.memset(ptr, 0, block_size); + auto alloc_ref = + c10::make_intrusive(ptr, ptr, block_size, device_idx, true); + auto block = c10::make_intrusive( + std::move(alloc_ref), + device_idx, + block_size, + size, + signal_pad_offset, + group_name); + + { + std::unique_lock lock(mutex_); + ptr_to_block_.emplace(ptr, std::move(block)); + } + return ptr; +} + +void XPUSymmetricMemoryAllocator::free(void* ptr) { + std::unique_lock lock(mutex_); + ptr_to_block_.erase(ptr); +} + +size_t XPUSymmetricMemoryAllocator::get_alloc_size(void* ptr) { + auto block = find_block(ptr); + TORCH_CHECK( + block != nullptr, + "XPUSymmetricMemoryAllocator::get_alloc_size: input must be allocated ", + "via XPUSymmetricMemoryAllocator::alloc"); + return block->buffer_size; +} + +struct RendezvousRequest { + int device_idx; + int pid; + size_t block_size; + size_t buffer_size; + size_t signal_pad_offset; + bool has_multicast_support; +}; + +void validate_rendezvous_requests( + const std::vector& reqs, + int world_size) { + TORCH_CHECK(reqs.size() == (size_t)world_size); + + std::unordered_set device_indices; + device_indices.reserve(world_size); + for (auto req : reqs) { + device_indices.insert(req.device_idx); + } + + for (int r = 1; r < world_size; ++r) { + TORCH_CHECK(reqs[r].block_size == reqs[0].block_size); + TORCH_CHECK(reqs[r].buffer_size == reqs[0].buffer_size); + TORCH_CHECK(reqs[r].signal_pad_offset == reqs[0].signal_pad_offset); + } +} + +c10::intrusive_ptr XPUSymmetricMemoryAllocator::rendezvous( + void* ptr, + const std::optional& group_name) { + auto block = find_block(ptr); + if (block == nullptr) { + return nullptr; + } + + // The group_name passed to rendezvous() takes precedence over + // the default group_name specified during allocation. + std::string group_name_; + // Treat empty string and std::nullopt the same as empty string seems to be + // implicitly used that way + if (group_name.has_value() && group_name != "") { + group_name_ = *group_name; + } else { + if (!block->default_group_name.has_value()) { + TORCH_CHECK( + false, + "XPUSymmetricMemory::rendezvous: `group_name` is neither " + "specified during allocation nor passed to rendezvous()."); + } + group_name_ = *block->default_group_name; + } + + auto it = block->symm_mems.find(group_name_); + if (it != block->symm_mems.end()) { + return it->second; + } + + c10::Device local_device(c10::DeviceType::XPU, block->device_idx); + c10::DeviceGuard guard(local_device); + + // IpcChannel is used to do inter-process communication + IpcChannel ipc_channel; + auto group_info = get_group_info(group_name_); + auto store = group_info.store; + int rank = group_info.rank; + int world_size = group_info.world_size; + sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); + + auto local_req = RendezvousRequest{ + .device_idx = block->device_idx, + .pid = getpid(), + .block_size = block->block_size, + .buffer_size = block->buffer_size, + .signal_pad_offset = block->signal_pad_offset, + .has_multicast_support = false}; + auto reqs = storeExchange.all_gather(store, rank, world_size, local_req); + validate_rendezvous_requests(reqs, world_size); + + std::vector pids(world_size); + for (int r = 0; r < world_size; ++r) { + pids[r] = reqs[r].pid; + } + + // do IPC exchange for all peer ranks + ipc_channel.exchange_peer_ipc_mem(current_queue, ptr, rank, world_size); + + // no physical memory handle, so handles and buffers are both for virtual + // address + std::vector handles(world_size); + std::vector buffers(world_size, nullptr); + std::vector signal_pads(world_size, nullptr); + + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + handles[r] = block->alloc_ref->handle; + buffers[r] = ptr; + signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset); + continue; + } else { + buffers[r] = ipc_channel.buffers[r]; + handles[r] = ipc_channel.buffers[r]; + signal_pads[r] = + (void*)((uintptr_t)buffers[r] + block->signal_pad_offset); + } + } + storeExchange.barrier(store, rank, world_size); + + HandleType mc_handle{}; + void* mc_addr = nullptr; + + std::vector> alloc_refs; + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + alloc_refs.emplace_back(block->alloc_ref); + continue; + } + alloc_refs.push_back(c10::make_intrusive( + buffers[r], handles[r], block->block_size, block->device_idx, false)); + } + + auto symm_mem = c10::make_intrusive( + std::move(alloc_refs), + std::move(buffers), + std::move(signal_pads), + mc_handle, + mc_addr, + block->buffer_size, + block->device_idx, + group_info.rank, + group_info.world_size); + symm_mem->set_group_name(group_name_); + block->symm_mems[group_name_] = symm_mem; + return symm_mem; +} + +bool XPUSymmetricMemoryAllocator::has_multicast_support(int device_idx) { + return false; +} + +c10::DeviceType XPUSymmetricMemoryAllocator::supported_device_type() { + return c10::DeviceType::XPU; +} + +std::string XPUSymmetricMemoryAllocator::name() { + return "XPU"; +} + +c10::intrusive_ptr XPUSymmetricMemoryAllocator::find_block(void* ptr) { + std::shared_lock lock(mutex_); + auto it = ptr_to_block_.find(ptr); + if (it == ptr_to_block_.end()) { + return nullptr; + } + return it->second; +} + +struct RegisterXPUSymmetricMemoryAllocator { + RegisterXPUSymmetricMemoryAllocator() { + auto allocator = c10::make_intrusive(); + // Query backend used for XPU + if (getSymmMemBackendXPU() == "XPU") { + // Direct set (static registration) + register_allocator(c10::DeviceType::XPU, allocator); + } else { + // Register availability in case `set_backend` is called dynamically + register_availability("XPU", allocator); + } + } +}; +static RegisterXPUSymmetricMemoryAllocator register_allocator_; + +} // namespace symmetric_memory +} // namespace c10d diff --git a/src/xccl/XPUSymmetricMemory.hpp b/src/xccl/XPUSymmetricMemory.hpp new file mode 100644 index 0000000000..2daac1114a --- /dev/null +++ b/src/xccl/XPUSymmetricMemory.hpp @@ -0,0 +1,130 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10d::symmetric_memory { + +// Resource wrapper that owns a (vaddr, allocation handle) pair. Upon +// destruction, it unmaps the vaddr and releases the allocation handle. +struct AllocationRef : public c10::intrusive_ptr_target { + void* ptr; + HandleType handle; + size_t block_size; + int device_idx; + bool local_allocation; + + AllocationRef( + void* ptr, + HandleType handle, + size_t block_size, + int device_idx, + bool local_allocation); + + ~AllocationRef(); +}; + +class XPUSymmetricMemory : public SymmetricMemory { + public: + XPUSymmetricMemory( + std::vector> alloc_refs, + std::vector buffers, + std::vector signal_pads, + HandleType mc_handle, + void* mc_addr, + size_t buffer_size, + int local_device_idx, + int rank, + int world_size); + + ~XPUSymmetricMemory() override{}; + + std::vector get_buffer_ptrs() override; + std::vector get_signal_pad_ptrs() override; + void** get_buffer_ptrs_dev() override; + void** get_signal_pad_ptrs_dev() override; + size_t get_buffer_size() override; + size_t get_signal_pad_size() override; + + bool has_multicast_support() override; + void* get_multicast_ptr() override; + + at::Tensor get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset); + + void barrier(int channel, size_t timeout_ms) override; + void put_signal(int dst_rank, int channel, size_t timeout_ms) override; + void wait_signal(int src_rank, int channel, size_t timeout_ms) override; + + int get_rank() override; + int get_world_size() override; + c10::Device get_device() override; + + void set_group_name(const std::string& group_name) { + group_name_ = group_name; + } + + private: + std::vector> alloc_refs_; + std::vector buffers_; + std::vector signal_pads_; + HandleType mc_handle_; + void* mc_addr_; + size_t buffer_size_; + int local_device_idx_; + int rank_; + int world_size_; + void** buffers_dev_; + void** signal_pads_dev_; + std::string group_name_; +}; + +struct Block : public c10::intrusive_ptr_target { + c10::intrusive_ptr alloc_ref; + int device_idx; + size_t block_size; + size_t buffer_size; + size_t signal_pad_offset; + std::optional default_group_name; + std::map> symm_mems; + + Block( + c10::intrusive_ptr alloc_ref, + int device_idx, + size_t block_size, + size_t buffer_size, + size_t signal_pad_offset, + const std::optional& group_name); +}; + +class XPUSymmetricMemoryAllocator : public SymmetricMemoryAllocator { + public: + void* alloc( + size_t size, + int device_idx, + const std::optional& group_name) override; + + void free(void* ptr) override; + size_t get_alloc_size(void* ptr) override; + c10::intrusive_ptr rendezvous( + void* ptr, + const std::optional& group_name) override; + bool has_multicast_support(int device_idx) override; + // void exchange_peer_ipc_mem(sycl::queue& queue, void* ptr); + c10::DeviceType supported_device_type() override; + std::string name() override; + + private: + c10::intrusive_ptr find_block(void* ptr); + + std::shared_mutex mutex_; + std::unordered_map> ptr_to_block_; +}; + +} // namespace c10d::symmetric_memory diff --git a/src/xccl/XPUSymmetricMemoryTypes.hpp b/src/xccl/XPUSymmetricMemoryTypes.hpp new file mode 100644 index 0000000000..4cab3b81f7 --- /dev/null +++ b/src/xccl/XPUSymmetricMemoryTypes.hpp @@ -0,0 +1,8 @@ +#pragma once + +namespace c10d::symmetric_memory { + +constexpr size_t signal_pad_size = 2048; +using HandleType = void*; + +} // namespace c10d::symmetric_memory diff --git a/src/xccl/XPUSymmetricMemoryUtils.cpp b/src/xccl/XPUSymmetricMemoryUtils.cpp new file mode 100644 index 0000000000..7130fe7b6a --- /dev/null +++ b/src/xccl/XPUSymmetricMemoryUtils.cpp @@ -0,0 +1,76 @@ +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace c10d::symmetric_memory { + +std::string getSymmMemBackendXPU() { + static auto val = c10::utils::get_env("TORCH_SYMMMEM"); + if (val.has_value()) { + TORCH_CHECK( + val.value() == "XPU", + "TORCH_SYMMMEM environment variable must be 'XPU'."); + return val.value(); + } + return "XPU"; +} + +bool device_has_multicast_support(int device_idx) { + return false; +} + +bool allow_overlapping_devices() { + return false; +} + +void map_block( + void** ptr, + ze_physical_mem_handle_t handle, + size_t size, + int device_idx) { + sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); + sycl::context sycl_ctx = current_queue.get_context(); + ze_context_handle_t ze_context = + sycl::get_native(sycl_ctx); + // 1. Reserve virtual address space + void* virtual_ptr = nullptr; + ze_result_t status = zeVirtualMemReserve( + ze_context, // context + nullptr, // let L0 pick virtual address + size, // size + &virtual_ptr // out: reserved address + ); + TORCH_CHECK(status == ZE_RESULT_SUCCESS, "zeVirtualMemReserve failed"); + + // 2. Map physical memory to virtual address + status = zeVirtualMemMap( + ze_context, + virtual_ptr, // virtual memory to map to + size, + handle, // physical memory handle + 0, // flags + ZE_MEMORY_ACCESS_ATTRIBUTE_READWRITE // ze_memory_access_attribute_t + ); + TORCH_CHECK(status == ZE_RESULT_SUCCESS, "zeVirtualMemMap failed"); + + // 3. Set access attributes + ze_memory_access_attribute_t access = ZE_MEMORY_ACCESS_ATTRIBUTE_READWRITE; + status = + zeVirtualMemSetAccessAttribute(ze_context, virtual_ptr, size, access); + TORCH_CHECK( + status == ZE_RESULT_SUCCESS, "zeVirtualMemSetAccessAttribute failed"); + + // 4. Return pointer + *ptr = virtual_ptr; +} + +} // namespace c10d::symmetric_memory diff --git a/src/xccl/XPUSymmetricMemoryUtils.hpp b/src/xccl/XPUSymmetricMemoryUtils.hpp new file mode 100644 index 0000000000..69189f45cf --- /dev/null +++ b/src/xccl/XPUSymmetricMemoryUtils.hpp @@ -0,0 +1,89 @@ +#pragma once +#include +#include +#include + +namespace c10d { +namespace symmetric_memory { + +std::string getSymmMemBackendXPU(); + +bool device_has_multicast_support(int device_idx); + +bool allow_overlapping_devices(); + +// A set of store-based exchange methods with a preset prefix typically type of +// the SymmetricMemory. Most used as static instances at respective +// SymmetricMemory implementation files. +class StoreExchange { + public: + StoreExchange(const std::string& store_prefix) + : store_prefix_(store_prefix) {} + + // Put template function in header file so that compiler can easily access it. + template + std::vector all_gather( + const c10::intrusive_ptr& store, + int rank, + int world_size, + T val) { + static_assert(std::is_trivially_copyable_v); + + std::vector peer_keys; + peer_keys.reserve(world_size); + for (int r = 0; r < world_size; ++r) { + std::ostringstream oss; + oss << store_prefix_ << "/" << seq_id_ << "/" << r; + peer_keys.push_back(oss.str()); + } + ++seq_id_; + + { + std::vector payload( + reinterpret_cast(&val), + reinterpret_cast(&val) + sizeof(T)); + store->set(peer_keys[rank], payload); + } + + std::vector peer_vals; + peer_vals.reserve(world_size); + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + peer_vals.push_back(val); + continue; + } + store->wait({peer_keys[r]}); + auto payload = store->get(peer_keys[r]); + TORCH_CHECK(payload.size() == sizeof(T)); + T peer_val{}; + std::memcpy(&peer_val, payload.data(), sizeof(T)); + peer_vals.push_back(peer_val); + } + return peer_vals; + } + + void barrier( + const c10::intrusive_ptr& store, + int rank, + int world_size) { + // TODO: implement an efficient one? + all_gather(store, rank, world_size, 0); + } + + private: + const std::string store_prefix_; + size_t seq_id_ = 0; +}; + +// Returns a pointer of virtual address that is mapped to the physical memory +// held by the handle. +// todo: will follow such physical memory handle map with virtual address, +// when L0 provides physical handle exchange API and we have multicast support. +void map_block( + void** ptr, + ze_physical_mem_handle_t handle, + size_t size, + int device_idx); + +} // namespace symmetric_memory +} // namespace c10d diff --git a/src/xccl/ze_symbol.hpp b/src/xccl/ze_symbol.hpp new file mode 100644 index 0000000000..20af666811 --- /dev/null +++ b/src/xccl/ze_symbol.hpp @@ -0,0 +1,254 @@ +#pragma once + +#include +#include +#include +#include + +#define zeVirtualMemMap zeVirtualMemMap_original +#define zeVirtualMemReserve zeVirtualMemReserve_original +#define zeVirtualMemSetAccessAttribute zeVirtualMemSetAccessAttribute_original + +#include + +#undef zeVirtualMemMap +#undef zeVirtualMemReserve +#undef zeVirtualMemSetAccessAttribute + +typedef ze_result_t (*zeInit_t)(ze_init_flags_t flags); +typedef ze_result_t (*zeMemGetAddressRange_t)( + ze_context_handle_t hContext, + const void* ptr, + void** pBase, + size_t* pSize); +typedef ze_result_t (*zeMemGetIpcHandle_t)( + ze_context_handle_t hContext, + const void* ptr, + ze_ipc_mem_handle_t* pIpcHandle); +typedef ze_result_t (*zeMemOpenIpcHandle_t)( + ze_context_handle_t hContext, + ze_device_handle_t hDevice, + ze_ipc_mem_handle_t handle, + ze_ipc_memory_flags_t flags, + void** pptr); +typedef ze_result_t ( + *zeMemCloseIpcHandle_t)(ze_context_handle_t hContext, const void* ptr); +typedef ze_result_t (*zeVirtualMemMap_t)( + ze_context_handle_t hContext, + const void* ptr, + size_t size, + ze_physical_mem_handle_t hPhysicalMemory, + size_t offset, + ze_memory_access_attribute_t access); +typedef ze_result_t (*zeVirtualMemReserve_t)( + ze_context_handle_t hContext, + const void* pStart, + size_t size, + void** pptr); +typedef ze_result_t (*zeVirtualMemSetAccessAttribute_t)( + ze_context_handle_t hContext, + const void* ptr, + size_t size, + ze_memory_access_attribute_t access); + +bool load_level_zero_library(); +void unload_level_zero_library(); + +#define zeCheck_dynamic(x) \ + do { \ + if (!load_level_zero_library()) { \ + throw std::runtime_error("Level Zero library not available"); \ + } \ + ze_result_t result = (x); \ + if (result != ZE_RESULT_SUCCESS) { \ + auto e = zeException(result); \ + std::cout << "Throw " << e.what() << std::endl; \ + throw e; \ + } \ + } while (0) + +#define zeInit_dynamic(flags) zeInit_ptr(flags) +#define zeMemGetAddressRange_dynamic(ctx, ptr, base, size) \ + zeMemGetAddressRange_ptr(ctx, ptr, base, size) +#define zeMemGetIpcHandle_dynamic(ctx, ptr, handle) \ + zeMemGetIpcHandle_ptr(ctx, ptr, handle) +#define zeMemOpenIpcHandle_dynamic(ctx, dev, handle, flags, ptr) \ + zeMemOpenIpcHandle_ptr(ctx, dev, handle, flags, ptr) +#define zeMemCloseIpcHandle_dynamic(ctx, ptr) zeMemCloseIpcHandle_ptr(ctx, ptr) +#define zeVirtualMemMap_dynamic(ctx, ptr, size, phys_mem, offset, access) \ + zeVirtualMemMap_ptr(ctx, ptr, size, phys_mem, offset, access) +#define zeVirtualMemReserve_dynamic(ctx, start, size, ptr) \ + zeVirtualMemReserve_ptr(ctx, start, size, ptr) +#define zeVirtualMemSetAccessAttribute_dynamic(ctx, ptr, size, access) \ + zeVirtualMemSetAccessAttribute_ptr(ctx, ptr, size, access) + +// Exception handling class +class zeException : std::exception { + const char* zeResultToString(ze_result_t status) const { + static const std::unordered_map zeResultToStringMap{ + {ZE_RESULT_SUCCESS, "[Core] success"}, + {ZE_RESULT_NOT_READY, "[Core] synchronization primitive not signaled"}, + {ZE_RESULT_ERROR_UNINITIALIZED, + "[Validation] driver is not initialized"}, + {ZE_RESULT_ERROR_INVALID_NULL_POINTER, + "[Validation] pointer argument may not be nullptr"}, + {ZE_RESULT_ERROR_INVALID_NULL_HANDLE, + "[Validation] handle argument is not valid"}, + {ZE_RESULT_ERROR_INVALID_ENUMERATION, + "[Validation] enumerator argument is not valid"}, + {ZE_RESULT_ERROR_INVALID_SIZE, "[Validation] size argument is invalid"}, + {ZE_RESULT_ERROR_UNSUPPORTED_SIZE, + "[Validation] size argument is not supported by the device"}, + {ZE_RESULT_ERROR_UNSUPPORTED_ALIGNMENT, + "[Validation] alignment argument is not supported by the device"}, + {ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, + "[Validation] generic error code for unsupported features"}, + {ZE_RESULT_ERROR_INVALID_NATIVE_BINARY, + "[Validation] native binary is not supported by the device"}, + {ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY, + "[Core] insufficient host memory to satisfy call"}, + {ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY, + "[Core] insufficient device memory to satisfy call"}, + {ZE_RESULT_ERROR_DEVICE_LOST, + "[Core] device hung, reset, was removed, or driver update occurred"}, + {ZE_RESULT_ERROR_MODULE_BUILD_FAILURE, + "[Core] error occurred when building module, see build log for details"}, + {ZE_RESULT_ERROR_HANDLE_OBJECT_IN_USE, + "[Validation] object pointed to by handle still in-use by device"}, + }; + auto it = zeResultToStringMap.find(status); + if (it != zeResultToStringMap.end()) + return it->second; + else + return "Unknown Reason"; + } + + public: + zeException(ze_result_t ret) : result_(ret) {} + + ze_result_t result_; + + const char* what() const noexcept override { + return zeResultToString(result_); + } +}; + +#define zeCheck(x) \ + if (x != ZE_RESULT_SUCCESS) { \ + auto e = zeException(x); \ + std::cout << "Throw " << e.what() << std::endl; \ + throw e; \ + } + +static zeInit_t zeInit_ptr = nullptr; +static zeMemGetAddressRange_t zeMemGetAddressRange_ptr = nullptr; +static zeMemGetIpcHandle_t zeMemGetIpcHandle_ptr = nullptr; +static zeMemOpenIpcHandle_t zeMemOpenIpcHandle_ptr = nullptr; +static zeMemCloseIpcHandle_t zeMemCloseIpcHandle_ptr = nullptr; +static zeVirtualMemMap_t zeVirtualMemMap_ptr = nullptr; +static zeVirtualMemReserve_t zeVirtualMemReserve_ptr = nullptr; +static zeVirtualMemSetAccessAttribute_t zeVirtualMemSetAccessAttribute_ptr = + nullptr; + +static void* ze_handle = nullptr; + +inline bool load_level_zero_library() { + if (ze_handle != nullptr) { + return true; + } + const char* lib_names[] = {"libze_loader.so"}; + + for (const char* lib_name : lib_names) { + ze_handle = dlopen(lib_name, RTLD_LAZY); + if (ze_handle != nullptr) { + break; + } + } + + if (ze_handle == nullptr) { + std::cerr << "Failed to load Level Zero library: " << dlerror() + << std::endl; + return false; + } + + zeInit_ptr = (zeInit_t)dlsym(ze_handle, "zeInit"); + zeMemGetAddressRange_ptr = + (zeMemGetAddressRange_t)dlsym(ze_handle, "zeMemGetAddressRange"); + zeMemGetIpcHandle_ptr = + (zeMemGetIpcHandle_t)dlsym(ze_handle, "zeMemGetIpcHandle"); + zeMemOpenIpcHandle_ptr = + (zeMemOpenIpcHandle_t)dlsym(ze_handle, "zeMemOpenIpcHandle"); + zeMemCloseIpcHandle_ptr = + (zeMemCloseIpcHandle_t)dlsym(ze_handle, "zeMemCloseIpcHandle"); + zeVirtualMemMap_ptr = (zeVirtualMemMap_t)dlsym(ze_handle, "zeVirtualMemMap"); + zeVirtualMemReserve_ptr = + (zeVirtualMemReserve_t)dlsym(ze_handle, "zeVirtualMemReserve"); + zeVirtualMemSetAccessAttribute_ptr = (zeVirtualMemSetAccessAttribute_t)dlsym( + ze_handle, "zeVirtualMemSetAccessAttribute"); + + if (!zeInit_ptr || !zeMemGetAddressRange_ptr || !zeMemGetIpcHandle_ptr || + !zeMemOpenIpcHandle_ptr || !zeMemCloseIpcHandle_ptr || + !zeVirtualMemMap_ptr || !zeVirtualMemReserve_ptr || + !zeVirtualMemSetAccessAttribute_ptr) { + std::cerr << "Failed to load Level Zero API functions" << std::endl; + dlclose(ze_handle); + ze_handle = nullptr; + return false; + } + + return true; +} + +inline void unload_level_zero_library() { + if (ze_handle != nullptr) { + dlclose(ze_handle); + ze_handle = nullptr; + zeInit_ptr = nullptr; + zeMemGetAddressRange_ptr = nullptr; + zeMemGetIpcHandle_ptr = nullptr; + zeMemOpenIpcHandle_ptr = nullptr; + zeMemCloseIpcHandle_ptr = nullptr; + zeVirtualMemMap_ptr = nullptr; + zeVirtualMemReserve_ptr = nullptr; + zeVirtualMemSetAccessAttribute_ptr = nullptr; + } +} + +extern "C" { + +__attribute__((weak)) ze_result_t zeVirtualMemMap( + ze_context_handle_t hContext, + const void* ptr, + size_t size, + ze_physical_mem_handle_t hPhysicalMemory, + size_t offset, + ze_memory_access_attribute_t access) { + if (!load_level_zero_library() || !zeVirtualMemMap_ptr) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + return zeVirtualMemMap_ptr( + hContext, ptr, size, hPhysicalMemory, offset, access); +} + +__attribute__((weak)) ze_result_t zeVirtualMemReserve( + ze_context_handle_t hContext, + const void* pStart, + size_t size, + void** pptr) { + if (!load_level_zero_library() || !zeVirtualMemReserve_ptr) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + return zeVirtualMemReserve_ptr(hContext, pStart, size, pptr); +} + +__attribute__((weak)) ze_result_t zeVirtualMemSetAccessAttribute( + ze_context_handle_t hContext, + const void* ptr, + size_t size, + ze_memory_access_attribute_t access) { + if (!load_level_zero_library() || !zeVirtualMemSetAccessAttribute_ptr) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + return zeVirtualMemSetAccessAttribute_ptr(hContext, ptr, size, access); +} +} diff --git a/test/xpu/distributed/test_symmetric_memory_xccl.py b/test/xpu/distributed/test_symmetric_memory_xccl.py new file mode 100644 index 0000000000..37f5d3e6da --- /dev/null +++ b/test/xpu/distributed/test_symmetric_memory_xccl.py @@ -0,0 +1,85 @@ +import torch +import torch.distributed as dist +from test_c10d_xccl import init_multigpu_helper, requires_xccl +from torch.distributed._symmetric_memory import ( + _fused_all_gather_matmul_fallback, + _fused_matmul_reduce_scatter_fallback, +) + +from torch.testing._internal.common_distributed import MultiProcContinuousTest +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests +) + +@instantiate_parametrized_tests +class AsyncTPTest(MultiProcContinuousTest): + @property + def device(self) -> torch.device: + return torch.device("xpu", self.rank) + + def _init_process(self): + torch.xpu.set_device(self.device) + torch.manual_seed(42 + self.rank) + torch.use_deterministic_algorithms(True) + torch.set_deterministic_debug_mode("warn") + torch.utils.deterministic.fill_uninitialized_memory = True + + @requires_xccl() + @parametrize("gather_dim", [0, 1]) + def test_fused_all_gather_matmul(self, gather_dim: int) -> None: + self._init_process() + BATCH = 8 + M = 64 + N = 16 + K = 32 + group = dist.group.WORLD + rank = self.rank + + torch.manual_seed(42 + rank) + A_shard = torch.rand(BATCH, M // self.world_size, K, device="xpu") + Bs = [torch.rand(K, N, device="xpu") for _ in range(3)] + + ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback( + A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name + ) + ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_matmul( + A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name + ) + + self.assertEqual(ag_output_0, ag_output_1) + self.assertEqual(ag_output_0.stride(), ag_output_1.stride()) + for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1): + self.assertEqual(mm_output_0, mm_output_1) + self.assertEqual(mm_output_0.stride(), mm_output_1.stride()) + + @requires_xccl() + @parametrize("scatter_dim", [0, 1]) + def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None: + self._init_process() + + BATCH = 8 + M = 64 + N = 16 + K = 32 + group = dist.group.WORLD + rank = self.rank + + torch.manual_seed(42 + rank) + A = torch.rand(BATCH, M, K, device="xpu") + B = torch.rand(K, N, device="xpu") + + output_0 = _fused_matmul_reduce_scatter_fallback( + A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name + ) + output_1 = torch.ops.symm_mem.fused_matmul_reduce_scatter( + A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name + ) + + self.assertEqual(output_0, output_1) + self.assertEqual(output_0.stride(), output_1.stride()) + + +if __name__ == "__main__": + run_tests() From 9563c00a007ae5d80ffe74e6187fab9b3db0d8e4 Mon Sep 17 00:00:00 2001 From: Cherry Zhang Date: Fri, 27 Feb 2026 18:30:42 +0800 Subject: [PATCH 02/25] remove atomic_ref --- src/xccl/Signal.hpp | 152 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 src/xccl/Signal.hpp diff --git a/src/xccl/Signal.hpp b/src/xccl/Signal.hpp new file mode 100644 index 0000000000..a53f1f7c15 --- /dev/null +++ b/src/xccl/Signal.hpp @@ -0,0 +1,152 @@ +#pragma once + +#include + +#include +#include + +namespace c10d::symmetric_memory { + +using at::native::memory::get_alignment; + +// ============================================================================= +// Signal primitives using store/load + atomic_fence +// (sycl::atomic_ref is not supported, use explicit fence instead) +// ============================================================================= + +// Store value with release fence (for put_signal) +// Order: store first, then release fence to flush the store +inline void store_release(uint32_t* addr, uint32_t val) { + *addr = val; + sycl::atomic_fence(sycl::memory_order::release, sycl::memory_scope::system); +} + +// Load value with acquire fence (for get_signal/wait_signal) +// Order: acquire fence first, then load to see the latest value +inline uint32_t load_acquire(uint32_t* addr) { + sycl::atomic_fence(sycl::memory_order::acquire, sycl::memory_scope::system); + uint32_t val = *addr; + return val; +} + +inline size_t global_timer_ns() { + auto now = std::chrono::high_resolution_clock::now(); + return std::chrono::duration_cast( + now.time_since_epoch()) + .count(); +} + +constexpr size_t ns_per_ms = 1e6; + +// ============================================================================= +// Put signal: wait until addr == 0, then set to 1 (release semantics) +// ============================================================================= + +// Device-compatible version using iteration count +template +bool try_put_signal_device(uint32_t* addr, size_t max_iterations = 1000) { + size_t iterations = 0; + // Wait until the slot is free (value == 0) + while (load_acquire(addr) != 0) { + if (max_iterations != 0 && iterations++ > max_iterations) { + return false; + } + } + // Set signal to 1 with release semantics + store_release(addr, 1); + return true; +} + +// Host version using timeout +template +bool try_put_signal(uint32_t* addr, size_t timeout_ms) { + size_t deadline = global_timer_ns() + timeout_ms * ns_per_ms; + // Wait until the slot is free (value == 0) + while (load_acquire(addr) != 0) { + if (timeout_ms != 0 && global_timer_ns() > deadline) { + return false; + } + } + // Set signal to 1 with release semantics + store_release(addr, 1); + return true; +} + +// Blocking version +template +void put_signal(uint32_t* addr) { + // Wait until the slot is free (value == 0) + while (load_acquire(addr) != 0) + ; + // Set signal to 1 with release semantics + store_release(addr, 1); +} + +// ============================================================================= +// Wait signal: wait until addr == 1, then set to 0 (acquire semantics) +// ============================================================================= + +// Device-compatible version using iteration count +template +bool try_wait_signal_device(uint32_t* addr, size_t max_iterations = 1000) { + size_t iterations = 0; + // Wait until signal is set (value == 1) + while (load_acquire(addr) != 1) { + // Spin wait (no timeout check to avoid early exit) + continue; + } + // Clear signal to 0 with release semantics + store_release(addr, 0); + return true; +} + +// Host version using timeout +template +bool try_wait_signal(uint32_t* addr, size_t timeout_ms) { + size_t deadline = global_timer_ns() + timeout_ms * ns_per_ms; + // Wait until signal is set (value == 1) + while (load_acquire(addr) != 1) { + // Spin wait (no timeout check to avoid early exit) + continue; + } + // Clear signal to 0 with release semantics + store_release(addr, 0); + return true; +} + +// Blocking version +template +void wait_signal(uint32_t* addr) { + // Wait until signal is set (value == 1) + while (load_acquire(addr) != 1) + ; + // Clear signal to 0 with release semantics + store_release(addr, 0); +} + +void barrier_impl_xpu( + uint32_t** signal_pads, + int channel, + int rank, + int world_size, + size_t timeout_ms, + at::xpu::XPUStream& stream); + +void put_signal_impl_xpu( + uint32_t** signal_pads, + int dst_rank, + int channel, + int rank, + int world_size, + size_t timeout_ms, + at::xpu::XPUStream& stream); + +void wait_signal_impl_xpu( + uint32_t** signal_pads, + int src_rank, + int channel, + int rank, + int world_size, + size_t timeout_ms, + at::xpu::XPUStream& stream); +} // namespace c10d::symmetric_memory From 3f68da3517cf947c83cc0c456d2970967a0c64c1 Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Mon, 8 Dec 2025 14:42:06 +0800 Subject: [PATCH 03/25] symm both xpu and ishmem --- src/xccl/CMakeLists.txt | 2 + src/xccl/ISHMEMSymmetricMemory.cpp | 450 +++++++++++++++++++++++++++ src/xccl/IpcExchange.hpp | 4 +- src/xccl/Signal.cpp | 201 ++++++++++++ src/xccl/XPUSymmetricMemory.cpp | 136 +++++--- src/xccl/XPUSymmetricMemoryUtils.cpp | 219 ++++++++++--- src/xccl/XPUSymmetricMemoryUtils.hpp | 55 +++- 7 files changed, 957 insertions(+), 110 deletions(-) create mode 100644 src/xccl/ISHMEMSymmetricMemory.cpp create mode 100644 src/xccl/Signal.cpp diff --git a/src/xccl/CMakeLists.txt b/src/xccl/CMakeLists.txt index 74be949cc8..0bd675511d 100644 --- a/src/xccl/CMakeLists.txt +++ b/src/xccl/CMakeLists.txt @@ -11,9 +11,11 @@ file(GLOB xccl_h "*.hpp") file(GLOB xccl_cpp "*.cpp") list(REMOVE_ITEM xccl_cpp "${CMAKE_CURRENT_SOURCE_DIR}/NanCheck_XPU.cpp") +list(REMOVE_ITEM xccl_cpp "${CMAKE_CURRENT_SOURCE_DIR}/Signal.cpp") list(APPEND ATen_XPU_XCCL_SRCS ${xccl_cpp}) list(APPEND ATen_XPU_SYCL_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/NanCheck_XPU.cpp") +list(APPEND ATen_XPU_SYCL_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/Signal.cpp") set(ATen_XPU_XCCL_SRCS ${ATen_XPU_XCCL_SRCS} PARENT_SCOPE) set(ATen_XPU_SYCL_SRCS ${ATen_XPU_SYCL_SRCS} PARENT_SCOPE) diff --git a/src/xccl/ISHMEMSymmetricMemory.cpp b/src/xccl/ISHMEMSymmetricMemory.cpp new file mode 100644 index 0000000000..9ea6a67918 --- /dev/null +++ b/src/xccl/ISHMEMSymmetricMemory.cpp @@ -0,0 +1,450 @@ +#include +#include "XPUSymmetricMemoryUtils.hpp" + +#include +#include +#include +#include +// Include ISHMEM headers - directly link to static library +#include +#include + +namespace c10d { +namespace symmetric_memory { + +/* Start of ISHMEMSymmetricMemory implementation */ + +// XPU-specific constants for symmetric memory +// Intel Data Center GPU Max can support up to 8 GPUs in a single node +constexpr int max_xpu_p2p_domain_size = 8; +// Maximum number of channels (same as CUDA) +constexpr int xpu_symm_max_nblocks = 32; +// Signal pad size for XPU +constexpr size_t xpu_signal_pad_size = + xpu_symm_max_nblocks * max_xpu_p2p_domain_size * sizeof(uint32_t); + +static StoreExchange storeExchange = StoreExchange("ISHMEMSymmetricMemory"); + +struct ISHMEMAllocation { + void* ptr; + size_t buffer_size; + int device_idx; + + ISHMEMAllocation(void* ptr, size_t buffer_size, int device_idx) + : ptr(ptr), buffer_size(buffer_size), device_idx(device_idx) {} + + ~ISHMEMAllocation() { + // Avoid calling XPU functions after driver shutting down + if (is_finalizing()) { + return; + } + c10::OptionalDeviceGuard guard; + guard.reset_device(at::Device(at::DeviceType::XPU, device_idx)); + ishmem_free(ptr); + } +}; + +// A class to hold the base pointers and signal pad pointers for a group of +// peers. One `ISHMEMPeerAllocInfo` object can be shared by multiple +// `ISHMEMSymmetricMemory` objects when latter reside on the same allocation +// and rendezvous over the same group. (The `ISHMEMSymmetricMemory` objects may +// have different offsets compared to the base address.) +class ISHMEMPeerAllocInfo : public c10::intrusive_ptr_target { + public: + ISHMEMPeerAllocInfo( + ISHMEMAllocation* allocation, + const std::string& group_name) + : base_ptr_(allocation->ptr), buffer_size_(allocation->buffer_size) { + // For logging only + static int exchanged_n_times = 0; + + c10::OptionalDeviceGuard guard; + guard.reset_device(at::Device(at::DeviceType::XPU, allocation->device_idx)); + + auto global_rank = get_group_info("0").rank; + GroupInfo& group_info = get_group_info(group_name); + auto store = group_info.store; + rank_ = group_info.rank; + world_size_ = group_info.world_size; + // Exchange rank to global rank mapping for this group. + // If it is already available, skip the exchange. + if (group_info.rank_to_global_rank.empty()) { + group_info.rank_to_global_rank = + storeExchange.all_gather(store, rank_, world_size_, global_rank); + exchanged_n_times++; + if (rank_ == 0) { + LOG(INFO) << "[rank " << rank_ << ']' + << " rank_to_global_rank: " << group_info.rank_to_global_rank + << ", group_name: " << group_name + << ", exchanged_n_times: " << exchanged_n_times; + } + } + TORCH_INTERNAL_ASSERT(!group_info.rank_to_global_rank.empty()); + rank_to_global_rank_ = group_info.rank_to_global_rank; + + world_within_xpu_p2p_ = true; + for (int r = 0; r < world_size_; ++r) { + auto peer_ptr = ishmem_ptr(base_ptr_, rank_to_global_rank_[r]); + buffers_.push_back(peer_ptr); + // If a peer is over network, `ishmem_ptr` returns null + if (peer_ptr == nullptr) { + world_within_xpu_p2p_ = false; + } + } + + // TODO: use the same allocation for signal pad + void* signal_pad_ptr = ishmem_malloc(xpu_signal_pad_size); + TORCH_CHECK(signal_pad_ptr != nullptr, "ishmem_malloc failed"); + + // Use SYCL queue to initialize signal pad memory + auto& queue = at::xpu::getCurrentSYCLQueue(); + queue.memset(signal_pad_ptr, 0, xpu_signal_pad_size).wait(); + + for (int r = 0; r < world_size_; ++r) { + signal_pads_.push_back( + ishmem_ptr(signal_pad_ptr, rank_to_global_rank_[r])); + } + + const size_t arr_size = sizeof(void*) * world_size_; + buffers_dev_ = reinterpret_cast( + c10::xpu::XPUCachingAllocator::raw_alloc(arr_size)); + signal_pads_dev_ = reinterpret_cast( + c10::xpu::XPUCachingAllocator::raw_alloc(arr_size)); + + queue.memcpy(buffers_dev_, buffers_.data(), arr_size).wait(); + queue.memcpy(signal_pads_dev_, signal_pads_.data(), arr_size).wait(); + + rank_to_global_rank_dev_ = reinterpret_cast( + c10::xpu::XPUCachingAllocator::raw_alloc(sizeof(int) * world_size_)); + queue + .memcpy( + rank_to_global_rank_dev_, + rank_to_global_rank_.data(), + sizeof(int) * world_size_) + .wait(); + } + + private: + void* base_ptr_; + size_t buffer_size_; + int rank_; + int world_size_; + std::vector buffers_; + std::vector signal_pads_; + void** buffers_dev_; + void** signal_pads_dev_; + std::vector rank_to_global_rank_; + int* rank_to_global_rank_dev_; + // Whether the world is within XPU P2P only, not network + bool world_within_xpu_p2p_; + + friend class ISHMEMSymmetricMemory; +}; + +class ISHMEMSymmetricMemory : public SymmetricMemory { + public: + ISHMEMSymmetricMemory( + ISHMEMAllocation* allocation, + const std::string& group_name) + : device_idx_(allocation->device_idx), group_name_(group_name) { + // A handle stores two types of info: + // (i) allocation's base ptrs and base signal pads, ours and peers' + pai_ = c10::make_intrusive(allocation, group_name); + // (ii) offset of tensor compared to base ptr (in byte) + offset_ = 0; + } + + // Exact copy is not needed / supported + ISHMEMSymmetricMemory(const ISHMEMSymmetricMemory& other) = delete; + + // Copy with offset is allowed + // This is mostly a shallow copy that shares the pointer to + // `ISHMEMPeerAllocInfo` which has been created by `other` + ISHMEMSymmetricMemory(const ISHMEMSymmetricMemory& other, size_t offset) + : device_idx_(other.device_idx_), + group_name_(other.group_name_), + pai_(other.pai_) { + offset_ = offset; + } + + ~ISHMEMSymmetricMemory() override{ + // TODO + }; + + std::vector get_buffer_ptrs() override { + return pai_->buffers_; + } + + std::vector get_signal_pad_ptrs() override { + return pai_->signal_pads_; + } + + void** get_buffer_ptrs_dev() override { + return pai_->buffers_dev_; + } + + void** get_signal_pad_ptrs_dev() override { + return pai_->signal_pads_dev_; + } + + size_t get_buffer_size() override { + return pai_->buffer_size_; + } + + size_t get_signal_pad_size() override { + return xpu_signal_pad_size; + }; + + bool has_multicast_support() override { + // ISHMEM does not have multicast support + return false; + } + + void* get_multicast_ptr() override { + // ISHMEM does not have multicast support + return nullptr; + } + + size_t get_offset() override { + return offset_; + } + + void barrier(int channel, size_t timeout_ms) override { + // Use ISHMEM barrier + ishmem_barrier_all(); + } + + void put_signal(int dst_rank, int channel, size_t timeout_ms) override { + // TODO: Implement signal mechanism for ISHMEM + // ISHMEM uses different signaling approach than NVSHMEM + } + + void wait_signal(int src_rank, int channel, size_t timeout_ms) override { + // TODO: Implement signal mechanism for ISHMEM + } + + int get_rank() override { + return pai_->rank_; + } + + int get_world_size() override { + return pai_->world_size_; + } + + c10::Device get_device() override { + return c10::Device(c10::DeviceType::XPU, device_idx_); + } + + const std::vector& get_rank_to_global_rank() override { + return pai_->rank_to_global_rank_; + }; + + int* get_rank_to_global_rank_dev() override { + return pai_->rank_to_global_rank_dev_; + }; + + bool world_within_direct_access() override { + return pai_->world_within_xpu_p2p_; + } + + private: + int device_idx_; + std::string group_name_; + c10::intrusive_ptr pai_; + size_t offset_{0}; // in byte +}; + +static void initialize_ishmem_with_store( + c10::intrusive_ptr store, + int rank, + int world_size, + int device_idx) { + static bool is_initialized = false; + if (is_initialized) { + return; + } + + c10::OptionalDeviceGuard guard; + guard.reset_device(at::Device(at::DeviceType::XPU, device_idx)); + + ishmemx_uniqueid_t unique_id; + if (rank == 0) { + // Root rank generates the unique ID + int ret = ishmemx_get_uniqueid(&unique_id); + TORCH_CHECK(ret == 0, "ishmemx_get_uniqueid failed with error: ", ret); + } + + auto unique_ids = + storeExchange.all_gather(store, rank, world_size, unique_id); + + // Initialize ISHMEM with attributes using unique ID + ishmemx_attr_t attr; + attr.initialize_runtime = false; // MPI/OpenSHMEM backend already initialized + attr.use_uid = true; + attr.nranks = world_size; + attr.uid = &unique_ids[0]; + + // ishmemx_init_attr returns void, not int + ishmemx_init_attr(&attr); + // Verify initialization succeeded by checking PE info + TORCH_CHECK( + ishmem_my_pe() == rank, + "ISHMEM initialization failed: rank mismatch, expected ", + rank, + " got ", + ishmem_my_pe()); + + is_initialized = true; + + // Print version + int major, minor; + ishmem_info_get_version(&major, &minor); + LOG(INFO) << "ISHMEM initialized with unique ID, version: " << major << '.' + << minor << ", rank: " << rank << "/" << world_size; +} + +class ISHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { + public: + void* alloc( + size_t size, + int device_idx, + const std::optional& group_name) override { + // Note: group_name may be passed but is ignored for ISHMEM allocations + // ISHMEM uses group "0" for all allocations + c10::OptionalDeviceGuard guard; + guard.reset_device(at::Device(at::DeviceType::XPU, device_idx)); + + auto group_info = get_group_info("0"); + auto store = group_info.store; + int rank = group_info.rank; + int world_size = group_info.world_size; + + initialize_ishmem_with_store(store, rank, world_size, device_idx); + auto ptr = ishmem_malloc(size); + // If size is 0 (which is legal allocation request) we shouldn't error out + TORCH_CHECK(ptr != nullptr || size == 0, "ishmem_malloc failed"); + // TODO: thread safety + allocations_.try_emplace( + ptr, std::make_unique(ptr, size, device_idx)); + return ptr; + } + + void free(void* ptr) override { + // TODO: thread safety + allocations_.erase(ptr); + }; + + size_t get_alloc_size(void* ptr) override { + auto it = allocations_.find(ptr); + if (it == allocations_.end()) { + TORCH_CHECK( + false, ptr, " is not allocated with ISHMEMSymmetricMemoryAllocator"); + } + return it->second->buffer_size; + }; + + c10::intrusive_ptr rendezvous( + void* ptr, + const std::optional& group_name) override { + // Use WORLD group (name "0") if group_name is not provided + std::string actual_group_name; + if (group_name.has_value()) { + actual_group_name = *group_name; + } else { + // Default to group "0" (WORLD) for ISHMEM + actual_group_name = "0"; + } + + { + auto it = symm_mems_.find(std::make_tuple(ptr, actual_group_name)); + if (it != symm_mems_.end()) { + return it->second; + } + } + // In case of MemPool, tensor.storage().data_ptr() may not match + // exactly an allocation's base address. Thus we perform the search by + // testing if the former is within an allocation's range. + auto alloc_it = std::find_if( + allocations_.begin(), allocations_.end(), [&](const auto& pair) { + auto& allocation = pair.second; + auto ptr_int = reinterpret_cast(ptr); + auto base_ptr = reinterpret_cast(allocation->ptr); + return ptr_int >= base_ptr && + ptr_int < base_ptr + allocation->buffer_size; + }); + TORCH_CHECK( + alloc_it != allocations_.end(), + "Pointer not within any SymmetricMemory allocation, " + "is the tensor allocated from SymmetricMemory?"); + + auto& allocation = alloc_it->second; + + // Search again using allocation base ptr (which is the key we use for + // caching, see below) + auto it = + symm_mems_.find(std::make_tuple(allocation->ptr, actual_group_name)); + c10::intrusive_ptr symm_mem; + if (it != symm_mems_.end()) { + // Base allocation has been rendezvoused + symm_mem = it->second; + } else { + // Create a new rendezvous + symm_mem = c10::make_intrusive( + allocation.get(), actual_group_name); + } + + // Cache rendezvous using allocation's base address as key + symm_mems_[std::make_tuple(allocation->ptr, actual_group_name)] = symm_mem; + + // TODO: change the `ptr` below to `tensor.data_ptr()` when adding support + // for user slice/view operations. For MemPool support, + // `tensor.storate().data_ptr()` is fine (today's `ptr`). + + // If the tensor's ptr happen to be the same as allocation ptr + if (ptr == allocation->ptr) { + return symm_mem; + } else { + // Return a copy of the SymmetricMemory with an offset. This is a + // "shallow" copy adjusting the offset field in the handle. + return c10::make_intrusive( + *symm_mem, (uintptr_t)ptr - (uintptr_t)allocation->ptr); + } + }; + + bool has_multicast_support(int device_idx) override { + // ISHMEM does not have multicast support + return false; + }; + + c10::DeviceType supported_device_type() override { + return c10::DeviceType::XPU; + } + + std::string name() override { + return "ISHMEM"; + } + + private: + std::unordered_map> allocations_; + std::map< + std::tuple, + c10::intrusive_ptr> + symm_mems_; +}; + +struct RegisterISHMEMSymmetricMemoryAllocator { + RegisterISHMEMSymmetricMemoryAllocator() { + auto allocator = c10::make_intrusive(); + // Always register availability to support dynamic backend switching + register_availability("ISHMEM", allocator); + // If this is the preferred backend, also set it as default + if (getSymmMemBackendXPU() == "ISHMEM") { + register_allocator(c10::DeviceType::XPU, allocator); + } + } +}; + +static RegisterISHMEMSymmetricMemoryAllocator register_allocator_; + +} // namespace symmetric_memory +} // namespace c10d diff --git a/src/xccl/IpcExchange.hpp b/src/xccl/IpcExchange.hpp index e515cd6ce0..600aea92c5 100644 --- a/src/xccl/IpcExchange.hpp +++ b/src/xccl/IpcExchange.hpp @@ -287,9 +287,9 @@ void un_allgather( recv_buf[rank] = *send_buf; } -class IpcChannel { +class IpcChannels { public: - IpcChannel() { + IpcChannels() { initialized = false; } void init(sycl::queue& queue, uint32_t rank_in, uint32_t world_in) { diff --git a/src/xccl/Signal.cpp b/src/xccl/Signal.cpp new file mode 100644 index 0000000000..09ad623423 --- /dev/null +++ b/src/xccl/Signal.cpp @@ -0,0 +1,201 @@ +#include +#include +#include +#include + +namespace c10d::symmetric_memory { + +struct barrierKernel { + void operator()(sycl::nd_item<1> item) const { + auto thread_id = item.get_local_id(0); + + if (thread_id < world_size) { + auto target_rank = thread_id; + if (target_rank == rank) { + return; + } + auto put_success = try_put_signal_device( + signal_pads[target_rank] + world_size * channel + rank, 10000000); + // if (!put_success) { + // assert(0); + // } + + auto wait_success = try_wait_signal_device( + signal_pads[rank] + world_size * channel + target_rank, 10000000); + // if (!wait_success) { + // assert(0); + // } + } + } + + barrierKernel( + uint32_t** signal_pads, + int channel, + int rank, + int world_size, + size_t timeout_ms) + : signal_pads(signal_pads), + channel(channel), + rank(rank), + world_size(world_size), + timeout_ms(timeout_ms) {} + + private: + uint32_t** signal_pads; + int channel; + int rank; + int world_size; + size_t timeout_ms; +}; + +void barrier_impl_xpu( + uint32_t** signal_pads, + int channel, + int rank, + int world_size, + size_t timeout_ms, + at::xpu::XPUStream& stream) { + int64_t maxNumThreadsPerBlock = syclMaxWorkGroupSize(); + const size_t numThreadsPerBlock = + std::min(maxNumThreadsPerBlock, std::max(32, world_size)); + + if (!(numThreadsPerBlock > 0)) { + return; + } + int64_t numBlocks = 1; + auto global_range = numBlocks * numThreadsPerBlock; + auto local_range = numThreadsPerBlock; + + using Kernel = barrierKernel; + auto kfn = Kernel(signal_pads, channel, rank, world_size, timeout_ms); + + sycl_kernel_submit(global_range, local_range, stream.queue(), kfn); +} + +struct putSignalKernel { + void operator()(sycl::nd_item<1> item) const { + auto thread_id = item.get_local_id(0); + + if (thread_id == 0) { + auto put_success = try_put_signal_device( + signal_pads[dst_rank] + world_size * channel + rank, 10000000); + // if (!put_success) { + // assert(0); + // } + } + } + + putSignalKernel( + uint32_t** signal_pads, + int dst_rank, + int channel, + int rank, + int world_size, + size_t timeout_ms) + : signal_pads(signal_pads), + dst_rank(dst_rank), + channel(channel), + rank(rank), + world_size(world_size), + timeout_ms(timeout_ms) {} + + private: + uint32_t** signal_pads; + int dst_rank; + int channel; + int rank; + int world_size; + size_t timeout_ms; +}; + +void put_signal_impl_xpu( + uint32_t** signal_pads, + int dst_rank, + int channel, + int rank, + int world_size, + size_t timeout_ms, + at::xpu::XPUStream& stream) { + int64_t maxNumThreadsPerBlock = syclMaxWorkGroupSize(); + const size_t numThreadsPerBlock = std::min(maxNumThreadsPerBlock, 32); + + if (!(numThreadsPerBlock > 0)) { + return; + } + + int64_t numBlocks = 1; + auto global_range = numBlocks * numThreadsPerBlock; + auto local_range = numThreadsPerBlock; + + using Kernel = putSignalKernel; + auto kfn = + Kernel(signal_pads, dst_rank, channel, rank, world_size, timeout_ms); + + sycl_kernel_submit(global_range, local_range, stream.queue(), kfn); +} + +struct waitSignalKernel { + void operator()(sycl::nd_item<1> item) const { + auto thread_id = item.get_local_id(0); + + if (thread_id == 0) { + auto wait_success = try_wait_signal_device( + signal_pads[rank] + world_size * channel + src_rank, 10000000); + // if (!wait_success) { + // assert(0); + // } + + sycl::atomic_fence(sycl::memory_order_seq_cst, sycl::memory_scope_system); + } + } + + waitSignalKernel( + uint32_t** signal_pads, + int src_rank, + int channel, + int rank, + int world_size, + size_t timeout_ms) + : signal_pads(signal_pads), + src_rank(src_rank), + channel(channel), + rank(rank), + world_size(world_size), + timeout_ms(timeout_ms) {} + + private: + uint32_t** signal_pads; + int src_rank; + int channel; + int rank; + int world_size; + size_t timeout_ms; +}; + +void wait_signal_impl_xpu( + uint32_t** signal_pads, + int src_rank, + int channel, + int rank, + int world_size, + size_t timeout_ms, + at::xpu::XPUStream& stream) { + int64_t maxNumThreadsPerBlock = syclMaxWorkGroupSize(); + const size_t numThreadsPerBlock = std::min(maxNumThreadsPerBlock, 32); + + if (!(numThreadsPerBlock > 0)) { + return; + } + + int64_t numBlocks = 1; + auto global_range = numBlocks * numThreadsPerBlock; + auto local_range = numThreadsPerBlock; + + using Kernel = waitSignalKernel; + auto kfn = + Kernel(signal_pads, src_rank, channel, rank, world_size, timeout_ms); + + sycl_kernel_submit(global_range, local_range, stream.queue(), kfn); +} + +} // namespace c10d::symmetric_memory diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp index d49d126122..9eadcc4fdd 100644 --- a/src/xccl/XPUSymmetricMemory.cpp +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -28,7 +29,7 @@ AllocationRef::AllocationRef( handle(handle), block_size(block_size), device_idx(device_idx), - local_allocation(local_allocation){} + local_allocation(local_allocation) {} AllocationRef::~AllocationRef() { if (is_finalizing()) { @@ -36,7 +37,7 @@ AllocationRef::~AllocationRef() { } // Currently, we cannot free virtual memory exchanged from other device. if (!local_allocation) { - return; + return; } c10::Device local_device(c10::DeviceType::XPU, device_idx); c10::DeviceGuard guard(local_device); @@ -163,59 +164,98 @@ void check_channel(int channel, int world_size) { void XPUSymmetricMemory::barrier(int channel, size_t timeout_ms) { check_channel(channel, world_size_); - // Currently, we leverage oneCCL for barrier. Later, we may move to SYCL - // implementation. - auto group = c10d::resolve_process_group(group_name_); - if (group == nullptr) { - TORCH_WARN( - "Process group '", - group_name_, - "' not found, please init process group first before calling SymmetricMemory"); - throw std::runtime_error("Process group not found"); - } - auto* xcclPg = dynamic_cast( - group->getBackend(c10::DeviceType::XPU).get()); - c10::Device local_device(c10::DeviceType::XPU, local_device_idx_); c10::DeviceGuard guard(local_device); + auto stream = at::xpu::getCurrentXPUStream(); - static thread_local at::Tensor barrier_tensor; - if (!barrier_tensor.defined() || barrier_tensor.device() != local_device) { - barrier_tensor = at::zeros( - {1}, at::TensorOptions().device(local_device).dtype(at::kFloat)); - } else { - barrier_tensor.zero_(); - } - - c10d::AllreduceOptions arOpts; - arOpts.asyncOp = false; - auto work = - xcclPg->allreduce_impl(barrier_tensor, "xccl:symm_mem_barrier", arOpts); - - if (work) { - bool success = work->wait(std::chrono::milliseconds(timeout_ms)); - TORCH_CHECK( - success, - "Barrier timeout after ", - timeout_ms, - " ms for group '", - group_name_, - "'"); - } + barrier_impl_xpu( + reinterpret_cast(signal_pads_dev_), + channel, + rank_, + world_size_, + timeout_ms, + stream); + // // Currently, we leverage oneCCL for barrier. Later, we may move to SYCL + // // implementation. + // auto group = c10d::resolve_process_group(group_name_); + // if (group == nullptr) { + // TORCH_WARN( + // "Process group '", + // group_name_, + // "' not found, please init process group first before calling + // SymmetricMemory"); + // throw std::runtime_error("Process group not found"); + // } + // auto* xcclPg = dynamic_cast( + // group->getBackend(c10::DeviceType::XPU).get()); + + // c10::Device local_device(c10::DeviceType::XPU, local_device_idx_); + // c10::DeviceGuard guard(local_device); + + // static thread_local at::Tensor barrier_tensor; + // if (!barrier_tensor.defined() || barrier_tensor.device() != local_device) { + // barrier_tensor = at::zeros( + // {1}, at::TensorOptions().device(local_device).dtype(at::kFloat)); + // } else { + // barrier_tensor.zero_(); + // } + + // c10d::AllreduceOptions arOpts; + // arOpts.asyncOp = false; + // auto work = + // xcclPg->allreduce_impl(barrier_tensor, "xccl:symm_mem_barrier", + // arOpts); + + // if (work) { + // bool success = work->wait(std::chrono::milliseconds(timeout_ms)); + // TORCH_CHECK( + // success, + // "Barrier timeout after ", + // timeout_ms, + // " ms for group '", + // group_name_, + // "'"); + // } } void XPUSymmetricMemory::put_signal( int dst_rank, int channel, size_t timeout_ms) { - LOG(ERROR) << "XPUSymmetricMemory::put_signal not supported"; + check_channel(channel, world_size_); + + c10::Device local_device(c10::DeviceType::XPU, local_device_idx_); + c10::DeviceGuard guard(local_device); + auto stream = at::xpu::getCurrentXPUStream(); + + put_signal_impl_xpu( + reinterpret_cast(signal_pads_dev_), + dst_rank, + channel, + rank_, + world_size_, + timeout_ms, + stream); } void XPUSymmetricMemory::wait_signal( int src_rank, int channel, size_t timeout_ms) { - LOG(ERROR) << "XPUSymmetricMemory::wait_signal not supported"; + check_channel(channel, world_size_); + + c10::Device local_device(c10::DeviceType::XPU, local_device_idx_); + c10::DeviceGuard guard(local_device); + auto stream = at::xpu::getCurrentXPUStream(); + + wait_signal_impl_xpu( + reinterpret_cast(signal_pads_dev_), + src_rank, + channel, + rank_, + world_size_, + timeout_ms, + stream); } int XPUSymmetricMemory::get_rank() { @@ -254,8 +294,8 @@ void* XPUSymmetricMemoryAllocator::alloc( sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); void* ptr = sycl::malloc_device(block_size, current_queue); current_queue.memset(ptr, 0, block_size); - auto alloc_ref = - c10::make_intrusive(ptr, ptr, block_size, device_idx, true); + auto alloc_ref = c10::make_intrusive( + ptr, ptr, block_size, device_idx, true); auto block = c10::make_intrusive( std::move(alloc_ref), device_idx, @@ -345,8 +385,8 @@ c10::intrusive_ptr XPUSymmetricMemoryAllocator::rendezvous( c10::Device local_device(c10::DeviceType::XPU, block->device_idx); c10::DeviceGuard guard(local_device); - // IpcChannel is used to do inter-process communication - IpcChannel ipc_channel; + // IpcChannels is used to do inter-process communication + IpcChannels ipc_channel; auto group_info = get_group_info(group_name_); auto store = group_info.store; int rank = group_info.rank; @@ -444,13 +484,11 @@ c10::intrusive_ptr XPUSymmetricMemoryAllocator::find_block(void* ptr) { struct RegisterXPUSymmetricMemoryAllocator { RegisterXPUSymmetricMemoryAllocator() { auto allocator = c10::make_intrusive(); - // Query backend used for XPU + // Always register availability to support dynamic backend switching + register_availability("XPU", allocator); + // If this is the preferred backend, also set it as default if (getSymmMemBackendXPU() == "XPU") { - // Direct set (static registration) register_allocator(c10::DeviceType::XPU, allocator); - } else { - // Register availability in case `set_backend` is called dynamically - register_availability("XPU", allocator); } } }; diff --git a/src/xccl/XPUSymmetricMemoryUtils.cpp b/src/xccl/XPUSymmetricMemoryUtils.cpp index 7130fe7b6a..78b3892155 100644 --- a/src/xccl/XPUSymmetricMemoryUtils.cpp +++ b/src/xccl/XPUSymmetricMemoryUtils.cpp @@ -1,3 +1,5 @@ +#include + #include #include #include @@ -5,72 +7,191 @@ #include -#include -#include -#include -#include -#include - namespace c10d::symmetric_memory { +// Query environment variable to get the backend used for XPU Symmetric Memory. std::string getSymmMemBackendXPU() { + // TORCH_SYMMMEM environment variable can be used to indicate the preferred + // backend. static auto val = c10::utils::get_env("TORCH_SYMMMEM"); if (val.has_value()) { TORCH_CHECK( - val.value() == "XPU", - "TORCH_SYMMMEM environment variable must be 'XPU'."); + val.value() == "XPU" || val.value() == "ISHMEM", + "TORCH_SYMMMEM environment variable must be one of 'XPU', 'ISHMEM'.") return val.value(); } return "XPU"; } -bool device_has_multicast_support(int device_idx) { - return false; +IpcChannel::IpcChannel() + : socket_name_(get_socket_name(getpid())), + socket_(socket(AF_UNIX, SOCK_DGRAM, 0)) { + // On success, a file descriptor for the new socket is returned. + // On error, -1 is returned, and errno is set to indicate the error. + TORCH_CHECK( + socket_ != -1, "Failed to create socket: ", c10::utils::str_error(errno)); + + struct sockaddr_un addr = {.sun_family = AF_UNIX}; + std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path); + + TORCH_CHECK( + bind(socket_, (struct sockaddr*)&addr, SUN_LEN(&addr)) == 0, + "Failed to bind socket: ", + c10::utils::str_error(errno)); } -bool allow_overlapping_devices() { - return false; +IpcChannel::~IpcChannel() { + close(socket_); + unlink(socket_name_.c_str()); } -void map_block( - void** ptr, - ze_physical_mem_handle_t handle, - size_t size, - int device_idx) { - sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); - sycl::context sycl_ctx = current_queue.get_context(); - ze_context_handle_t ze_context = - sycl::get_native(sycl_ctx); - // 1. Reserve virtual address space - void* virtual_ptr = nullptr; - ze_result_t status = zeVirtualMemReserve( - ze_context, // context - nullptr, // let L0 pick virtual address - size, // size - &virtual_ptr // out: reserved address - ); - TORCH_CHECK(status == ZE_RESULT_SUCCESS, "zeVirtualMemReserve failed"); - - // 2. Map physical memory to virtual address - status = zeVirtualMemMap( - ze_context, - virtual_ptr, // virtual memory to map to - size, - handle, // physical memory handle - 0, // flags - ZE_MEMORY_ACCESS_ATTRIBUTE_READWRITE // ze_memory_access_attribute_t - ); - TORCH_CHECK(status == ZE_RESULT_SUCCESS, "zeVirtualMemMap failed"); - - // 3. Set access attributes - ze_memory_access_attribute_t access = ZE_MEMORY_ACCESS_ATTRIBUTE_READWRITE; - status = - zeVirtualMemSetAccessAttribute(ze_context, virtual_ptr, size, access); +void IpcChannel::send_fd(int dst_pid, int fd) { + // Because file descriptors are process-local kernel objects, and we can’t + // pass them via normal socket payloads (like write() or send()). Unix domain + // sockets provide a mechanism to pass actual FDs via sendmsg()/recvmsg(). + // Define destination socket address + struct sockaddr_un addr = {.sun_family = AF_UNIX}; + auto socket_name = get_socket_name(dst_pid); + std::copy(socket_name.begin(), socket_name.end(), addr.sun_path); + + // Prepare data to send + // Data being sent is "fd", the value of fd will be sent as auxiliary data + // (control message) + struct iovec io = {.iov_base = (void*)("fd"), .iov_len = 2}; + + // Prepare control message data buffer and zero it out + // NOLINTNEXTLINE(*array*) + char cbuf[CMSG_SPACE(sizeof(int))]; + memset(cbuf, 0, sizeof(cbuf)); + + // Create message header + struct msghdr msg { + // destination socket address and size of it + // message content in msg_iov and number of such structs (1 in our case) + // auxiliary data with the value of fd and size of it + .msg_name = (void*)&addr, .msg_namelen = sizeof(struct sockaddr_un), + .msg_iov = &io, .msg_iovlen = 1, .msg_control = cbuf, + .msg_controllen = sizeof(cbuf) + }; + + // This points to the first control message header + // With SCM_RIGHTS we let the kernel know that we are passing file + // descriptors. + auto cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_len = CMSG_LEN(sizeof(int)); + // Specify socket level message + cmsg->cmsg_level = SOL_SOCKET; + // SCM_RIGHTS is the type used to pass file descriptors + cmsg->cmsg_type = SCM_RIGHTS; + + if (fd != -1) { + std::copy( + reinterpret_cast(&fd), + reinterpret_cast(&fd) + sizeof(fd), + reinterpret_cast(CMSG_DATA(cmsg))); + } else { + msg.msg_controllen = 0; + } + + // Finally send the message TORCH_CHECK( - status == ZE_RESULT_SUCCESS, "zeVirtualMemSetAccessAttribute failed"); + sendmsg(socket_, &msg, 0) > 0, + "Failed to send fd: ", + c10::utils::str_error(errno)); +} + +int IpcChannel::recv_fd() { + // Prepare buffer for regular message "fd" + // NOLINTNEXTLINE(*array*) + char buf[2]; + memset(&buf, 0, sizeof(buf)); + struct iovec io = {.iov_base = (void*)buf, .iov_len = sizeof(buf)}; + + // Prepare buffer for control message and zero it out + // NOLINTNEXTLINE(*array*) + char cbuf[CMSG_SPACE(sizeof(int))]; + memset(cbuf, 0, sizeof(cbuf)); - // 4. Return pointer - *ptr = virtual_ptr; + // Define socket address to receive on: family AF_UNIX means unix domain + // socket + struct sockaddr_un addr = {.sun_family = AF_UNIX}; + std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path); + + // Prepare message header + struct msghdr msg = { + .msg_name = (void*)&addr, + .msg_namelen = sizeof(struct sockaddr_un), + .msg_iov = &io, + .msg_iovlen = 1, + .msg_control = cbuf, + .msg_controllen = sizeof(cbuf)}; + + // Receive message on socket_ + TORCH_CHECK( + recvmsg(socket_, &msg, 0) > 0, + "Failed to receive fd: ", + c10::utils::str_error(errno)); + + if (msg.msg_controllen == 0) { + return -1; + } + + // Extract control message and validate its content + auto cmsg = CMSG_FIRSTHDR(&msg); + TORCH_CHECK(cmsg != nullptr); + TORCH_CHECK(cmsg->cmsg_len == CMSG_LEN(sizeof(int))); + TORCH_CHECK(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS); + return *reinterpret_cast(CMSG_DATA(cmsg)); +} + +std::vector IpcChannel::all_gather_fds( + int rank, + const std::vector& pids, + int fd) { + int world_size = static_cast(pids.size()); + std::vector fds(pids.size()); + fds[rank] = fd; + + int dst_rank = (rank + 1) % world_size; + for (int step = 1; step < world_size; ++step) { + int src_rank = (rank + world_size - step) % world_size; + send_fd(pids[dst_rank], fd); + fd = recv_fd(); + fds[src_rank] = fd; + } + return fds; +} + +int IpcChannel::broadcast_fds( + int rank, + int src_rank, + const std::vector& pids, + int fd) { + int world_size = static_cast(pids.size()); + + if (rank == src_rank) { + for (int dst_rank = 0; dst_rank < world_size; ++dst_rank) { + if (dst_rank == rank) { + continue; + } + send_fd(pids[dst_rank], fd); + } + return fd; + } + return recv_fd(); +} + +std::string IpcChannel::get_socket_name(int pid) { + const char* tmp_dir = "/tmp"; + for (const char* env_var : {"TMPDIR", "TMP", "TEMP", "TEMPDIR"}) { + if (const char* path = getenv(env_var)) { + tmp_dir = path; + break; + } + } + std::ostringstream oss; + oss << tmp_dir << "/symm_mem-" << pid; + return oss.str(); } } // namespace c10d::symmetric_memory diff --git a/src/xccl/XPUSymmetricMemoryUtils.hpp b/src/xccl/XPUSymmetricMemoryUtils.hpp index 69189f45cf..36204d2f7d 100644 --- a/src/xccl/XPUSymmetricMemoryUtils.hpp +++ b/src/xccl/XPUSymmetricMemoryUtils.hpp @@ -1,20 +1,55 @@ #pragma once + #include #include #include +// A set of store-based exchange methods with a preset prefix typically type of +// the SymmetricMemory. Most used as static instances at respective +// SymmetricMemory implementation files. +#include +#include +#include +#include + +#include + +#include namespace c10d { namespace symmetric_memory { std::string getSymmMemBackendXPU(); -bool device_has_multicast_support(int device_idx); +// bool device_has_multicast_support(int device_idx); -bool allow_overlapping_devices(); +// bool allow_overlapping_devices(); + +class IpcChannel { + public: + IpcChannel(); + ~IpcChannel(); + + void send_fd(int dst_pid, int fd); + int recv_fd(); + + std::vector all_gather_fds( + int rank, + const std::vector& pids, + int fd); + + int broadcast_fds( + int rank, + int src_rank, + const std::vector& pids, + int fd); + + private: + static std::string get_socket_name(int pid); + + std::string socket_name_; + int socket_; +}; -// A set of store-based exchange methods with a preset prefix typically type of -// the SymmetricMemory. Most used as static instances at respective -// SymmetricMemory implementation files. class StoreExchange { public: StoreExchange(const std::string& store_prefix) @@ -79,11 +114,11 @@ class StoreExchange { // held by the handle. // todo: will follow such physical memory handle map with virtual address, // when L0 provides physical handle exchange API and we have multicast support. -void map_block( - void** ptr, - ze_physical_mem_handle_t handle, - size_t size, - int device_idx); +// void map_block( +// void** ptr, +// ze_physical_mem_handle_t handle, +// size_t size, +// int device_idx); } // namespace symmetric_memory } // namespace c10d From ac8f99b4fb79e968d23dc76f5fbed53cd85b516a Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Wed, 10 Dec 2025 16:32:44 +0800 Subject: [PATCH 04/25] unify IpcChannel and rm dynamic ze --- src/xccl/ISHMEMSymmetricMemory.cpp | 41 +++-- src/xccl/XPUSymmetricMemory.cpp | 58 ++++-- src/xccl/XPUSymmetricMemory.hpp | 13 ++ src/xccl/XPUSymmetricMemoryUtils.cpp | 33 +++- src/xccl/ze_symbol.hpp | 254 --------------------------- 5 files changed, 112 insertions(+), 287 deletions(-) delete mode 100644 src/xccl/ze_symbol.hpp diff --git a/src/xccl/ISHMEMSymmetricMemory.cpp b/src/xccl/ISHMEMSymmetricMemory.cpp index 9ea6a67918..2a20e0bf28 100644 --- a/src/xccl/ISHMEMSymmetricMemory.cpp +++ b/src/xccl/ISHMEMSymmetricMemory.cpp @@ -5,16 +5,13 @@ #include #include #include -// Include ISHMEM headers - directly link to static library + #include #include namespace c10d { namespace symmetric_memory { -/* Start of ISHMEMSymmetricMemory implementation */ - -// XPU-specific constants for symmetric memory // Intel Data Center GPU Max can support up to 8 GPUs in a single node constexpr int max_xpu_p2p_domain_size = 8; // Maximum number of channels (same as CUDA) @@ -34,7 +31,6 @@ struct ISHMEMAllocation { : ptr(ptr), buffer_size(buffer_size), device_idx(device_idx) {} ~ISHMEMAllocation() { - // Avoid calling XPU functions after driver shutting down if (is_finalizing()) { return; } @@ -96,7 +92,6 @@ class ISHMEMPeerAllocInfo : public c10::intrusive_ptr_target { void* signal_pad_ptr = ishmem_malloc(xpu_signal_pad_size); TORCH_CHECK(signal_pad_ptr != nullptr, "ishmem_malloc failed"); - // Use SYCL queue to initialize signal pad memory auto& queue = at::xpu::getCurrentSYCLQueue(); queue.memset(signal_pad_ptr, 0, xpu_signal_pad_size).wait(); @@ -196,12 +191,10 @@ class ISHMEMSymmetricMemory : public SymmetricMemory { }; bool has_multicast_support() override { - // ISHMEM does not have multicast support return false; } void* get_multicast_ptr() override { - // ISHMEM does not have multicast support return nullptr; } @@ -210,13 +203,11 @@ class ISHMEMSymmetricMemory : public SymmetricMemory { } void barrier(int channel, size_t timeout_ms) override { - // Use ISHMEM barrier ishmem_barrier_all(); } void put_signal(int dst_rank, int channel, size_t timeout_ms) override { // TODO: Implement signal mechanism for ISHMEM - // ISHMEM uses different signaling approach than NVSHMEM } void wait_signal(int src_rank, int channel, size_t timeout_ms) override { @@ -263,26 +254,44 @@ static void initialize_ishmem_with_store( if (is_initialized) { return; } - c10::OptionalDeviceGuard guard; guard.reset_device(at::Device(at::DeviceType::XPU, device_idx)); + // Generate unique ID - ONLY rank 0 should generate it ishmemx_uniqueid_t unique_id; + memset( + &unique_id, 0, sizeof(unique_id)); // Zero-initialize for non-root ranks + if (rank == 0) { + LOG(INFO) << "[ISHMEM Init] Rank 0 generating unique ID"; // Root rank generates the unique ID int ret = ishmemx_get_uniqueid(&unique_id); TORCH_CHECK(ret == 0, "ishmemx_get_uniqueid failed with error: ", ret); + LOG(INFO) << "[ISHMEM Init] Rank 0 unique ID generated"; } - auto unique_ids = - storeExchange.all_gather(store, rank, world_size, unique_id); + // All-gather to distribute rank 0's unique_id to all ranks + // This creates a vector where all elements should contain rank 0's unique_id + std::vector unique_ids; + LOG(INFO) << "[ISHMEM Init] Rank " << rank + << " about to all_gather unique_id"; + try { + unique_ids = storeExchange.all_gather(store, rank, world_size, unique_id); + LOG(INFO) << "[ISHMEM Init] Rank " << rank + << " all_gather completed, received " << unique_ids.size() + << " unique_ids"; + } catch (const std::exception& e) { + LOG(ERROR) << "[ISHMEM Init] Rank " << rank + << " all_gather failed: " << e.what(); + throw; + } - // Initialize ISHMEM with attributes using unique ID + // Initialize ISHMEM with attributes using unique ID from rank 0 ishmemx_attr_t attr; - attr.initialize_runtime = false; // MPI/OpenSHMEM backend already initialized + attr.initialize_runtime = false; attr.use_uid = true; attr.nranks = world_size; - attr.uid = &unique_ids[0]; + attr.uid = &unique_ids[0]; // Use rank 0's unique_id (first element) // ishmemx_init_attr returns void, not int ishmemx_init_attr(&attr); diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp index 9eadcc4fdd..7bbf2407b4 100644 --- a/src/xccl/XPUSymmetricMemory.cpp +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -1,6 +1,4 @@ -#include #include -#include #include #include @@ -385,8 +383,6 @@ c10::intrusive_ptr XPUSymmetricMemoryAllocator::rendezvous( c10::Device local_device(c10::DeviceType::XPU, block->device_idx); c10::DeviceGuard guard(local_device); - // IpcChannels is used to do inter-process communication - IpcChannels ipc_channel; auto group_info = get_group_info(group_name_); auto store = group_info.store; int rank = group_info.rank; @@ -408,27 +404,61 @@ c10::intrusive_ptr XPUSymmetricMemoryAllocator::rendezvous( pids[r] = reqs[r].pid; } - // do IPC exchange for all peer ranks - ipc_channel.exchange_peer_ipc_mem(current_queue, ptr, rank, world_size); + // Step 1: Get base address and offset + sycl::context ctx = current_queue.get_context(); + auto l0_ctx = sycl::get_native(ctx); + auto l0_device = sycl::get_native( + current_queue.get_device()); - // no physical memory handle, so handles and buffers are both for virtual - // address + void* base_addr; + size_t base_size; + ZE_CHECK(zeMemGetAddressRange(l0_ctx, ptr, &base_addr, &base_size)); + size_t offset = (char*)ptr - (char*)base_addr; + + // Step 2: Get IPC mem handle from base address + ze_ipc_mem_handle_t local_ipc_handle; + ZE_CHECK(zeMemGetIpcHandle(l0_ctx, base_addr, &local_ipc_handle)); + + // Step 3: Extract fd from IPC handle (ze_ipc_mem_handle_t's first field is + // fd) + int local_fd = *reinterpret_cast(&local_ipc_handle); + + // Step 4: Exchange offsets via store + auto offsets = storeExchange.all_gather(store, rank, world_size, offset); + + // Step 5: Exchange fds via IpcChannel (uses Unix domain socket + SCM_RIGHTS) + IpcChannel ipc_channel; + auto fds = ipc_channel.all_gather_fds(rank, pids, local_fd); + + // Step 6: Reconstruct remote IPC handles and open them std::vector handles(world_size); std::vector buffers(world_size, nullptr); std::vector signal_pads(world_size, nullptr); for (int r = 0; r < world_size; ++r) { if (r == rank) { - handles[r] = block->alloc_ref->handle; + handles[r] = base_addr; // Store base address as handle buffers[r] = ptr; signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset); continue; - } else { - buffers[r] = ipc_channel.buffers[r]; - handles[r] = ipc_channel.buffers[r]; - signal_pads[r] = - (void*)((uintptr_t)buffers[r] + block->signal_pad_offset); } + + // Reconstruct remote IPC handle by setting the fd field + ze_ipc_mem_handle_t remote_ipc_handle = local_ipc_handle; // Copy structure + *reinterpret_cast(&remote_ipc_handle) = fds[r]; // Set remote fd + + // Open IPC handle to get remote base address + void* remote_base; + ZE_CHECK(zeMemOpenIpcHandle( + l0_ctx, + l0_device, + remote_ipc_handle, + ZE_IPC_MEMORY_FLAG_BIAS_CACHED, + &remote_base)); + + handles[r] = remote_base; // Store remote base address as handle + buffers[r] = (char*)remote_base + offsets[r]; + signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset); } storeExchange.barrier(store, rank, world_size); diff --git a/src/xccl/XPUSymmetricMemory.hpp b/src/xccl/XPUSymmetricMemory.hpp index 2daac1114a..4137a033d1 100644 --- a/src/xccl/XPUSymmetricMemory.hpp +++ b/src/xccl/XPUSymmetricMemory.hpp @@ -1,11 +1,24 @@ #pragma once +#include +#include + #include #include #include #include #include +#define ZE_CHECK(call) \ + do { \ + ze_result_t result = (call); \ + TORCH_CHECK( \ + result == ZE_RESULT_SUCCESS, \ + "Level Zero error: ", \ + #call, \ + " returned ", \ + result); \ + } while (0) namespace c10d::symmetric_memory { // Resource wrapper that owns a (vaddr, allocation handle) pair. Upon diff --git a/src/xccl/XPUSymmetricMemoryUtils.cpp b/src/xccl/XPUSymmetricMemoryUtils.cpp index 78b3892155..f518166361 100644 --- a/src/xccl/XPUSymmetricMemoryUtils.cpp +++ b/src/xccl/XPUSymmetricMemoryUtils.cpp @@ -93,10 +93,37 @@ void IpcChannel::send_fd(int dst_pid, int fd) { msg.msg_controllen = 0; } - // Finally send the message + // Retry sending with exponential backoff (wait for destination socket to be + // ready) + const int max_retries = 100; + int retry = 0; + ssize_t result = -1; + + while (retry < max_retries) { + result = sendmsg(socket_, &msg, 0); + if (result > 0) { + return; // Success + } + + // Check if error is because destination doesn't exist yet + if (errno == ENOENT || errno == ECONNREFUSED) { + // Exponential backoff: 1ms, 2ms, 4ms, ..., up to 100ms + int sleep_ms = std::min(1 << retry, 100); + usleep(sleep_ms * 1000); + retry++; + continue; + } + + // Other errors should fail immediately + break; + } + + // Finally check if we succeeded or report error TORCH_CHECK( - sendmsg(socket_, &msg, 0) > 0, - "Failed to send fd: ", + result > 0, + "Failed to send fd after ", + retry, + " retries: ", c10::utils::str_error(errno)); } diff --git a/src/xccl/ze_symbol.hpp b/src/xccl/ze_symbol.hpp deleted file mode 100644 index 20af666811..0000000000 --- a/src/xccl/ze_symbol.hpp +++ /dev/null @@ -1,254 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -#define zeVirtualMemMap zeVirtualMemMap_original -#define zeVirtualMemReserve zeVirtualMemReserve_original -#define zeVirtualMemSetAccessAttribute zeVirtualMemSetAccessAttribute_original - -#include - -#undef zeVirtualMemMap -#undef zeVirtualMemReserve -#undef zeVirtualMemSetAccessAttribute - -typedef ze_result_t (*zeInit_t)(ze_init_flags_t flags); -typedef ze_result_t (*zeMemGetAddressRange_t)( - ze_context_handle_t hContext, - const void* ptr, - void** pBase, - size_t* pSize); -typedef ze_result_t (*zeMemGetIpcHandle_t)( - ze_context_handle_t hContext, - const void* ptr, - ze_ipc_mem_handle_t* pIpcHandle); -typedef ze_result_t (*zeMemOpenIpcHandle_t)( - ze_context_handle_t hContext, - ze_device_handle_t hDevice, - ze_ipc_mem_handle_t handle, - ze_ipc_memory_flags_t flags, - void** pptr); -typedef ze_result_t ( - *zeMemCloseIpcHandle_t)(ze_context_handle_t hContext, const void* ptr); -typedef ze_result_t (*zeVirtualMemMap_t)( - ze_context_handle_t hContext, - const void* ptr, - size_t size, - ze_physical_mem_handle_t hPhysicalMemory, - size_t offset, - ze_memory_access_attribute_t access); -typedef ze_result_t (*zeVirtualMemReserve_t)( - ze_context_handle_t hContext, - const void* pStart, - size_t size, - void** pptr); -typedef ze_result_t (*zeVirtualMemSetAccessAttribute_t)( - ze_context_handle_t hContext, - const void* ptr, - size_t size, - ze_memory_access_attribute_t access); - -bool load_level_zero_library(); -void unload_level_zero_library(); - -#define zeCheck_dynamic(x) \ - do { \ - if (!load_level_zero_library()) { \ - throw std::runtime_error("Level Zero library not available"); \ - } \ - ze_result_t result = (x); \ - if (result != ZE_RESULT_SUCCESS) { \ - auto e = zeException(result); \ - std::cout << "Throw " << e.what() << std::endl; \ - throw e; \ - } \ - } while (0) - -#define zeInit_dynamic(flags) zeInit_ptr(flags) -#define zeMemGetAddressRange_dynamic(ctx, ptr, base, size) \ - zeMemGetAddressRange_ptr(ctx, ptr, base, size) -#define zeMemGetIpcHandle_dynamic(ctx, ptr, handle) \ - zeMemGetIpcHandle_ptr(ctx, ptr, handle) -#define zeMemOpenIpcHandle_dynamic(ctx, dev, handle, flags, ptr) \ - zeMemOpenIpcHandle_ptr(ctx, dev, handle, flags, ptr) -#define zeMemCloseIpcHandle_dynamic(ctx, ptr) zeMemCloseIpcHandle_ptr(ctx, ptr) -#define zeVirtualMemMap_dynamic(ctx, ptr, size, phys_mem, offset, access) \ - zeVirtualMemMap_ptr(ctx, ptr, size, phys_mem, offset, access) -#define zeVirtualMemReserve_dynamic(ctx, start, size, ptr) \ - zeVirtualMemReserve_ptr(ctx, start, size, ptr) -#define zeVirtualMemSetAccessAttribute_dynamic(ctx, ptr, size, access) \ - zeVirtualMemSetAccessAttribute_ptr(ctx, ptr, size, access) - -// Exception handling class -class zeException : std::exception { - const char* zeResultToString(ze_result_t status) const { - static const std::unordered_map zeResultToStringMap{ - {ZE_RESULT_SUCCESS, "[Core] success"}, - {ZE_RESULT_NOT_READY, "[Core] synchronization primitive not signaled"}, - {ZE_RESULT_ERROR_UNINITIALIZED, - "[Validation] driver is not initialized"}, - {ZE_RESULT_ERROR_INVALID_NULL_POINTER, - "[Validation] pointer argument may not be nullptr"}, - {ZE_RESULT_ERROR_INVALID_NULL_HANDLE, - "[Validation] handle argument is not valid"}, - {ZE_RESULT_ERROR_INVALID_ENUMERATION, - "[Validation] enumerator argument is not valid"}, - {ZE_RESULT_ERROR_INVALID_SIZE, "[Validation] size argument is invalid"}, - {ZE_RESULT_ERROR_UNSUPPORTED_SIZE, - "[Validation] size argument is not supported by the device"}, - {ZE_RESULT_ERROR_UNSUPPORTED_ALIGNMENT, - "[Validation] alignment argument is not supported by the device"}, - {ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, - "[Validation] generic error code for unsupported features"}, - {ZE_RESULT_ERROR_INVALID_NATIVE_BINARY, - "[Validation] native binary is not supported by the device"}, - {ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY, - "[Core] insufficient host memory to satisfy call"}, - {ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY, - "[Core] insufficient device memory to satisfy call"}, - {ZE_RESULT_ERROR_DEVICE_LOST, - "[Core] device hung, reset, was removed, or driver update occurred"}, - {ZE_RESULT_ERROR_MODULE_BUILD_FAILURE, - "[Core] error occurred when building module, see build log for details"}, - {ZE_RESULT_ERROR_HANDLE_OBJECT_IN_USE, - "[Validation] object pointed to by handle still in-use by device"}, - }; - auto it = zeResultToStringMap.find(status); - if (it != zeResultToStringMap.end()) - return it->second; - else - return "Unknown Reason"; - } - - public: - zeException(ze_result_t ret) : result_(ret) {} - - ze_result_t result_; - - const char* what() const noexcept override { - return zeResultToString(result_); - } -}; - -#define zeCheck(x) \ - if (x != ZE_RESULT_SUCCESS) { \ - auto e = zeException(x); \ - std::cout << "Throw " << e.what() << std::endl; \ - throw e; \ - } - -static zeInit_t zeInit_ptr = nullptr; -static zeMemGetAddressRange_t zeMemGetAddressRange_ptr = nullptr; -static zeMemGetIpcHandle_t zeMemGetIpcHandle_ptr = nullptr; -static zeMemOpenIpcHandle_t zeMemOpenIpcHandle_ptr = nullptr; -static zeMemCloseIpcHandle_t zeMemCloseIpcHandle_ptr = nullptr; -static zeVirtualMemMap_t zeVirtualMemMap_ptr = nullptr; -static zeVirtualMemReserve_t zeVirtualMemReserve_ptr = nullptr; -static zeVirtualMemSetAccessAttribute_t zeVirtualMemSetAccessAttribute_ptr = - nullptr; - -static void* ze_handle = nullptr; - -inline bool load_level_zero_library() { - if (ze_handle != nullptr) { - return true; - } - const char* lib_names[] = {"libze_loader.so"}; - - for (const char* lib_name : lib_names) { - ze_handle = dlopen(lib_name, RTLD_LAZY); - if (ze_handle != nullptr) { - break; - } - } - - if (ze_handle == nullptr) { - std::cerr << "Failed to load Level Zero library: " << dlerror() - << std::endl; - return false; - } - - zeInit_ptr = (zeInit_t)dlsym(ze_handle, "zeInit"); - zeMemGetAddressRange_ptr = - (zeMemGetAddressRange_t)dlsym(ze_handle, "zeMemGetAddressRange"); - zeMemGetIpcHandle_ptr = - (zeMemGetIpcHandle_t)dlsym(ze_handle, "zeMemGetIpcHandle"); - zeMemOpenIpcHandle_ptr = - (zeMemOpenIpcHandle_t)dlsym(ze_handle, "zeMemOpenIpcHandle"); - zeMemCloseIpcHandle_ptr = - (zeMemCloseIpcHandle_t)dlsym(ze_handle, "zeMemCloseIpcHandle"); - zeVirtualMemMap_ptr = (zeVirtualMemMap_t)dlsym(ze_handle, "zeVirtualMemMap"); - zeVirtualMemReserve_ptr = - (zeVirtualMemReserve_t)dlsym(ze_handle, "zeVirtualMemReserve"); - zeVirtualMemSetAccessAttribute_ptr = (zeVirtualMemSetAccessAttribute_t)dlsym( - ze_handle, "zeVirtualMemSetAccessAttribute"); - - if (!zeInit_ptr || !zeMemGetAddressRange_ptr || !zeMemGetIpcHandle_ptr || - !zeMemOpenIpcHandle_ptr || !zeMemCloseIpcHandle_ptr || - !zeVirtualMemMap_ptr || !zeVirtualMemReserve_ptr || - !zeVirtualMemSetAccessAttribute_ptr) { - std::cerr << "Failed to load Level Zero API functions" << std::endl; - dlclose(ze_handle); - ze_handle = nullptr; - return false; - } - - return true; -} - -inline void unload_level_zero_library() { - if (ze_handle != nullptr) { - dlclose(ze_handle); - ze_handle = nullptr; - zeInit_ptr = nullptr; - zeMemGetAddressRange_ptr = nullptr; - zeMemGetIpcHandle_ptr = nullptr; - zeMemOpenIpcHandle_ptr = nullptr; - zeMemCloseIpcHandle_ptr = nullptr; - zeVirtualMemMap_ptr = nullptr; - zeVirtualMemReserve_ptr = nullptr; - zeVirtualMemSetAccessAttribute_ptr = nullptr; - } -} - -extern "C" { - -__attribute__((weak)) ze_result_t zeVirtualMemMap( - ze_context_handle_t hContext, - const void* ptr, - size_t size, - ze_physical_mem_handle_t hPhysicalMemory, - size_t offset, - ze_memory_access_attribute_t access) { - if (!load_level_zero_library() || !zeVirtualMemMap_ptr) { - return ZE_RESULT_ERROR_UNINITIALIZED; - } - return zeVirtualMemMap_ptr( - hContext, ptr, size, hPhysicalMemory, offset, access); -} - -__attribute__((weak)) ze_result_t zeVirtualMemReserve( - ze_context_handle_t hContext, - const void* pStart, - size_t size, - void** pptr) { - if (!load_level_zero_library() || !zeVirtualMemReserve_ptr) { - return ZE_RESULT_ERROR_UNINITIALIZED; - } - return zeVirtualMemReserve_ptr(hContext, pStart, size, pptr); -} - -__attribute__((weak)) ze_result_t zeVirtualMemSetAccessAttribute( - ze_context_handle_t hContext, - const void* ptr, - size_t size, - ze_memory_access_attribute_t access) { - if (!load_level_zero_library() || !zeVirtualMemSetAccessAttribute_ptr) { - return ZE_RESULT_ERROR_UNINITIALIZED; - } - return zeVirtualMemSetAccessAttribute_ptr(hContext, ptr, size, access); -} -} From a18c6d50eecfa8151234707b78295411729f9467 Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Thu, 11 Dec 2025 15:49:36 +0800 Subject: [PATCH 05/25] add miss file --- CMakeLists.txt | 10 ++++++ cmake/ISHMEM.cmake | 24 +++++++++++++ cmake/Modules/FindISHMEM.cmake | 65 ++++++++++++++++++++++++++++++++++ src/BuildOnLinux.cmake | 6 ++++ 4 files changed, 105 insertions(+) create mode 100644 cmake/ISHMEM.cmake create mode 100644 cmake/Modules/FindISHMEM.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index 8f6fa8d6d5..ecd23ef8cf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,6 +64,16 @@ if(USE_XCCL) caffe2_update_option(USE_C10D_XCCL OFF) update_caffe2_macros_file() endif() + if(USE_ISHMEM) + include(${TORCH_XPU_OPS_ROOT}/cmake/ISHMEM.cmake) + if(NOT PYTORCH_FOUND_ISHMEM) + message(WARNING "ISHMEM not found, disabling ISHMEM support") + caffe2_update_option(USE_ISHMEM OFF) + update_caffe2_macros_file() + else() + message(STATUS "ISHMEM support enabled") + endif() + endif() endif() set(USE_SYCLTLA ON) diff --git a/cmake/ISHMEM.cmake b/cmake/ISHMEM.cmake new file mode 100644 index 0000000000..434935ee2c --- /dev/null +++ b/cmake/ISHMEM.cmake @@ -0,0 +1,24 @@ +if(NOT __ISHMEM_INCLUDED) + set(__ISHMEM_INCLUDED TRUE) + + # ISHMEM_ROOT, ISHMEM_LIBRARY_DIR, ISHMEM_INCLUDE_DIR are handled by FindISHMEM.cmake. + find_package(ISHMEM REQUIRED) + if(NOT ISHMEM_FOUND) + set(PYTORCH_FOUND_ISHMEM FALSE) + message(WARNING "${ISHMEM_NOT_FOUND_MESSAGE}") + return() + endif() + + set(PYTORCH_FOUND_ISHMEM TRUE) + add_library(torch::ishmem INTERFACE IMPORTED) + set_property( + TARGET torch::ishmem PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${ISHMEM_INCLUDE_DIR}) + set_property( + TARGET torch::ishmem PROPERTY INTERFACE_LINK_LIBRARIES + ${ISHMEM_LIBRARY}) + + message(STATUS "Found Intel SHMEM: ${ISHMEM_ROOT}") + message(STATUS " ISHMEM include dir: ${ISHMEM_INCLUDE_DIR}") + message(STATUS " ISHMEM library: ${ISHMEM_LIBRARY}") +endif() diff --git a/cmake/Modules/FindISHMEM.cmake b/cmake/Modules/FindISHMEM.cmake new file mode 100644 index 0000000000..96c7656b85 --- /dev/null +++ b/cmake/Modules/FindISHMEM.cmake @@ -0,0 +1,65 @@ +# This will define the following variables: +# ISHMEM_FOUND : True if the system has the ISHMEM library. +# ISHMEM_INCLUDE_DIR : Include directories needed to use ISHMEM. +# ISHMEM_LIBRARY_DIR : The path to the ISHMEM library. +# ISHMEM_LIBRARY : ISHMEM library fullname. + +include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) + +if(NOT CMAKE_SYSTEM_NAME MATCHES "Linux") + set(ISHMEM_FOUND False) + set(ISHMEM_NOT_FOUND_MESSAGE "Intel SHMEM library is only supported on Linux!") + return() +endif() + +set(ISHMEM_ROOT $ENV{ISHMEM_ROOT}) + +if(NOT ISHMEM_ROOT) + set(ISHMEM_FOUND False) + set(ISHMEM_NOT_FOUND_MESSAGE "ISHMEM_ROOT environment variable not set. Please set it to your ISHMEM installation directory.") + return() +endif() + +# Find include path from binary. +find_path( + ISHMEM_INCLUDE_DIR + NAMES ishmem.h + HINTS ${ISHMEM_ROOT}/include + NO_DEFAULT_PATH +) + +# Find library directory from binary. +find_path( + ISHMEM_LIBRARY_DIR + NAMES libishmem.a + HINTS ${ISHMEM_ROOT}/lib + NO_DEFAULT_PATH +) + +# Find ISHMEM library fullname (static library). +find_library( + ISHMEM_LIBRARY + NAMES ishmem + HINTS ${ISHMEM_LIBRARY_DIR} + NO_DEFAULT_PATH +) + +if((NOT ISHMEM_INCLUDE_DIR) OR (NOT ISHMEM_LIBRARY_DIR) OR (NOT ISHMEM_LIBRARY)) + set(ISHMEM_FOUND False) + set(ISHMEM_NOT_FOUND_MESSAGE "Intel SHMEM library not found! Please set ISHMEM_ROOT environment variable.") + return() +endif() + +SET(CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH} + "${ISHMEM_INCLUDE_DIR}") +SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} + "${ISHMEM_LIBRARY_DIR}") + +find_package_handle_standard_args( + ISHMEM + FOUND_VAR ISHMEM_FOUND + REQUIRED_VARS ISHMEM_INCLUDE_DIR ISHMEM_LIBRARY_DIR ISHMEM_LIBRARY + REASON_FAILURE_MESSAGE "${ISHMEM_NOT_FOUND_MESSAGE}" +) + +mark_as_advanced(ISHMEM_INCLUDE_DIR ISHMEM_LIBRARY_DIR ISHMEM_LIBRARY) diff --git a/src/BuildOnLinux.cmake b/src/BuildOnLinux.cmake index af0e90c8e6..ae0ab02248 100644 --- a/src/BuildOnLinux.cmake +++ b/src/BuildOnLinux.cmake @@ -24,6 +24,9 @@ macro(setup_common_libraries) target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL) target_link_libraries(torch_xpu_ops PUBLIC torch::xccl) target_link_libraries(torch_xpu_ops PUBLIC fmt::fmt-header-only) + if(USE_ISHMEM AND PYTORCH_FOUND_ISHMEM) + target_link_libraries(torch_xpu_ops PUBLIC torch::ishmem) + endif() endif() if(USE_SYCLTLA) @@ -57,6 +60,9 @@ else() target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL) target_link_libraries(torch_xpu_ops PUBLIC torch::xccl) target_link_libraries(torch_xpu_ops PUBLIC fmt::fmt-header-only) + if(USE_ISHMEM AND PYTORCH_FOUND_ISHMEM) + target_link_libraries(torch_xpu_ops PUBLIC torch::ishmem) + endif() endif() if(USE_SYCLTLA) From c86f32f65a6e4a8300b854d3a76f6a0f7e488983 Mon Sep 17 00:00:00 2001 From: lzhang2 Date: Tue, 20 Jan 2026 16:25:31 +0800 Subject: [PATCH 06/25] revert barrier implementations --- src/xccl/XPUSymmetricMemory.cpp | 100 ++++++++++++++++---------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp index 7bbf2407b4..21c27cf309 100644 --- a/src/xccl/XPUSymmetricMemory.cpp +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -164,56 +164,56 @@ void XPUSymmetricMemory::barrier(int channel, size_t timeout_ms) { c10::Device local_device(c10::DeviceType::XPU, local_device_idx_); c10::DeviceGuard guard(local_device); - auto stream = at::xpu::getCurrentXPUStream(); - - barrier_impl_xpu( - reinterpret_cast(signal_pads_dev_), - channel, - rank_, - world_size_, - timeout_ms, - stream); - // // Currently, we leverage oneCCL for barrier. Later, we may move to SYCL - // // implementation. - // auto group = c10d::resolve_process_group(group_name_); - // if (group == nullptr) { - // TORCH_WARN( - // "Process group '", - // group_name_, - // "' not found, please init process group first before calling - // SymmetricMemory"); - // throw std::runtime_error("Process group not found"); - // } - // auto* xcclPg = dynamic_cast( - // group->getBackend(c10::DeviceType::XPU).get()); - - // c10::Device local_device(c10::DeviceType::XPU, local_device_idx_); - // c10::DeviceGuard guard(local_device); - - // static thread_local at::Tensor barrier_tensor; - // if (!barrier_tensor.defined() || barrier_tensor.device() != local_device) { - // barrier_tensor = at::zeros( - // {1}, at::TensorOptions().device(local_device).dtype(at::kFloat)); - // } else { - // barrier_tensor.zero_(); - // } - - // c10d::AllreduceOptions arOpts; - // arOpts.asyncOp = false; - // auto work = - // xcclPg->allreduce_impl(barrier_tensor, "xccl:symm_mem_barrier", - // arOpts); - - // if (work) { - // bool success = work->wait(std::chrono::milliseconds(timeout_ms)); - // TORCH_CHECK( - // success, - // "Barrier timeout after ", - // timeout_ms, - // " ms for group '", - // group_name_, - // "'"); - // } +// auto stream = at::xpu::getCurrentXPUStream(); + +// barrier_impl_xpu( +// reinterpret_cast(signal_pads_dev_), +// channel, +// rank_, +// world_size_, +// timeout_ms, +// stream); + // Currently, we leverage oneCCL for barrier. Later, we may move to SYCL + // implementation. + auto group = c10d::resolve_process_group(group_name_); + if (group == nullptr) { + TORCH_WARN( + "Process group '", + group_name_, + "' not found, please init process group first before calling + SymmetricMemory"); + throw std::runtime_error("Process group not found"); + } + auto* xcclPg = dynamic_cast( + group->getBackend(c10::DeviceType::XPU).get()); + + c10::Device local_device(c10::DeviceType::XPU, local_device_idx_); + c10::DeviceGuard guard(local_device); + + static thread_local at::Tensor barrier_tensor; + if (!barrier_tensor.defined() || barrier_tensor.device() != local_device) { + barrier_tensor = at::zeros( + {1}, at::TensorOptions().device(local_device).dtype(at::kFloat)); + } else { + barrier_tensor.zero_(); + } + + c10d::AllreduceOptions arOpts; + arOpts.asyncOp = false; + auto work = + xcclPg->allreduce_impl(barrier_tensor, "xccl:symm_mem_barrier", + arOpts); + + if (work) { + bool success = work->wait(std::chrono::milliseconds(timeout_ms)); + TORCH_CHECK( + success, + "Barrier timeout after ", + timeout_ms, + " ms for group '", + group_name_, + "'"); + } } void XPUSymmetricMemory::put_signal( From 6bf4775e1fbe57c85321da213955ac65b5514f8c Mon Sep 17 00:00:00 2001 From: hanchao Date: Fri, 15 May 2026 14:43:15 +0800 Subject: [PATCH 07/25] rm ishmem related file --- src/xccl/ISHMEMSymmetricMemory.cpp | 459 ----------------------------- src/xccl/IpcExchange.hpp | 400 ------------------------- 2 files changed, 859 deletions(-) delete mode 100644 src/xccl/ISHMEMSymmetricMemory.cpp delete mode 100644 src/xccl/IpcExchange.hpp diff --git a/src/xccl/ISHMEMSymmetricMemory.cpp b/src/xccl/ISHMEMSymmetricMemory.cpp deleted file mode 100644 index 2a20e0bf28..0000000000 --- a/src/xccl/ISHMEMSymmetricMemory.cpp +++ /dev/null @@ -1,459 +0,0 @@ -#include -#include "XPUSymmetricMemoryUtils.hpp" - -#include -#include -#include -#include - -#include -#include - -namespace c10d { -namespace symmetric_memory { - -// Intel Data Center GPU Max can support up to 8 GPUs in a single node -constexpr int max_xpu_p2p_domain_size = 8; -// Maximum number of channels (same as CUDA) -constexpr int xpu_symm_max_nblocks = 32; -// Signal pad size for XPU -constexpr size_t xpu_signal_pad_size = - xpu_symm_max_nblocks * max_xpu_p2p_domain_size * sizeof(uint32_t); - -static StoreExchange storeExchange = StoreExchange("ISHMEMSymmetricMemory"); - -struct ISHMEMAllocation { - void* ptr; - size_t buffer_size; - int device_idx; - - ISHMEMAllocation(void* ptr, size_t buffer_size, int device_idx) - : ptr(ptr), buffer_size(buffer_size), device_idx(device_idx) {} - - ~ISHMEMAllocation() { - if (is_finalizing()) { - return; - } - c10::OptionalDeviceGuard guard; - guard.reset_device(at::Device(at::DeviceType::XPU, device_idx)); - ishmem_free(ptr); - } -}; - -// A class to hold the base pointers and signal pad pointers for a group of -// peers. One `ISHMEMPeerAllocInfo` object can be shared by multiple -// `ISHMEMSymmetricMemory` objects when latter reside on the same allocation -// and rendezvous over the same group. (The `ISHMEMSymmetricMemory` objects may -// have different offsets compared to the base address.) -class ISHMEMPeerAllocInfo : public c10::intrusive_ptr_target { - public: - ISHMEMPeerAllocInfo( - ISHMEMAllocation* allocation, - const std::string& group_name) - : base_ptr_(allocation->ptr), buffer_size_(allocation->buffer_size) { - // For logging only - static int exchanged_n_times = 0; - - c10::OptionalDeviceGuard guard; - guard.reset_device(at::Device(at::DeviceType::XPU, allocation->device_idx)); - - auto global_rank = get_group_info("0").rank; - GroupInfo& group_info = get_group_info(group_name); - auto store = group_info.store; - rank_ = group_info.rank; - world_size_ = group_info.world_size; - // Exchange rank to global rank mapping for this group. - // If it is already available, skip the exchange. - if (group_info.rank_to_global_rank.empty()) { - group_info.rank_to_global_rank = - storeExchange.all_gather(store, rank_, world_size_, global_rank); - exchanged_n_times++; - if (rank_ == 0) { - LOG(INFO) << "[rank " << rank_ << ']' - << " rank_to_global_rank: " << group_info.rank_to_global_rank - << ", group_name: " << group_name - << ", exchanged_n_times: " << exchanged_n_times; - } - } - TORCH_INTERNAL_ASSERT(!group_info.rank_to_global_rank.empty()); - rank_to_global_rank_ = group_info.rank_to_global_rank; - - world_within_xpu_p2p_ = true; - for (int r = 0; r < world_size_; ++r) { - auto peer_ptr = ishmem_ptr(base_ptr_, rank_to_global_rank_[r]); - buffers_.push_back(peer_ptr); - // If a peer is over network, `ishmem_ptr` returns null - if (peer_ptr == nullptr) { - world_within_xpu_p2p_ = false; - } - } - - // TODO: use the same allocation for signal pad - void* signal_pad_ptr = ishmem_malloc(xpu_signal_pad_size); - TORCH_CHECK(signal_pad_ptr != nullptr, "ishmem_malloc failed"); - - auto& queue = at::xpu::getCurrentSYCLQueue(); - queue.memset(signal_pad_ptr, 0, xpu_signal_pad_size).wait(); - - for (int r = 0; r < world_size_; ++r) { - signal_pads_.push_back( - ishmem_ptr(signal_pad_ptr, rank_to_global_rank_[r])); - } - - const size_t arr_size = sizeof(void*) * world_size_; - buffers_dev_ = reinterpret_cast( - c10::xpu::XPUCachingAllocator::raw_alloc(arr_size)); - signal_pads_dev_ = reinterpret_cast( - c10::xpu::XPUCachingAllocator::raw_alloc(arr_size)); - - queue.memcpy(buffers_dev_, buffers_.data(), arr_size).wait(); - queue.memcpy(signal_pads_dev_, signal_pads_.data(), arr_size).wait(); - - rank_to_global_rank_dev_ = reinterpret_cast( - c10::xpu::XPUCachingAllocator::raw_alloc(sizeof(int) * world_size_)); - queue - .memcpy( - rank_to_global_rank_dev_, - rank_to_global_rank_.data(), - sizeof(int) * world_size_) - .wait(); - } - - private: - void* base_ptr_; - size_t buffer_size_; - int rank_; - int world_size_; - std::vector buffers_; - std::vector signal_pads_; - void** buffers_dev_; - void** signal_pads_dev_; - std::vector rank_to_global_rank_; - int* rank_to_global_rank_dev_; - // Whether the world is within XPU P2P only, not network - bool world_within_xpu_p2p_; - - friend class ISHMEMSymmetricMemory; -}; - -class ISHMEMSymmetricMemory : public SymmetricMemory { - public: - ISHMEMSymmetricMemory( - ISHMEMAllocation* allocation, - const std::string& group_name) - : device_idx_(allocation->device_idx), group_name_(group_name) { - // A handle stores two types of info: - // (i) allocation's base ptrs and base signal pads, ours and peers' - pai_ = c10::make_intrusive(allocation, group_name); - // (ii) offset of tensor compared to base ptr (in byte) - offset_ = 0; - } - - // Exact copy is not needed / supported - ISHMEMSymmetricMemory(const ISHMEMSymmetricMemory& other) = delete; - - // Copy with offset is allowed - // This is mostly a shallow copy that shares the pointer to - // `ISHMEMPeerAllocInfo` which has been created by `other` - ISHMEMSymmetricMemory(const ISHMEMSymmetricMemory& other, size_t offset) - : device_idx_(other.device_idx_), - group_name_(other.group_name_), - pai_(other.pai_) { - offset_ = offset; - } - - ~ISHMEMSymmetricMemory() override{ - // TODO - }; - - std::vector get_buffer_ptrs() override { - return pai_->buffers_; - } - - std::vector get_signal_pad_ptrs() override { - return pai_->signal_pads_; - } - - void** get_buffer_ptrs_dev() override { - return pai_->buffers_dev_; - } - - void** get_signal_pad_ptrs_dev() override { - return pai_->signal_pads_dev_; - } - - size_t get_buffer_size() override { - return pai_->buffer_size_; - } - - size_t get_signal_pad_size() override { - return xpu_signal_pad_size; - }; - - bool has_multicast_support() override { - return false; - } - - void* get_multicast_ptr() override { - return nullptr; - } - - size_t get_offset() override { - return offset_; - } - - void barrier(int channel, size_t timeout_ms) override { - ishmem_barrier_all(); - } - - void put_signal(int dst_rank, int channel, size_t timeout_ms) override { - // TODO: Implement signal mechanism for ISHMEM - } - - void wait_signal(int src_rank, int channel, size_t timeout_ms) override { - // TODO: Implement signal mechanism for ISHMEM - } - - int get_rank() override { - return pai_->rank_; - } - - int get_world_size() override { - return pai_->world_size_; - } - - c10::Device get_device() override { - return c10::Device(c10::DeviceType::XPU, device_idx_); - } - - const std::vector& get_rank_to_global_rank() override { - return pai_->rank_to_global_rank_; - }; - - int* get_rank_to_global_rank_dev() override { - return pai_->rank_to_global_rank_dev_; - }; - - bool world_within_direct_access() override { - return pai_->world_within_xpu_p2p_; - } - - private: - int device_idx_; - std::string group_name_; - c10::intrusive_ptr pai_; - size_t offset_{0}; // in byte -}; - -static void initialize_ishmem_with_store( - c10::intrusive_ptr store, - int rank, - int world_size, - int device_idx) { - static bool is_initialized = false; - if (is_initialized) { - return; - } - c10::OptionalDeviceGuard guard; - guard.reset_device(at::Device(at::DeviceType::XPU, device_idx)); - - // Generate unique ID - ONLY rank 0 should generate it - ishmemx_uniqueid_t unique_id; - memset( - &unique_id, 0, sizeof(unique_id)); // Zero-initialize for non-root ranks - - if (rank == 0) { - LOG(INFO) << "[ISHMEM Init] Rank 0 generating unique ID"; - // Root rank generates the unique ID - int ret = ishmemx_get_uniqueid(&unique_id); - TORCH_CHECK(ret == 0, "ishmemx_get_uniqueid failed with error: ", ret); - LOG(INFO) << "[ISHMEM Init] Rank 0 unique ID generated"; - } - - // All-gather to distribute rank 0's unique_id to all ranks - // This creates a vector where all elements should contain rank 0's unique_id - std::vector unique_ids; - LOG(INFO) << "[ISHMEM Init] Rank " << rank - << " about to all_gather unique_id"; - try { - unique_ids = storeExchange.all_gather(store, rank, world_size, unique_id); - LOG(INFO) << "[ISHMEM Init] Rank " << rank - << " all_gather completed, received " << unique_ids.size() - << " unique_ids"; - } catch (const std::exception& e) { - LOG(ERROR) << "[ISHMEM Init] Rank " << rank - << " all_gather failed: " << e.what(); - throw; - } - - // Initialize ISHMEM with attributes using unique ID from rank 0 - ishmemx_attr_t attr; - attr.initialize_runtime = false; - attr.use_uid = true; - attr.nranks = world_size; - attr.uid = &unique_ids[0]; // Use rank 0's unique_id (first element) - - // ishmemx_init_attr returns void, not int - ishmemx_init_attr(&attr); - // Verify initialization succeeded by checking PE info - TORCH_CHECK( - ishmem_my_pe() == rank, - "ISHMEM initialization failed: rank mismatch, expected ", - rank, - " got ", - ishmem_my_pe()); - - is_initialized = true; - - // Print version - int major, minor; - ishmem_info_get_version(&major, &minor); - LOG(INFO) << "ISHMEM initialized with unique ID, version: " << major << '.' - << minor << ", rank: " << rank << "/" << world_size; -} - -class ISHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { - public: - void* alloc( - size_t size, - int device_idx, - const std::optional& group_name) override { - // Note: group_name may be passed but is ignored for ISHMEM allocations - // ISHMEM uses group "0" for all allocations - c10::OptionalDeviceGuard guard; - guard.reset_device(at::Device(at::DeviceType::XPU, device_idx)); - - auto group_info = get_group_info("0"); - auto store = group_info.store; - int rank = group_info.rank; - int world_size = group_info.world_size; - - initialize_ishmem_with_store(store, rank, world_size, device_idx); - auto ptr = ishmem_malloc(size); - // If size is 0 (which is legal allocation request) we shouldn't error out - TORCH_CHECK(ptr != nullptr || size == 0, "ishmem_malloc failed"); - // TODO: thread safety - allocations_.try_emplace( - ptr, std::make_unique(ptr, size, device_idx)); - return ptr; - } - - void free(void* ptr) override { - // TODO: thread safety - allocations_.erase(ptr); - }; - - size_t get_alloc_size(void* ptr) override { - auto it = allocations_.find(ptr); - if (it == allocations_.end()) { - TORCH_CHECK( - false, ptr, " is not allocated with ISHMEMSymmetricMemoryAllocator"); - } - return it->second->buffer_size; - }; - - c10::intrusive_ptr rendezvous( - void* ptr, - const std::optional& group_name) override { - // Use WORLD group (name "0") if group_name is not provided - std::string actual_group_name; - if (group_name.has_value()) { - actual_group_name = *group_name; - } else { - // Default to group "0" (WORLD) for ISHMEM - actual_group_name = "0"; - } - - { - auto it = symm_mems_.find(std::make_tuple(ptr, actual_group_name)); - if (it != symm_mems_.end()) { - return it->second; - } - } - // In case of MemPool, tensor.storage().data_ptr() may not match - // exactly an allocation's base address. Thus we perform the search by - // testing if the former is within an allocation's range. - auto alloc_it = std::find_if( - allocations_.begin(), allocations_.end(), [&](const auto& pair) { - auto& allocation = pair.second; - auto ptr_int = reinterpret_cast(ptr); - auto base_ptr = reinterpret_cast(allocation->ptr); - return ptr_int >= base_ptr && - ptr_int < base_ptr + allocation->buffer_size; - }); - TORCH_CHECK( - alloc_it != allocations_.end(), - "Pointer not within any SymmetricMemory allocation, " - "is the tensor allocated from SymmetricMemory?"); - - auto& allocation = alloc_it->second; - - // Search again using allocation base ptr (which is the key we use for - // caching, see below) - auto it = - symm_mems_.find(std::make_tuple(allocation->ptr, actual_group_name)); - c10::intrusive_ptr symm_mem; - if (it != symm_mems_.end()) { - // Base allocation has been rendezvoused - symm_mem = it->second; - } else { - // Create a new rendezvous - symm_mem = c10::make_intrusive( - allocation.get(), actual_group_name); - } - - // Cache rendezvous using allocation's base address as key - symm_mems_[std::make_tuple(allocation->ptr, actual_group_name)] = symm_mem; - - // TODO: change the `ptr` below to `tensor.data_ptr()` when adding support - // for user slice/view operations. For MemPool support, - // `tensor.storate().data_ptr()` is fine (today's `ptr`). - - // If the tensor's ptr happen to be the same as allocation ptr - if (ptr == allocation->ptr) { - return symm_mem; - } else { - // Return a copy of the SymmetricMemory with an offset. This is a - // "shallow" copy adjusting the offset field in the handle. - return c10::make_intrusive( - *symm_mem, (uintptr_t)ptr - (uintptr_t)allocation->ptr); - } - }; - - bool has_multicast_support(int device_idx) override { - // ISHMEM does not have multicast support - return false; - }; - - c10::DeviceType supported_device_type() override { - return c10::DeviceType::XPU; - } - - std::string name() override { - return "ISHMEM"; - } - - private: - std::unordered_map> allocations_; - std::map< - std::tuple, - c10::intrusive_ptr> - symm_mems_; -}; - -struct RegisterISHMEMSymmetricMemoryAllocator { - RegisterISHMEMSymmetricMemoryAllocator() { - auto allocator = c10::make_intrusive(); - // Always register availability to support dynamic backend switching - register_availability("ISHMEM", allocator); - // If this is the preferred backend, also set it as default - if (getSymmMemBackendXPU() == "ISHMEM") { - register_allocator(c10::DeviceType::XPU, allocator); - } - } -}; - -static RegisterISHMEMSymmetricMemoryAllocator register_allocator_; - -} // namespace symmetric_memory -} // namespace c10d diff --git a/src/xccl/IpcExchange.hpp b/src/xccl/IpcExchange.hpp deleted file mode 100644 index 600aea92c5..0000000000 --- a/src/xccl/IpcExchange.hpp +++ /dev/null @@ -1,400 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "xccl/ze_symbol.hpp" - -#include - -#include -#include -#include -#include -#include - -struct exchange_contents { - // first 4-byte is file descriptor for drmbuf or gem object - union { - ze_ipc_mem_handle_t ipc_handle; - int fd = -1; - }; - size_t offset = 0; - int pid = -1; -}; - -#define sysCheck(x) \ - if (x == -1) { \ - throw std::system_error(std::make_error_code(std::errc(errno))); \ - } - -// We can't inherit it from cmsghdr because flexible array member -struct exchange_fd { - char obscure[CMSG_LEN(sizeof(int)) - sizeof(int)]; - int fd; - - exchange_fd(int cmsg_level, int cmsg_type, int fd) : fd(fd) { - auto* cmsg = reinterpret_cast(obscure); - cmsg->cmsg_len = sizeof(exchange_fd); - cmsg->cmsg_level = cmsg_level; - cmsg->cmsg_type = cmsg_type; - } - - exchange_fd() : fd(-1) { - memset(obscure, 0, sizeof(obscure)); - }; -}; - -void un_send_fd(int sock, int fd, int rank, size_t offset) { - iovec iov[1]; - msghdr msg; - auto rank_offset = std::make_pair(rank, offset); - - iov[0].iov_base = &rank_offset; - iov[0].iov_len = sizeof(rank_offset); - msg.msg_iov = iov; - msg.msg_iovlen = 1; - msg.msg_name = nullptr; - msg.msg_namelen = 0; - - exchange_fd cmsg(SOL_SOCKET, SCM_RIGHTS, fd); - - msg.msg_control = &cmsg; - msg.msg_controllen = sizeof(exchange_fd); - sysCheck(sendmsg(sock, &msg, 0)); -} - -std::tuple un_recv_fd(int sock) { - iovec iov[1]; - msghdr msg; - std::pair rank_offset; - - iov[0].iov_base = &rank_offset; - iov[0].iov_len = sizeof(rank_offset); - msg.msg_iov = iov; - msg.msg_iovlen = 1; - msg.msg_name = nullptr; - msg.msg_namelen = 0; - - exchange_fd cmsg; - msg.msg_control = &cmsg; - msg.msg_controllen = sizeof(exchange_fd); - int n_recv = recvmsg(sock, &msg, 0); - sysCheck(n_recv); - // assert(n_recv == sizeof(int)); - - return std::make_tuple(cmsg.fd, rank_offset.first, rank_offset.second); -} - -int prepare_socket(const char* sockname) { - sockaddr_un un; - memset(&un, 0, sizeof(un)); - un.sun_family = AF_UNIX; - strcpy(un.sun_path, sockname); - - auto sock = socket(AF_UNIX, SOCK_STREAM, 0); - sysCheck(sock); - - int on = 1; - sysCheck(ioctl(sock, FIONBIO, &on)); - - auto size = offsetof(sockaddr_un, sun_path) + strlen(un.sun_path); - sysCheck(bind(sock, (sockaddr*)&un, size)); - - return sock; -} - -int server_listen(const char* sockname) { - unlink(sockname); - auto sock = prepare_socket(sockname); - sysCheck(listen(sock, 10)); - - return sock; -} - -int serv_accept(int listen_sock) { - sockaddr_un un; - - socklen_t len = sizeof(un); - auto accept_sock = accept(listen_sock, (sockaddr*)&un, &len); - sysCheck(accept_sock); - - return accept_sock; -} - -bool wait_for_socket_file(const char* path, int max_seconds = 10) { - struct stat buffer; - for (int i = 0; i < max_seconds * 10; ++i) { - if (stat(path, &buffer) == 0) { - return true; - } - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - return false; -} - -int client_connect(const char* server, const char* client) { - if (!wait_for_socket_file(server, 10)) { - std::cerr << "Error: timeout waiting for server socket file: " << server - << std::endl; - exit(EXIT_FAILURE); - } - auto sock = prepare_socket(client); - sockaddr_un sun; - memset(&sun, 0, sizeof(sun)); - sun.sun_family = AF_UNIX; - strcpy(sun.sun_path, server); - auto len = offsetof(sockaddr_un, sun_path) + strlen(server); - const int max_retries = 50; - int retry = 0; - int ret = -1; - while (retry < max_retries) { - ret = connect(sock, (sockaddr*)&sun, len); - if (ret == 0) - break; - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - retry++; - } - if (ret != 0) { - perror("connect failed"); - exit(EXIT_FAILURE); - } - - // sysCheck(connect(sock, (sockaddr*)&sun, len)); - return sock; -} - -void un_allgather( - exchange_contents* send_buf, - exchange_contents recv_buf[], - int rank, - int world) { - const char* servername_prefix = "/tmp/open-peer-ipc-mem-server-rank_"; - const char* clientname_prefix = "/tmp/open-peer-ipc-mem-client-rank_"; - char server_name[64]; - /* get username to make server_name unique */ - auto uid = getuid(); - auto pwd = getpwuid(uid); - snprintf( - server_name, - sizeof(server_name), - "%s%d_%s", - servername_prefix, - rank, - pwd->pw_name); - unlink(server_name); - auto s_listen = server_listen(server_name); - - pollfd fdarray[world]; - int recv_socks[world - 1]; - - for (auto& pollfd : fdarray) - pollfd.fd = -1; - std::fill(recv_socks, recv_socks + world - 1, -1); - - auto fd_guard = [&]() { - for (int i = 0, j = 0; i < world; ++i) { - if (i != rank && recv_socks[j] != -1) - sysCheck(close(recv_socks[j++])); - if (fdarray[i].fd != -1) - sysCheck(close(fdarray[i].fd)); - } - }; - - struct guard__ { - using F = decltype(fd_guard); - F f; - guard__(const F& f) : f(f) {} - ~guard__() { - f(); - } - } free_fd(fd_guard); - - // connect to all ranks - for (int i = 0; i < world; ++i) { - if (rank == i) { - fdarray[i].fd = s_listen; - fdarray[i].events = POLLIN; - fdarray[i].revents = 0; - } else { - char peer_name[64]; - char client_name[64]; - - snprintf( - client_name, - sizeof(client_name), - "%s%d-%d_%s", - clientname_prefix, - rank, - i, - pwd->pw_name); - unlink(client_name); - - snprintf( - peer_name, - sizeof(peer_name), - "%s%d_%s", - servername_prefix, - i, - pwd->pw_name); - fdarray[i].fd = client_connect(peer_name, client_name); - fdarray[i].events = POLLOUT; - fdarray[i].revents = 0; - } - } - - // std::future> future_fds[world -1]; - int slot = 0; - uint32_t send_progress = 1 << rank; - - while (slot < world - 1 || send_progress != (1 << world) - 1) { - sysCheck(ppoll(fdarray, world, nullptr, nullptr)); - - for (int i = 0; i < world; ++i) { - if (i == rank && (fdarray[i].revents & POLLIN)) { - // auto accept_sock = serv_accept(fdarray[i].fd); - // future_fds[slot ++] = std::async( - // std::launch::async, [=]() { - // struct sock_guard{ - // int sock; - // sock_guard(int sock) : sock(sock) {} - // ~guard_sock() {sysCheck(close(sock));} - // } release(accept_sock); - // auto ret = un_recv_fd(accept_sock); - // return ret;}); - recv_socks[slot++] = serv_accept(fdarray[i].fd); - } else if ( - (send_progress & (1 << i)) == 0 && fdarray[i].revents & POLLOUT) { - un_send_fd(fdarray[i].fd, send_buf->fd, rank, send_buf->offset); - send_progress |= 1 << i; - } - } - } - - for (int i = 0; i < world - 1; ++i) { - // future_fds[i].wait(); - // auto [fd, peer, offset] = future_fds[i].get(); - auto [fd, peer, offset] = un_recv_fd(recv_socks[i]); - recv_buf[peer].fd = fd; - recv_buf[peer].offset = offset; - } - - recv_buf[rank] = *send_buf; -} - -class IpcChannels { - public: - IpcChannels() { - initialized = false; - } - void init(sycl::queue& queue, uint32_t rank_in, uint32_t world_in) { - if (initialized) - return; - - if (!load_level_zero_library()) { - throw std::runtime_error("Failed to initialize Level Zero"); - } - - zeCheck_dynamic(zeInit_dynamic(0)); - int tmp_rank, tmp_world; - - tmp_world = world_in; - tmp_rank = rank_in; - - rank = tmp_rank; - world = tmp_world; - initialized = true; - } - void release(sycl::queue& queue) { - if (!initialized) - return; - try { - auto l0_ctx = sycl::get_native( - queue.get_context()); - for (int i = 0; i < world; i++) { - if (i != rank) { - zeCheck_dynamic(zeMemCloseIpcHandle_dynamic( - l0_ctx, (char*)buffers[i] - offsets[i])); - } - } - } catch (const std::exception& e) { - std::cerr << "Warning: Level Zero cleanup failed: " << e.what() - << std::endl; - } - sycl::free(buffers[rank], queue); - initialized = false; - } - - // buffer_size as element size - void exchange_peer_ipc_mem( - sycl::queue& queue, - void* ptr, - uint32_t rank_in, - uint32_t world_in) { - if (!initialized) - init(queue, rank_in, world_in); - if (!load_level_zero_library()) { - throw std::runtime_error("Level Zero not available"); - } - - // Step 1: Get base address of the pointer - sycl::context ctx = queue.get_context(); - auto l0_ctx = sycl::get_native(ctx); - - void* base_addr; - size_t base_size; - zeCheck_dynamic( - zeMemGetAddressRange_dynamic(l0_ctx, ptr, &base_addr, &base_size)); - - // Step 2: Get IPC mem handle from base address - alignas(64) exchange_contents send_buf; - alignas(64) exchange_contents recv_buf[world]; - - // fill in the exchange info - zeCheck_dynamic( - zeMemGetIpcHandle_dynamic(l0_ctx, base_addr, &send_buf.ipc_handle)); - send_buf.offset = (char*)ptr - (char*)base_addr; - - send_buf.pid = getpid(); - - // Step 3: Exchange the handles and offsets - memset(recv_buf, 0, sizeof(recv_buf)); - // Overkill if we don't really needs all peer's handles - un_allgather(&send_buf, recv_buf, rank, world); - for (uint32_t i = 0; i < world; i++) { - // Step 4: Prepare pid file descriptor of next process - auto* peer = recv_buf + i; - // Step 6: Open IPC handle of remote peer - auto l0_device = sycl::get_native( - queue.get_device()); - void* peer_base; - - zeCheck_dynamic(zeMemOpenIpcHandle_dynamic( - l0_ctx, - l0_device, - peer->ipc_handle, - ZE_IPC_MEMORY_FLAG_BIAS_CACHED, - &peer_base)); - - buffers[i] = (char*)peer_base + peer->offset; - offsets[i] = peer->offset; - ipc_handle[i] = send_buf.ipc_handle; - } - } - - bool initialized; - static constexpr uint32_t max_rank = 16; - void* buffers[max_rank]; - void* sync_buffer[max_rank]; - size_t offsets[max_rank]; - ze_ipc_mem_handle_t ipc_handle[max_rank]; - int rank, world; - int size_per_buffer; - int data_size_per_buffer; - int buffer_index; -}; From b97ade9796a94107f6cc077f715fafe5eeed0392 Mon Sep 17 00:00:00 2001 From: hanchao Date: Fri, 15 May 2026 14:47:06 +0800 Subject: [PATCH 08/25] clean up code and use sycl ipc --- src/xccl/Signal.cpp | 24 +-- src/xccl/Signal.hpp | 60 +------- src/xccl/XPUSymmetricMemory.cpp | 213 +++++++++++++-------------- src/xccl/XPUSymmetricMemory.hpp | 16 +- src/xccl/XPUSymmetricMemoryUtils.cpp | 5 +- src/xccl/XPUSymmetricMemoryUtils.hpp | 48 +++--- 6 files changed, 155 insertions(+), 211 deletions(-) diff --git a/src/xccl/Signal.cpp b/src/xccl/Signal.cpp index 09ad623423..6e1c240938 100644 --- a/src/xccl/Signal.cpp +++ b/src/xccl/Signal.cpp @@ -16,15 +16,15 @@ struct barrierKernel { } auto put_success = try_put_signal_device( signal_pads[target_rank] + world_size * channel + rank, 10000000); - // if (!put_success) { - // assert(0); - // } + if (!put_success) { + SYCL_KERNEL_ASSERT(false); + } auto wait_success = try_wait_signal_device( signal_pads[rank] + world_size * channel + target_rank, 10000000); - // if (!wait_success) { - // assert(0); - // } + if (!wait_success) { + SYCL_KERNEL_ASSERT(false); + } } } @@ -79,9 +79,9 @@ struct putSignalKernel { if (thread_id == 0) { auto put_success = try_put_signal_device( signal_pads[dst_rank] + world_size * channel + rank, 10000000); - // if (!put_success) { - // assert(0); - // } + if (!put_success) { + SYCL_KERNEL_ASSERT(false); + } } } @@ -141,9 +141,9 @@ struct waitSignalKernel { if (thread_id == 0) { auto wait_success = try_wait_signal_device( signal_pads[rank] + world_size * channel + src_rank, 10000000); - // if (!wait_success) { - // assert(0); - // } + if (!wait_success) { + SYCL_KERNEL_ASSERT(false); + } sycl::atomic_fence(sycl::memory_order_seq_cst, sycl::memory_scope_system); } diff --git a/src/xccl/Signal.hpp b/src/xccl/Signal.hpp index a53f1f7c15..91b43e7f6b 100644 --- a/src/xccl/Signal.hpp +++ b/src/xccl/Signal.hpp @@ -29,15 +29,6 @@ inline uint32_t load_acquire(uint32_t* addr) { return val; } -inline size_t global_timer_ns() { - auto now = std::chrono::high_resolution_clock::now(); - return std::chrono::duration_cast( - now.time_since_epoch()) - .count(); -} - -constexpr size_t ns_per_ms = 1e6; - // ============================================================================= // Put signal: wait until addr == 0, then set to 1 (release semantics) // ============================================================================= @@ -57,31 +48,6 @@ bool try_put_signal_device(uint32_t* addr, size_t max_iterations = 1000) { return true; } -// Host version using timeout -template -bool try_put_signal(uint32_t* addr, size_t timeout_ms) { - size_t deadline = global_timer_ns() + timeout_ms * ns_per_ms; - // Wait until the slot is free (value == 0) - while (load_acquire(addr) != 0) { - if (timeout_ms != 0 && global_timer_ns() > deadline) { - return false; - } - } - // Set signal to 1 with release semantics - store_release(addr, 1); - return true; -} - -// Blocking version -template -void put_signal(uint32_t* addr) { - // Wait until the slot is free (value == 0) - while (load_acquire(addr) != 0) - ; - // Set signal to 1 with release semantics - store_release(addr, 1); -} - // ============================================================================= // Wait signal: wait until addr == 1, then set to 0 (acquire semantics) // ============================================================================= @@ -89,7 +55,7 @@ void put_signal(uint32_t* addr) { // Device-compatible version using iteration count template bool try_wait_signal_device(uint32_t* addr, size_t max_iterations = 1000) { - size_t iterations = 0; + (void)max_iterations; // Wait until signal is set (value == 1) while (load_acquire(addr) != 1) { // Spin wait (no timeout check to avoid early exit) @@ -100,30 +66,6 @@ bool try_wait_signal_device(uint32_t* addr, size_t max_iterations = 1000) { return true; } -// Host version using timeout -template -bool try_wait_signal(uint32_t* addr, size_t timeout_ms) { - size_t deadline = global_timer_ns() + timeout_ms * ns_per_ms; - // Wait until signal is set (value == 1) - while (load_acquire(addr) != 1) { - // Spin wait (no timeout check to avoid early exit) - continue; - } - // Clear signal to 0 with release semantics - store_release(addr, 0); - return true; -} - -// Blocking version -template -void wait_signal(uint32_t* addr) { - // Wait until signal is set (value == 1) - while (load_acquire(addr) != 1) - ; - // Clear signal to 0 with release semantics - store_release(addr, 0); -} - void barrier_impl_xpu( uint32_t** signal_pads, int channel, diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp index 21c27cf309..6656b2b269 100644 --- a/src/xccl/XPUSymmetricMemory.cpp +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -9,6 +9,9 @@ #include #include +#include + +#include #include #include @@ -17,6 +20,18 @@ namespace symmetric_memory { static StoreExchange storeExchange = StoreExchange("XPUSymmetricMemory"); +namespace { + +bool use_signal_barrier_enabled() { + static const bool cached_value = []() { + const char* env = std::getenv("USE_SIGNAL_BARRIER"); + return env != nullptr && std::string(env) == "1"; + }(); + return cached_value; +} + +} // namespace + AllocationRef::AllocationRef( void* ptr, HandleType handle, @@ -34,6 +49,9 @@ AllocationRef::~AllocationRef() { return; } // Currently, we cannot free virtual memory exchanged from other device. + // (SYCL `ipc_memory::close` is available but calling it during teardown + // has been observed to hang on this stack; match the original L0 path + // which also skips remote unmap.) if (!local_allocation) { return; } @@ -98,8 +116,8 @@ size_t XPUSymmetricMemory::get_buffer_size() { return buffer_size_; } -size_t XPUSymmetricMemory::get_signal_pad_size() { - return signal_pad_size; +size_t XPUSymmetricMemory::get_offset() { + return 0; } bool XPUSymmetricMemory::has_multicast_support() { @@ -149,7 +167,8 @@ void check_channel(int channel, int world_size) { "must be greater than 0 (got ", channel, ")"); - const size_t num_channels = signal_pad_size / sizeof(uint32_t) * world_size; + const size_t num_channels = + get_signal_pad_size() / sizeof(uint32_t) * world_size; TORCH_CHECK( static_cast(channel) < num_channels, "The maximum supported channel for barrier(), put_signal() and wait_signal() is ", @@ -164,56 +183,51 @@ void XPUSymmetricMemory::barrier(int channel, size_t timeout_ms) { c10::Device local_device(c10::DeviceType::XPU, local_device_idx_); c10::DeviceGuard guard(local_device); -// auto stream = at::xpu::getCurrentXPUStream(); - -// barrier_impl_xpu( -// reinterpret_cast(signal_pads_dev_), -// channel, -// rank_, -// world_size_, -// timeout_ms, -// stream); - // Currently, we leverage oneCCL for barrier. Later, we may move to SYCL - // implementation. - auto group = c10d::resolve_process_group(group_name_); - if (group == nullptr) { - TORCH_WARN( - "Process group '", - group_name_, - "' not found, please init process group first before calling - SymmetricMemory"); - throw std::runtime_error("Process group not found"); - } - auto* xcclPg = dynamic_cast( - group->getBackend(c10::DeviceType::XPU).get()); - - c10::Device local_device(c10::DeviceType::XPU, local_device_idx_); - c10::DeviceGuard guard(local_device); - - static thread_local at::Tensor barrier_tensor; - if (!barrier_tensor.defined() || barrier_tensor.device() != local_device) { - barrier_tensor = at::zeros( - {1}, at::TensorOptions().device(local_device).dtype(at::kFloat)); - } else { - barrier_tensor.zero_(); - } - - c10d::AllreduceOptions arOpts; - arOpts.asyncOp = false; - auto work = - xcclPg->allreduce_impl(barrier_tensor, "xccl:symm_mem_barrier", - arOpts); - - if (work) { - bool success = work->wait(std::chrono::milliseconds(timeout_ms)); - TORCH_CHECK( - success, - "Barrier timeout after ", - timeout_ms, - " ms for group '", - group_name_, - "'"); - } + if (use_signal_barrier_enabled()) { + auto stream = at::xpu::getCurrentXPUStream(); + barrier_impl_xpu( + reinterpret_cast(signal_pads_dev_), + channel, + rank_, + world_size_, + timeout_ms, + stream); + return; + } + + auto group = c10d::resolve_process_group(group_name_); + TORCH_CHECK( + group != nullptr, + "Process group '", + group_name_, + "' not found, please init process group first before calling " + "SymmetricMemory"); + + auto backend = group->getBackend(c10::DeviceType::XPU); + + static thread_local at::Tensor barrier_tensor; + if (!barrier_tensor.defined() || barrier_tensor.device() != local_device) { + barrier_tensor = at::zeros( + {1}, at::TensorOptions().device(local_device).dtype(at::kFloat)); + } else { + barrier_tensor.zero_(); + } + + c10d::AllreduceOptions arOpts; + arOpts.asyncOp = false; + std::vector tensors = {barrier_tensor}; + auto work = backend->allreduce(tensors, arOpts); + + if (work) { + bool success = work->wait(std::chrono::milliseconds(timeout_ms)); + TORCH_CHECK( + success, + "Barrier timeout after ", + timeout_ms, + " ms for group '", + group_name_, + "'"); + } } void XPUSymmetricMemory::put_signal( @@ -287,7 +301,7 @@ void* XPUSymmetricMemoryAllocator::alloc( int device_idx, const std::optional& group_name) { size_t signal_pad_offset = at::round_up(size, 16UL); - size_t block_size = signal_pad_offset + signal_pad_size; + size_t block_size = signal_pad_offset + get_signal_pad_size(); sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); void* ptr = sycl::malloc_device(block_size, current_queue); @@ -383,10 +397,10 @@ c10::intrusive_ptr XPUSymmetricMemoryAllocator::rendezvous( c10::Device local_device(c10::DeviceType::XPU, block->device_idx); c10::DeviceGuard guard(local_device); - auto group_info = get_group_info(group_name_); - auto store = group_info.store; - int rank = group_info.rank; - int world_size = group_info.world_size; + auto group = resolve_process_group(group_name_); + auto rank = group->getRank(); + auto world_size = group->getSize(); + auto store = group->getStore(); sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); auto local_req = RendezvousRequest{ @@ -399,66 +413,49 @@ c10::intrusive_ptr XPUSymmetricMemoryAllocator::rendezvous( auto reqs = storeExchange.all_gather(store, rank, world_size, local_req); validate_rendezvous_requests(reqs, world_size); - std::vector pids(world_size); - for (int r = 0; r < world_size; ++r) { - pids[r] = reqs[r].pid; - } - - // Step 1: Get base address and offset + // Step 1: Get SYCL experimental IPC handle from the allocation base. + // `ptr` is always a base pointer here: alloc() stores ptr_to_block_[ptr] + // keyed by the malloc_device return value, and find_block() does an exact + // lookup, so by the time rendezvous() is called we already have the base. sycl::context ctx = current_queue.get_context(); - auto l0_ctx = sycl::get_native(ctx); - auto l0_device = sycl::get_native( - current_queue.get_device()); - - void* base_addr; - size_t base_size; - ZE_CHECK(zeMemGetAddressRange(l0_ctx, ptr, &base_addr, &base_size)); - size_t offset = (char*)ptr - (char*)base_addr; - - // Step 2: Get IPC mem handle from base address - ze_ipc_mem_handle_t local_ipc_handle; - ZE_CHECK(zeMemGetIpcHandle(l0_ctx, base_addr, &local_ipc_handle)); - - // Step 3: Extract fd from IPC handle (ze_ipc_mem_handle_t's first field is - // fd) - int local_fd = *reinterpret_cast(&local_ipc_handle); - - // Step 4: Exchange offsets via store - auto offsets = storeExchange.all_gather(store, rank, world_size, offset); - - // Step 5: Exchange fds via IpcChannel (uses Unix domain socket + SCM_RIGHTS) - IpcChannel ipc_channel; - auto fds = ipc_channel.all_gather_fds(rank, pids, local_fd); - - // Step 6: Reconstruct remote IPC handles and open them + sycl::device dev = current_queue.get_device(); + + namespace syclexp = sycl::ext::oneapi::experimental; + syclexp::ipc_memory::handle local_handle = + syclexp::ipc_memory::get(ptr, ctx); + syclexp::ipc_memory::handle_data_t local_handle_bytes = local_handle.data(); + std::vector local_payload( + reinterpret_cast(local_handle_bytes.data()), + reinterpret_cast(local_handle_bytes.data()) + + local_handle_bytes.size()); + + // Step 2: Exchange IPC-handle bytes via store (variable-length payload). + auto peer_handle_payloads = + storeExchange.all_gather_bytes(store, rank, world_size, local_payload); + + // Step 3: Open peer IPC handles via SYCL API. std::vector handles(world_size); std::vector buffers(world_size, nullptr); std::vector signal_pads(world_size, nullptr); for (int r = 0; r < world_size; ++r) { if (r == rank) { - handles[r] = base_addr; // Store base address as handle + handles[r] = ptr; buffers[r] = ptr; signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset); continue; } - // Reconstruct remote IPC handle by setting the fd field - ze_ipc_mem_handle_t remote_ipc_handle = local_ipc_handle; // Copy structure - *reinterpret_cast(&remote_ipc_handle) = fds[r]; // Set remote fd - - // Open IPC handle to get remote base address - void* remote_base; - ZE_CHECK(zeMemOpenIpcHandle( - l0_ctx, - l0_device, - remote_ipc_handle, - ZE_IPC_MEMORY_FLAG_BIAS_CACHED, - &remote_base)); - - handles[r] = remote_base; // Store remote base address as handle - buffers[r] = (char*)remote_base + offsets[r]; - signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset); + const auto& payload = peer_handle_payloads[r]; + syclexp::ipc_memory::handle_data_t peer_bytes( + reinterpret_cast(payload.data()), + reinterpret_cast(payload.data()) + payload.size()); + void* remote_base = syclexp::ipc_memory::open(peer_bytes, ctx, dev); + + handles[r] = remote_base; + buffers[r] = remote_base; + signal_pads[r] = + (void*)((uintptr_t)remote_base + block->signal_pad_offset); } storeExchange.barrier(store, rank, world_size); @@ -483,8 +480,8 @@ c10::intrusive_ptr XPUSymmetricMemoryAllocator::rendezvous( mc_addr, block->buffer_size, block->device_idx, - group_info.rank, - group_info.world_size); + rank, + world_size); symm_mem->set_group_name(group_name_); block->symm_mems[group_name_] = symm_mem; return symm_mem; diff --git a/src/xccl/XPUSymmetricMemory.hpp b/src/xccl/XPUSymmetricMemory.hpp index 4137a033d1..b108ea8d18 100644 --- a/src/xccl/XPUSymmetricMemory.hpp +++ b/src/xccl/XPUSymmetricMemory.hpp @@ -1,6 +1,5 @@ #pragma once -#include #include #include @@ -9,16 +8,8 @@ #include #include -#define ZE_CHECK(call) \ - do { \ - ze_result_t result = (call); \ - TORCH_CHECK( \ - result == ZE_RESULT_SUCCESS, \ - "Level Zero error: ", \ - #call, \ - " returned ", \ - result); \ - } while (0) +#include + namespace c10d::symmetric_memory { // Resource wrapper that owns a (vaddr, allocation handle) pair. Upon @@ -60,7 +51,8 @@ class XPUSymmetricMemory : public SymmetricMemory { void** get_buffer_ptrs_dev() override; void** get_signal_pad_ptrs_dev() override; size_t get_buffer_size() override; - size_t get_signal_pad_size() override; + + size_t get_offset() override; bool has_multicast_support() override; void* get_multicast_ptr() override; diff --git a/src/xccl/XPUSymmetricMemoryUtils.cpp b/src/xccl/XPUSymmetricMemoryUtils.cpp index f518166361..03894826ee 100644 --- a/src/xccl/XPUSymmetricMemoryUtils.cpp +++ b/src/xccl/XPUSymmetricMemoryUtils.cpp @@ -9,15 +9,14 @@ namespace c10d::symmetric_memory { -// Query environment variable to get the backend used for XPU Symmetric Memory. std::string getSymmMemBackendXPU() { // TORCH_SYMMMEM environment variable can be used to indicate the preferred // backend. static auto val = c10::utils::get_env("TORCH_SYMMMEM"); if (val.has_value()) { TORCH_CHECK( - val.value() == "XPU" || val.value() == "ISHMEM", - "TORCH_SYMMMEM environment variable must be one of 'XPU', 'ISHMEM'.") + val.value() == "XPU", + "TORCH_SYMMMEM environment variable must be 'XPU'.") return val.value(); } return "XPU"; diff --git a/src/xccl/XPUSymmetricMemoryUtils.hpp b/src/xccl/XPUSymmetricMemoryUtils.hpp index 36204d2f7d..4594c08e82 100644 --- a/src/xccl/XPUSymmetricMemoryUtils.hpp +++ b/src/xccl/XPUSymmetricMemoryUtils.hpp @@ -3,9 +3,7 @@ #include #include #include -// A set of store-based exchange methods with a preset prefix typically type of -// the SymmetricMemory. Most used as static instances at respective -// SymmetricMemory implementation files. + #include #include #include @@ -20,10 +18,6 @@ namespace symmetric_memory { std::string getSymmMemBackendXPU(); -// bool device_has_multicast_support(int device_idx); - -// bool allow_overlapping_devices(); - class IpcChannel { public: IpcChannel(); @@ -105,20 +99,40 @@ class StoreExchange { all_gather(store, rank, world_size, 0); } + // Variable-length byte all_gather (used to exchange SYCL IPC handles, whose + // serialized size is opaque and may differ from platform to platform). + std::vector> all_gather_bytes( + const c10::intrusive_ptr& store, + int rank, + int world_size, + const std::vector& payload) { + std::vector peer_keys; + peer_keys.reserve(world_size); + for (int r = 0; r < world_size; ++r) { + std::ostringstream oss; + oss << store_prefix_ << "/bytes/" << seq_id_ << "/" << r; + peer_keys.push_back(oss.str()); + } + ++seq_id_; + + store->set(peer_keys[rank], payload); + + std::vector> peer_vals(world_size); + peer_vals[rank] = payload; + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + continue; + } + store->wait({peer_keys[r]}); + peer_vals[r] = store->get(peer_keys[r]); + } + return peer_vals; + } + private: const std::string store_prefix_; size_t seq_id_ = 0; }; -// Returns a pointer of virtual address that is mapped to the physical memory -// held by the handle. -// todo: will follow such physical memory handle map with virtual address, -// when L0 provides physical handle exchange API and we have multicast support. -// void map_block( -// void** ptr, -// ze_physical_mem_handle_t handle, -// size_t size, -// int device_idx); - } // namespace symmetric_memory } // namespace c10d From 381b96b0024fa6ea7c1bdb7b185f6769e17a6bc1 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Fri, 15 May 2026 05:54:22 +0000 Subject: [PATCH 09/25] Add basic test case --- test/xpu/distributed/test_c10d_xccl.py | 167 +++++++++++++++++++++++-- 1 file changed, 158 insertions(+), 9 deletions(-) diff --git a/test/xpu/distributed/test_c10d_xccl.py b/test/xpu/distributed/test_c10d_xccl.py index e880e95363..26ad702113 100644 --- a/test/xpu/distributed/test_c10d_xccl.py +++ b/test/xpu/distributed/test_c10d_xccl.py @@ -29,15 +29,19 @@ import torch import torch._C._distributed_c10d import torch.distributed as c10d -import torch.distributed._functional_collectives as _functional_collectives - -if not c10d.is_available() or not c10d.is_xccl_available(): - print("c10d XCCL not available, skipping tests", file=sys.stderr) - sys.exit(0) - import torch.distributed as dist +import torch.distributed._functional_collectives as _functional_collectives +import torch.distributed._symmetric_memory as symm_mem +from torch._C._distributed_c10d import _SymmetricMemory +from torch.distributed._symmetric_memory import ( + _fused_all_gather_matmul_fallback, + _fused_matmul_reduce_scatter_fallback, +) import torch.testing._internal.common_utils as common -from torch.testing._internal.common_distributed import MultiProcessTestCase +from torch.testing._internal.common_distributed import ( + MultiProcContinuousTest, + MultiProcessTestCase, +) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, IS_SANDCASTLE, @@ -1406,14 +1410,159 @@ def test_coalescing_manager_collective(self, timing_enabled): self.assertTrue("duration_ms" not in t["entries"][0]) +# ------------------------------------------------------------------ +# XPU SymmetricMemory tests (SYCL IPC backend) +# ------------------------------------------------------------------ + +# Use the SYCL signal-pad kernel for barrier / put_signal / wait_signal. +os.environ.setdefault("USE_SIGNAL_BARRIER", "1") +# XPU does not support multicast. +os.environ["TORCH_SYMM_MEM_DISABLE_MULTICAST"] = "1" + +device_type = "xpu" + + +@instantiate_parametrized_tests +class SymmetricMemoryTest(MultiProcContinuousTest): + """XPU SymmetricMemory tests (SYCL IPC backend).""" + + @property + def device(self) -> torch.device: + return torch.device("xpu", self.rank) + + def _init_process(self): + torch.xpu.set_device(self.device) + torch.manual_seed(42 + self.rank) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_rendezvous_basic(self) -> None: + """Smoke-test the SYCL IPC rendezvous path: allocate → rendezvous → + write → barrier → read peer buffer.""" + self._init_process() + + numel = 1024 + t = symm_mem.empty(numel, dtype=torch.float32, device=self.device) + hdl = symm_mem.rendezvous(t, group=dist.group.WORLD) + + self.assertEqual(hdl.rank, self.rank) + self.assertEqual(hdl.world_size, self.world_size) + self.assertEqual(len(hdl.buffer_ptrs), self.world_size) + + t.fill_(float(self.rank)) + hdl.barrier() + + for r in range(self.world_size): + buf = hdl.get_buffer(r, (numel,), torch.float32) + self.assertTrue( + buf.eq(float(r)).all().item(), + f"peer {r} buffer != {r} (seen from rank {self.rank})", + ) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_get_signal_pad(self) -> None: + """Verify that signal-pad views (dtype, numel, data_ptr) match the + handle's metadata, and that buffer writes do not corrupt the pad.""" + self._init_process() + + t = symm_mem.empty(1, device="xpu") + hdl = symm_mem.rendezvous(t, group=dist.group.WORLD) + peer = (self.rank + 1) % self.world_size + + # Local pad pointer must match what the handle advertises. + local_pad = hdl.get_signal_pad(self.rank) + self.assertEqual(local_pad.data_ptr(), hdl.signal_pad_ptrs[hdl.rank]) + + # Default: uint32, signal_pad_size // 4 elements. + pad = hdl.get_signal_pad(peer) + self.assertEqual(pad.dtype, torch.uint32) + self.assertEqual(pad.numel(), hdl.signal_pad_size // 4) + + # Sizes only. + pad = hdl.get_signal_pad(peer, (8, 8)) + self.assertEqual(pad.dtype, torch.uint32) + self.assertEqual(pad.numel(), 64) + + # dtype only. + pad = hdl.get_signal_pad(peer, dtype=torch.uint64) + self.assertEqual(pad.dtype, torch.uint64) + self.assertEqual(pad.numel(), hdl.signal_pad_size // 8) + + # Sizes + dtype. + pad = hdl.get_signal_pad(peer, (8, 8), dtype=torch.uint64) + self.assertEqual(pad.dtype, torch.uint64) + self.assertEqual(pad.numel(), 64) + + # Writes to buffer must not corrupt the signal pad. + t2 = symm_mem.empty(1, device="xpu") + hdl2 = symm_mem.rendezvous(t2, group=dist.group.WORLD) + local_pad2 = hdl2.get_signal_pad(self.rank) + local_pad2.fill_(42) + t2.fill_(0) + self.assertTrue(local_pad2.eq(42).all().item()) + + @requires_xccl() + @skip_if_lt_x_gpu(4) + def test_subgroup(self) -> None: + """Two disjoint subgroups rendezvous on the same tensor; each can + observe its peers correctly via the SYCL IPC mapping.""" + self._init_process() + + ranks = list(range(self.world_size)) + subgroup_0 = dist.new_group(ranks[: len(ranks) // 2]) + subgroup_1 = dist.new_group(ranks[len(ranks) // 2 :]) + + world = dist.group.WORLD + subgroup = subgroup_0 if world.rank() < world.size() // 2 else subgroup_1 + + t = symm_mem.empty(64, device="xpu") + sm_world = symm_mem.rendezvous(t, group=world) + sm_sub = symm_mem.rendezvous(t, group=subgroup) + + self.assertEqual(sm_world.world_size, world.size()) + self.assertEqual(sm_world.rank, world.rank()) + self.assertEqual(sm_sub.world_size, world.size() // 2) + self.assertEqual(sm_sub.rank, world.rank() % subgroup.size()) + + t.fill_(world.rank()) + sm_world.barrier() + + peer = (world.rank() + 1) % world.size() + buf = sm_world.get_buffer(peer, (64,), torch.float32) + self.assertTrue(buf.eq(peer).all().item()) + + peer_sub = (subgroup.rank() + 1) % subgroup.size() + buf = sm_sub.get_buffer(peer_sub, (64,), torch.float32) + if world.rank() < world.size() // 2: + self.assertTrue(buf.eq(peer_sub).all().item()) + else: + self.assertTrue( + buf.eq(peer_sub + world.size() // 2).all().item() + ) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_put_wait_signal(self) -> None: + """Verify put_signal / wait_signal over the SYCL IPC peer mapping.""" + self._init_process() + + t = symm_mem.empty(1, device="xpu") + hdl = symm_mem.rendezvous(t, group=dist.group.WORLD) + + # Ring: each rank sends a signal to its right neighbor and waits for + # a signal from its left neighbor. + dst = (self.rank + 1) % self.world_size + src = (self.rank - 1) % self.world_size + hdl.put_signal(dst_rank=dst, channel=0, timeout_ms=10_000) + hdl.wait_signal(src_rank=src, channel=0, timeout_ms=10_000) + instantiate_parametrized_tests(XCCLTraceTest) instantiate_parametrized_tests(ProcessGroupXCCLTest) - class SetDeviceMethod(Enum): TORCH_XPU_SET = auto() # torch.xpu.set_device COLLECTIVE_ARGUMENT = auto() # broadcast_object_list(device=) - if __name__ == "__main__": run_tests() From 03cedc142138859faafd05ad2ed72a2dc210e084 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Fri, 15 May 2026 06:26:43 +0000 Subject: [PATCH 10/25] lint --- src/xccl/Signal.cpp | 24 ++++++++++++------------ src/xccl/XPUSymmetricMemory.cpp | 7 ++----- test/xpu/distributed/test_c10d_xccl.py | 12 ++++-------- 3 files changed, 18 insertions(+), 25 deletions(-) diff --git a/src/xccl/Signal.cpp b/src/xccl/Signal.cpp index 6e1c240938..5a8460e09a 100644 --- a/src/xccl/Signal.cpp +++ b/src/xccl/Signal.cpp @@ -16,15 +16,15 @@ struct barrierKernel { } auto put_success = try_put_signal_device( signal_pads[target_rank] + world_size * channel + rank, 10000000); - if (!put_success) { - SYCL_KERNEL_ASSERT(false); - } + if (!put_success) { + SYCL_KERNEL_ASSERT(false); + } auto wait_success = try_wait_signal_device( signal_pads[rank] + world_size * channel + target_rank, 10000000); - if (!wait_success) { - SYCL_KERNEL_ASSERT(false); - } + if (!wait_success) { + SYCL_KERNEL_ASSERT(false); + } } } @@ -79,9 +79,9 @@ struct putSignalKernel { if (thread_id == 0) { auto put_success = try_put_signal_device( signal_pads[dst_rank] + world_size * channel + rank, 10000000); - if (!put_success) { - SYCL_KERNEL_ASSERT(false); - } + if (!put_success) { + SYCL_KERNEL_ASSERT(false); + } } } @@ -141,9 +141,9 @@ struct waitSignalKernel { if (thread_id == 0) { auto wait_success = try_wait_signal_device( signal_pads[rank] + world_size * channel + src_rank, 10000000); - if (!wait_success) { - SYCL_KERNEL_ASSERT(false); - } + if (!wait_success) { + SYCL_KERNEL_ASSERT(false); + } sycl::atomic_fence(sycl::memory_order_seq_cst, sycl::memory_scope_system); } diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp index 6656b2b269..3affa9d479 100644 --- a/src/xccl/XPUSymmetricMemory.cpp +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -11,7 +11,6 @@ #include -#include #include #include @@ -421,8 +420,7 @@ c10::intrusive_ptr XPUSymmetricMemoryAllocator::rendezvous( sycl::device dev = current_queue.get_device(); namespace syclexp = sycl::ext::oneapi::experimental; - syclexp::ipc_memory::handle local_handle = - syclexp::ipc_memory::get(ptr, ctx); + syclexp::ipc_memory::handle local_handle = syclexp::ipc_memory::get(ptr, ctx); syclexp::ipc_memory::handle_data_t local_handle_bytes = local_handle.data(); std::vector local_payload( reinterpret_cast(local_handle_bytes.data()), @@ -454,8 +452,7 @@ c10::intrusive_ptr XPUSymmetricMemoryAllocator::rendezvous( handles[r] = remote_base; buffers[r] = remote_base; - signal_pads[r] = - (void*)((uintptr_t)remote_base + block->signal_pad_offset); + signal_pads[r] = (void*)((uintptr_t)remote_base + block->signal_pad_offset); } storeExchange.barrier(store, rank, world_size); diff --git a/test/xpu/distributed/test_c10d_xccl.py b/test/xpu/distributed/test_c10d_xccl.py index 26ad702113..c411ccfbdc 100644 --- a/test/xpu/distributed/test_c10d_xccl.py +++ b/test/xpu/distributed/test_c10d_xccl.py @@ -32,11 +32,6 @@ import torch.distributed as dist import torch.distributed._functional_collectives as _functional_collectives import torch.distributed._symmetric_memory as symm_mem -from torch._C._distributed_c10d import _SymmetricMemory -from torch.distributed._symmetric_memory import ( - _fused_all_gather_matmul_fallback, - _fused_matmul_reduce_scatter_fallback, -) import torch.testing._internal.common_utils as common from torch.testing._internal.common_distributed import ( MultiProcContinuousTest, @@ -1537,9 +1532,7 @@ def test_subgroup(self) -> None: if world.rank() < world.size() // 2: self.assertTrue(buf.eq(peer_sub).all().item()) else: - self.assertTrue( - buf.eq(peer_sub + world.size() // 2).all().item() - ) + self.assertTrue(buf.eq(peer_sub + world.size() // 2).all().item()) @requires_xccl() @skip_if_lt_x_gpu(2) @@ -1557,12 +1550,15 @@ def test_put_wait_signal(self) -> None: hdl.put_signal(dst_rank=dst, channel=0, timeout_ms=10_000) hdl.wait_signal(src_rank=src, channel=0, timeout_ms=10_000) + instantiate_parametrized_tests(XCCLTraceTest) instantiate_parametrized_tests(ProcessGroupXCCLTest) + class SetDeviceMethod(Enum): TORCH_XPU_SET = auto() # torch.xpu.set_device COLLECTIVE_ARGUMENT = auto() # broadcast_object_list(device=) + if __name__ == "__main__": run_tests() From 48766d6973eb02f38dcc57492f6885168089e478 Mon Sep 17 00:00:00 2001 From: hanchao Date: Fri, 15 May 2026 14:54:37 +0800 Subject: [PATCH 11/25] remove ishmem related --- CMakeLists.txt | 10 --- cmake/ISHMEM.cmake | 24 ------ cmake/Modules/FindISHMEM.cmake | 65 -------------- src/BuildOnLinux.cmake | 6 -- .../distributed/test_symmetric_memory_xccl.py | 85 ------------------- 5 files changed, 190 deletions(-) delete mode 100644 cmake/ISHMEM.cmake delete mode 100644 cmake/Modules/FindISHMEM.cmake delete mode 100644 test/xpu/distributed/test_symmetric_memory_xccl.py diff --git a/CMakeLists.txt b/CMakeLists.txt index ecd23ef8cf..8f6fa8d6d5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,16 +64,6 @@ if(USE_XCCL) caffe2_update_option(USE_C10D_XCCL OFF) update_caffe2_macros_file() endif() - if(USE_ISHMEM) - include(${TORCH_XPU_OPS_ROOT}/cmake/ISHMEM.cmake) - if(NOT PYTORCH_FOUND_ISHMEM) - message(WARNING "ISHMEM not found, disabling ISHMEM support") - caffe2_update_option(USE_ISHMEM OFF) - update_caffe2_macros_file() - else() - message(STATUS "ISHMEM support enabled") - endif() - endif() endif() set(USE_SYCLTLA ON) diff --git a/cmake/ISHMEM.cmake b/cmake/ISHMEM.cmake deleted file mode 100644 index 434935ee2c..0000000000 --- a/cmake/ISHMEM.cmake +++ /dev/null @@ -1,24 +0,0 @@ -if(NOT __ISHMEM_INCLUDED) - set(__ISHMEM_INCLUDED TRUE) - - # ISHMEM_ROOT, ISHMEM_LIBRARY_DIR, ISHMEM_INCLUDE_DIR are handled by FindISHMEM.cmake. - find_package(ISHMEM REQUIRED) - if(NOT ISHMEM_FOUND) - set(PYTORCH_FOUND_ISHMEM FALSE) - message(WARNING "${ISHMEM_NOT_FOUND_MESSAGE}") - return() - endif() - - set(PYTORCH_FOUND_ISHMEM TRUE) - add_library(torch::ishmem INTERFACE IMPORTED) - set_property( - TARGET torch::ishmem PROPERTY INTERFACE_INCLUDE_DIRECTORIES - ${ISHMEM_INCLUDE_DIR}) - set_property( - TARGET torch::ishmem PROPERTY INTERFACE_LINK_LIBRARIES - ${ISHMEM_LIBRARY}) - - message(STATUS "Found Intel SHMEM: ${ISHMEM_ROOT}") - message(STATUS " ISHMEM include dir: ${ISHMEM_INCLUDE_DIR}") - message(STATUS " ISHMEM library: ${ISHMEM_LIBRARY}") -endif() diff --git a/cmake/Modules/FindISHMEM.cmake b/cmake/Modules/FindISHMEM.cmake deleted file mode 100644 index 96c7656b85..0000000000 --- a/cmake/Modules/FindISHMEM.cmake +++ /dev/null @@ -1,65 +0,0 @@ -# This will define the following variables: -# ISHMEM_FOUND : True if the system has the ISHMEM library. -# ISHMEM_INCLUDE_DIR : Include directories needed to use ISHMEM. -# ISHMEM_LIBRARY_DIR : The path to the ISHMEM library. -# ISHMEM_LIBRARY : ISHMEM library fullname. - -include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) - -if(NOT CMAKE_SYSTEM_NAME MATCHES "Linux") - set(ISHMEM_FOUND False) - set(ISHMEM_NOT_FOUND_MESSAGE "Intel SHMEM library is only supported on Linux!") - return() -endif() - -set(ISHMEM_ROOT $ENV{ISHMEM_ROOT}) - -if(NOT ISHMEM_ROOT) - set(ISHMEM_FOUND False) - set(ISHMEM_NOT_FOUND_MESSAGE "ISHMEM_ROOT environment variable not set. Please set it to your ISHMEM installation directory.") - return() -endif() - -# Find include path from binary. -find_path( - ISHMEM_INCLUDE_DIR - NAMES ishmem.h - HINTS ${ISHMEM_ROOT}/include - NO_DEFAULT_PATH -) - -# Find library directory from binary. -find_path( - ISHMEM_LIBRARY_DIR - NAMES libishmem.a - HINTS ${ISHMEM_ROOT}/lib - NO_DEFAULT_PATH -) - -# Find ISHMEM library fullname (static library). -find_library( - ISHMEM_LIBRARY - NAMES ishmem - HINTS ${ISHMEM_LIBRARY_DIR} - NO_DEFAULT_PATH -) - -if((NOT ISHMEM_INCLUDE_DIR) OR (NOT ISHMEM_LIBRARY_DIR) OR (NOT ISHMEM_LIBRARY)) - set(ISHMEM_FOUND False) - set(ISHMEM_NOT_FOUND_MESSAGE "Intel SHMEM library not found! Please set ISHMEM_ROOT environment variable.") - return() -endif() - -SET(CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH} - "${ISHMEM_INCLUDE_DIR}") -SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} - "${ISHMEM_LIBRARY_DIR}") - -find_package_handle_standard_args( - ISHMEM - FOUND_VAR ISHMEM_FOUND - REQUIRED_VARS ISHMEM_INCLUDE_DIR ISHMEM_LIBRARY_DIR ISHMEM_LIBRARY - REASON_FAILURE_MESSAGE "${ISHMEM_NOT_FOUND_MESSAGE}" -) - -mark_as_advanced(ISHMEM_INCLUDE_DIR ISHMEM_LIBRARY_DIR ISHMEM_LIBRARY) diff --git a/src/BuildOnLinux.cmake b/src/BuildOnLinux.cmake index ae0ab02248..af0e90c8e6 100644 --- a/src/BuildOnLinux.cmake +++ b/src/BuildOnLinux.cmake @@ -24,9 +24,6 @@ macro(setup_common_libraries) target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL) target_link_libraries(torch_xpu_ops PUBLIC torch::xccl) target_link_libraries(torch_xpu_ops PUBLIC fmt::fmt-header-only) - if(USE_ISHMEM AND PYTORCH_FOUND_ISHMEM) - target_link_libraries(torch_xpu_ops PUBLIC torch::ishmem) - endif() endif() if(USE_SYCLTLA) @@ -60,9 +57,6 @@ else() target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL) target_link_libraries(torch_xpu_ops PUBLIC torch::xccl) target_link_libraries(torch_xpu_ops PUBLIC fmt::fmt-header-only) - if(USE_ISHMEM AND PYTORCH_FOUND_ISHMEM) - target_link_libraries(torch_xpu_ops PUBLIC torch::ishmem) - endif() endif() if(USE_SYCLTLA) diff --git a/test/xpu/distributed/test_symmetric_memory_xccl.py b/test/xpu/distributed/test_symmetric_memory_xccl.py deleted file mode 100644 index 37f5d3e6da..0000000000 --- a/test/xpu/distributed/test_symmetric_memory_xccl.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -import torch.distributed as dist -from test_c10d_xccl import init_multigpu_helper, requires_xccl -from torch.distributed._symmetric_memory import ( - _fused_all_gather_matmul_fallback, - _fused_matmul_reduce_scatter_fallback, -) - -from torch.testing._internal.common_distributed import MultiProcContinuousTest -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - parametrize, - run_tests -) - -@instantiate_parametrized_tests -class AsyncTPTest(MultiProcContinuousTest): - @property - def device(self) -> torch.device: - return torch.device("xpu", self.rank) - - def _init_process(self): - torch.xpu.set_device(self.device) - torch.manual_seed(42 + self.rank) - torch.use_deterministic_algorithms(True) - torch.set_deterministic_debug_mode("warn") - torch.utils.deterministic.fill_uninitialized_memory = True - - @requires_xccl() - @parametrize("gather_dim", [0, 1]) - def test_fused_all_gather_matmul(self, gather_dim: int) -> None: - self._init_process() - BATCH = 8 - M = 64 - N = 16 - K = 32 - group = dist.group.WORLD - rank = self.rank - - torch.manual_seed(42 + rank) - A_shard = torch.rand(BATCH, M // self.world_size, K, device="xpu") - Bs = [torch.rand(K, N, device="xpu") for _ in range(3)] - - ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback( - A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name - ) - ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_matmul( - A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name - ) - - self.assertEqual(ag_output_0, ag_output_1) - self.assertEqual(ag_output_0.stride(), ag_output_1.stride()) - for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1): - self.assertEqual(mm_output_0, mm_output_1) - self.assertEqual(mm_output_0.stride(), mm_output_1.stride()) - - @requires_xccl() - @parametrize("scatter_dim", [0, 1]) - def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None: - self._init_process() - - BATCH = 8 - M = 64 - N = 16 - K = 32 - group = dist.group.WORLD - rank = self.rank - - torch.manual_seed(42 + rank) - A = torch.rand(BATCH, M, K, device="xpu") - B = torch.rand(K, N, device="xpu") - - output_0 = _fused_matmul_reduce_scatter_fallback( - A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name - ) - output_1 = torch.ops.symm_mem.fused_matmul_reduce_scatter( - A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name - ) - - self.assertEqual(output_0, output_1) - self.assertEqual(output_0.stride(), output_1.stride()) - - -if __name__ == "__main__": - run_tests() From 8805b9135a818be34048bf7fcc7a14227c53f099 Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Mon, 18 May 2026 15:16:33 +0800 Subject: [PATCH 12/25] address some comments --- src/xccl/XPUSymmetricMemory.cpp | 19 ++++++++++--------- src/xccl/XPUSymmetricMemory.hpp | 1 - src/xccl/XPUSymmetricMemoryUtils.cpp | 21 +++++++++++++++++++++ src/xccl/XPUSymmetricMemoryUtils.hpp | 2 -- 4 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp index 3affa9d479..2891cbf713 100644 --- a/src/xccl/XPUSymmetricMemory.cpp +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -89,10 +89,10 @@ XPUSymmetricMemory::XPUSymmetricMemory( c10::Device local_device(c10::DeviceType::XPU, local_device_idx); c10::DeviceGuard guard(local_device); - at::xpu::getCurrentXPUStream().queue().memcpy( - buffers_dev_, buffers_.data(), arr_size); - at::xpu::getCurrentXPUStream().queue().memcpy( - signal_pads_dev_, signal_pads_.data(), arr_size); + auto& queue = at::xpu::getCurrentXPUStream().queue(); + queue.memcpy(buffers_dev_, buffers_.data(), arr_size); + queue.memcpy(signal_pads_dev_, signal_pads_.data(), arr_size); + queue.wait(); } std::vector XPUSymmetricMemory::get_buffer_ptrs() { @@ -148,9 +148,10 @@ at::Tensor XPUSymmetricMemory::get_buffer( " bytes)"); auto data_ptr = reinterpret_cast(buffers_[rank]) + storage_offset * element_size; - // check the device of this device buffer - auto ptr_to_device_id = c10::xpu::get_device_idx_from_pointer(data_ptr); - auto device = c10::Device(c10::DeviceType::XPU, ptr_to_device_id); + // Peer buffers are mapped into the local device's virtual address space via + // SYCL IPC, so the returned tensor always lives on the local device. This + // matches the contract of `CUDASymmetricMemory::get_buffer`. + auto device = c10::Device(c10::DeviceType::XPU, local_device_idx_); auto options = at::TensorOptions().dtype(dtype).device(device); return at::for_blob(data_ptr, sizes) @@ -163,11 +164,11 @@ void check_channel(int channel, int world_size) { TORCH_CHECK( channel >= 0, "channel for barrier(), put_signal() and wait_signal() ", - "must be greater than 0 (got ", + "must be greater than or equal to 0 (got ", channel, ")"); const size_t num_channels = - get_signal_pad_size() / sizeof(uint32_t) * world_size; + get_signal_pad_size() / sizeof(uint32_t) / world_size; TORCH_CHECK( static_cast(channel) < num_channels, "The maximum supported channel for barrier(), put_signal() and wait_signal() is ", diff --git a/src/xccl/XPUSymmetricMemory.hpp b/src/xccl/XPUSymmetricMemory.hpp index b108ea8d18..2bf33bef95 100644 --- a/src/xccl/XPUSymmetricMemory.hpp +++ b/src/xccl/XPUSymmetricMemory.hpp @@ -121,7 +121,6 @@ class XPUSymmetricMemoryAllocator : public SymmetricMemoryAllocator { void* ptr, const std::optional& group_name) override; bool has_multicast_support(int device_idx) override; - // void exchange_peer_ipc_mem(sycl::queue& queue, void* ptr); c10::DeviceType supported_device_type() override; std::string name() override; diff --git a/src/xccl/XPUSymmetricMemoryUtils.cpp b/src/xccl/XPUSymmetricMemoryUtils.cpp index 03894826ee..daec394964 100644 --- a/src/xccl/XPUSymmetricMemoryUtils.cpp +++ b/src/xccl/XPUSymmetricMemoryUtils.cpp @@ -31,6 +31,13 @@ IpcChannel::IpcChannel() socket_ != -1, "Failed to create socket: ", c10::utils::str_error(errno)); struct sockaddr_un addr = {.sun_family = AF_UNIX}; + TORCH_CHECK( + socket_name_.size() < sizeof(addr.sun_path), + "IpcChannel: socket path '", + socket_name_, + "' is too long for sockaddr_un::sun_path (max ", + sizeof(addr.sun_path) - 1, + " bytes). Consider setting a shorter TMPDIR."); std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path); TORCH_CHECK( @@ -51,6 +58,13 @@ void IpcChannel::send_fd(int dst_pid, int fd) { // Define destination socket address struct sockaddr_un addr = {.sun_family = AF_UNIX}; auto socket_name = get_socket_name(dst_pid); + TORCH_CHECK( + socket_name.size() < sizeof(addr.sun_path), + "IpcChannel::send_fd: socket path '", + socket_name, + "' is too long for sockaddr_un::sun_path (max ", + sizeof(addr.sun_path) - 1, + " bytes). Consider setting a shorter TMPDIR."); std::copy(socket_name.begin(), socket_name.end(), addr.sun_path); // Prepare data to send @@ -141,6 +155,13 @@ int IpcChannel::recv_fd() { // Define socket address to receive on: family AF_UNIX means unix domain // socket struct sockaddr_un addr = {.sun_family = AF_UNIX}; + TORCH_CHECK( + socket_name_.size() < sizeof(addr.sun_path), + "IpcChannel::recv_fd: socket path '", + socket_name_, + "' is too long for sockaddr_un::sun_path (max ", + sizeof(addr.sun_path) - 1, + " bytes). Consider setting a shorter TMPDIR."); std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path); // Prepare message header diff --git a/src/xccl/XPUSymmetricMemoryUtils.hpp b/src/xccl/XPUSymmetricMemoryUtils.hpp index 4594c08e82..ffd75ce79a 100644 --- a/src/xccl/XPUSymmetricMemoryUtils.hpp +++ b/src/xccl/XPUSymmetricMemoryUtils.hpp @@ -11,8 +11,6 @@ #include -#include - namespace c10d { namespace symmetric_memory { From cff6db164293f43e6415d57d45ad6a74cc1c2f90 Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Wed, 20 May 2026 15:58:22 +0800 Subject: [PATCH 13/25] Address PR review feedback on XPUSymmetricMemory And remove XPUSymmetricMemory::get_buffer --- src/xccl/XPUSymmetricMemory.cpp | 33 ---------------------------- src/xccl/XPUSymmetricMemory.hpp | 6 ----- src/xccl/XPUSymmetricMemoryUtils.cpp | 17 +++++++++++++- 3 files changed, 16 insertions(+), 40 deletions(-) diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp index 2891cbf713..8348b3cfc5 100644 --- a/src/xccl/XPUSymmetricMemory.cpp +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -127,39 +127,6 @@ void* XPUSymmetricMemory::get_multicast_ptr() { return nullptr; } -at::Tensor XPUSymmetricMemory::get_buffer( - int rank, - c10::IntArrayRef sizes, - c10::ScalarType dtype, - int64_t storage_offset) { - const size_t numel = std::accumulate( - sizes.begin(), - sizes.end(), - static_cast(1), - std::multiplies()); - const auto element_size = c10::elementSize(dtype); - const auto req_size = (numel + storage_offset) * element_size; - TORCH_CHECK( - req_size <= buffer_size_, - "XPUSymmetricMemory::get_buffer: the requested size (", - req_size, - " bytes) exceeds the allocated size (", - buffer_size_, - " bytes)"); - auto data_ptr = reinterpret_cast(buffers_[rank]) + - storage_offset * element_size; - // Peer buffers are mapped into the local device's virtual address space via - // SYCL IPC, so the returned tensor always lives on the local device. This - // matches the contract of `CUDASymmetricMemory::get_buffer`. - auto device = c10::Device(c10::DeviceType::XPU, local_device_idx_); - auto options = at::TensorOptions().dtype(dtype).device(device); - - return at::for_blob(data_ptr, sizes) - .options(options) - .target_device(device) - .make_tensor(); -} - void check_channel(int channel, int world_size) { TORCH_CHECK( channel >= 0, diff --git a/src/xccl/XPUSymmetricMemory.hpp b/src/xccl/XPUSymmetricMemory.hpp index 2bf33bef95..f47c4c8a98 100644 --- a/src/xccl/XPUSymmetricMemory.hpp +++ b/src/xccl/XPUSymmetricMemory.hpp @@ -57,12 +57,6 @@ class XPUSymmetricMemory : public SymmetricMemory { bool has_multicast_support() override; void* get_multicast_ptr() override; - at::Tensor get_buffer( - int rank, - c10::IntArrayRef sizes, - c10::ScalarType dtype, - int64_t storage_offset); - void barrier(int channel, size_t timeout_ms) override; void put_signal(int dst_rank, int channel, size_t timeout_ms) override; void wait_signal(int src_rank, int channel, size_t timeout_ms) override; diff --git a/src/xccl/XPUSymmetricMemoryUtils.cpp b/src/xccl/XPUSymmetricMemoryUtils.cpp index daec394964..9042ef4ff0 100644 --- a/src/xccl/XPUSymmetricMemoryUtils.cpp +++ b/src/xccl/XPUSymmetricMemoryUtils.cpp @@ -238,7 +238,22 @@ std::string IpcChannel::get_socket_name(int pid) { } std::ostringstream oss; oss << tmp_dir << "/symm_mem-" << pid; - return oss.str(); + std::string socket_name = oss.str(); + + // sockaddr_un::sun_path is a fixed-size buffer (108 bytes on Linux). + // Silent truncation here would produce a corrupted path and a confusing + // bind()/sendmsg() failure later, so reject up front. + constexpr size_t kMaxSunPath = sizeof(sockaddr_un{}.sun_path); + TORCH_CHECK( + socket_name.size() < kMaxSunPath, + "IpcChannel: socket path '", + socket_name, + "' (", + socket_name.size(), + " bytes) is too long for sockaddr_un::sun_path (max ", + kMaxSunPath - 1, + " bytes). Please set TMPDIR/TMP/TEMP/TEMPDIR to a shorter directory."); + return socket_name; } } // namespace c10d::symmetric_memory From 307db6df587f4c72895092c8830470319870e29f Mon Sep 17 00:00:00 2001 From: Cherry Zhang Date: Fri, 22 May 2026 11:16:58 +0800 Subject: [PATCH 14/25] Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- src/xccl/XPUSymmetricMemory.cpp | 9 ++++++++- src/xccl/XPUSymmetricMemoryUtils.cpp | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp index 8348b3cfc5..3cb5408f9b 100644 --- a/src/xccl/XPUSymmetricMemory.cpp +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -58,7 +58,8 @@ AllocationRef::~AllocationRef() { c10::DeviceGuard guard(local_device); c10::xpu::syncStreamsOnDevice(); auto stream = at::xpu::getCurrentXPUStream(); - sycl::free(ptr, stream); + auto& queue = stream.queue(); + sycl::free(ptr, queue); } XPUSymmetricMemory::XPUSymmetricMemory( @@ -365,6 +366,12 @@ c10::intrusive_ptr XPUSymmetricMemoryAllocator::rendezvous( c10::DeviceGuard guard(local_device); auto group = resolve_process_group(group_name_); + TORCH_CHECK( + group != nullptr, + "XPUSymmetricMemory::rendezvous: Could not resolve process group '", + group_name_, + "'. This can happen if rendezvous() is called before the process " + "group is initialized or if the group name is incorrect."); auto rank = group->getRank(); auto world_size = group->getSize(); auto store = group->getStore(); diff --git a/src/xccl/XPUSymmetricMemoryUtils.cpp b/src/xccl/XPUSymmetricMemoryUtils.cpp index 9042ef4ff0..37b78ffe86 100644 --- a/src/xccl/XPUSymmetricMemoryUtils.cpp +++ b/src/xccl/XPUSymmetricMemoryUtils.cpp @@ -16,7 +16,7 @@ std::string getSymmMemBackendXPU() { if (val.has_value()) { TORCH_CHECK( val.value() == "XPU", - "TORCH_SYMMMEM environment variable must be 'XPU'.") + "TORCH_SYMMMEM environment variable must be 'XPU'."); return val.value(); } return "XPU"; From 4d428fb404a7c05f5998005b10f5776d7194eae1 Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Fri, 22 May 2026 13:21:14 +0800 Subject: [PATCH 15/25] default use signal barrier --- src/xccl/XPUSymmetricMemory.cpp | 3 ++- test/xpu/distributed/test_c10d_xccl.py | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp index 3cb5408f9b..e96aa1b31a 100644 --- a/src/xccl/XPUSymmetricMemory.cpp +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -24,7 +24,8 @@ namespace { bool use_signal_barrier_enabled() { static const bool cached_value = []() { const char* env = std::getenv("USE_SIGNAL_BARRIER"); - return env != nullptr && std::string(env) == "1"; + // Default to enabled; only opt out when explicitly set to "0". + return env == nullptr || std::string(env) != "0"; }(); return cached_value; } diff --git a/test/xpu/distributed/test_c10d_xccl.py b/test/xpu/distributed/test_c10d_xccl.py index c411ccfbdc..afd3cd7acb 100644 --- a/test/xpu/distributed/test_c10d_xccl.py +++ b/test/xpu/distributed/test_c10d_xccl.py @@ -1409,8 +1409,6 @@ def test_coalescing_manager_collective(self, timing_enabled): # XPU SymmetricMemory tests (SYCL IPC backend) # ------------------------------------------------------------------ -# Use the SYCL signal-pad kernel for barrier / put_signal / wait_signal. -os.environ.setdefault("USE_SIGNAL_BARRIER", "1") # XPU does not support multicast. os.environ["TORCH_SYMM_MEM_DISABLE_MULTICAST"] = "1" From 9a488d9b17e892b584d392d2e01bcf927915f421 Mon Sep 17 00:00:00 2001 From: lzhang2 Date: Mon, 25 May 2026 11:07:55 +0800 Subject: [PATCH 16/25] align timeout check in api definition --- src/xccl/Signal.cpp | 4 ++-- src/xccl/Signal.hpp | 16 +++++----------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/xccl/Signal.cpp b/src/xccl/Signal.cpp index 5a8460e09a..61086ad0a2 100644 --- a/src/xccl/Signal.cpp +++ b/src/xccl/Signal.cpp @@ -15,13 +15,13 @@ struct barrierKernel { return; } auto put_success = try_put_signal_device( - signal_pads[target_rank] + world_size * channel + rank, 10000000); + signal_pads[target_rank] + world_size * channel + rank, timeout_ms); if (!put_success) { SYCL_KERNEL_ASSERT(false); } auto wait_success = try_wait_signal_device( - signal_pads[rank] + world_size * channel + target_rank, 10000000); + signal_pads[rank] + world_size * channel + target_rank, timeout_ms); if (!wait_success) { SYCL_KERNEL_ASSERT(false); } diff --git a/src/xccl/Signal.hpp b/src/xccl/Signal.hpp index 91b43e7f6b..b9089e31cf 100644 --- a/src/xccl/Signal.hpp +++ b/src/xccl/Signal.hpp @@ -33,15 +33,12 @@ inline uint32_t load_acquire(uint32_t* addr) { // Put signal: wait until addr == 0, then set to 1 (release semantics) // ============================================================================= -// Device-compatible version using iteration count template -bool try_put_signal_device(uint32_t* addr, size_t max_iterations = 1000) { - size_t iterations = 0; +bool try_put_signal_device(uint32_t* addr, size_t timeout_ms) { // Wait until the slot is free (value == 0) while (load_acquire(addr) != 0) { - if (max_iterations != 0 && iterations++ > max_iterations) { - return false; - } + // Spin wait (no timeout check as IGC issue) + continue; } // Set signal to 1 with release semantics store_release(addr, 1); @@ -51,14 +48,11 @@ bool try_put_signal_device(uint32_t* addr, size_t max_iterations = 1000) { // ============================================================================= // Wait signal: wait until addr == 1, then set to 0 (acquire semantics) // ============================================================================= - -// Device-compatible version using iteration count template -bool try_wait_signal_device(uint32_t* addr, size_t max_iterations = 1000) { - (void)max_iterations; +bool try_wait_signal_device(uint32_t* addr, size_t timeout_ms) { // Wait until signal is set (value == 1) while (load_acquire(addr) != 1) { - // Spin wait (no timeout check to avoid early exit) + // Spin wait (no timeout check as IGC issue) continue; } // Clear signal to 0 with release semantics From ec2bfbaa82e144c2202e56c72c3737e174cdcb54 Mon Sep 17 00:00:00 2001 From: Cherry Zhang Date: Mon, 25 May 2026 12:58:49 +0800 Subject: [PATCH 17/25] Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- src/xccl/XPUSymmetricMemory.cpp | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp index e96aa1b31a..3231cabef6 100644 --- a/src/xccl/XPUSymmetricMemory.cpp +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -11,16 +11,23 @@ #include +#include #include #include namespace c10d { namespace symmetric_memory { -static StoreExchange storeExchange = StoreExchange("XPUSymmetricMemory"); - namespace { +std::atomic store_exchange_nonce{0}; + +thread_local StoreExchange storeExchange = []() { + const auto nonce = + store_exchange_nonce.fetch_add(1, std::memory_order_relaxed); + return StoreExchange("XPUSymmetricMemory_" + std::to_string(nonce)); +}(); + bool use_signal_barrier_enabled() { static const bool cached_value = []() { const char* env = std::getenv("USE_SIGNAL_BARRIER"); @@ -272,6 +279,8 @@ void* XPUSymmetricMemoryAllocator::alloc( size_t signal_pad_offset = at::round_up(size, 16UL); size_t block_size = signal_pad_offset + get_signal_pad_size(); + c10::DeviceGuard device_guard( + c10::Device(c10::DeviceType::XPU, device_idx)); sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); void* ptr = sycl::malloc_device(block_size, current_queue); current_queue.memset(ptr, 0, block_size); From c1b66ccfc6768e0ba9af8010c812f41310dd0ad0 Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Mon, 25 May 2026 13:12:20 +0800 Subject: [PATCH 18/25] remove unused IpcChannel --- src/xccl/XPUSymmetricMemoryUtils.cpp | 239 --------------------------- src/xccl/XPUSymmetricMemoryUtils.hpp | 31 ---- 2 files changed, 270 deletions(-) diff --git a/src/xccl/XPUSymmetricMemoryUtils.cpp b/src/xccl/XPUSymmetricMemoryUtils.cpp index 37b78ffe86..ec165a08c7 100644 --- a/src/xccl/XPUSymmetricMemoryUtils.cpp +++ b/src/xccl/XPUSymmetricMemoryUtils.cpp @@ -1,10 +1,5 @@ #include -#include -#include -#include -#include - #include namespace c10d::symmetric_memory { @@ -22,238 +17,4 @@ std::string getSymmMemBackendXPU() { return "XPU"; } -IpcChannel::IpcChannel() - : socket_name_(get_socket_name(getpid())), - socket_(socket(AF_UNIX, SOCK_DGRAM, 0)) { - // On success, a file descriptor for the new socket is returned. - // On error, -1 is returned, and errno is set to indicate the error. - TORCH_CHECK( - socket_ != -1, "Failed to create socket: ", c10::utils::str_error(errno)); - - struct sockaddr_un addr = {.sun_family = AF_UNIX}; - TORCH_CHECK( - socket_name_.size() < sizeof(addr.sun_path), - "IpcChannel: socket path '", - socket_name_, - "' is too long for sockaddr_un::sun_path (max ", - sizeof(addr.sun_path) - 1, - " bytes). Consider setting a shorter TMPDIR."); - std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path); - - TORCH_CHECK( - bind(socket_, (struct sockaddr*)&addr, SUN_LEN(&addr)) == 0, - "Failed to bind socket: ", - c10::utils::str_error(errno)); -} - -IpcChannel::~IpcChannel() { - close(socket_); - unlink(socket_name_.c_str()); -} - -void IpcChannel::send_fd(int dst_pid, int fd) { - // Because file descriptors are process-local kernel objects, and we can’t - // pass them via normal socket payloads (like write() or send()). Unix domain - // sockets provide a mechanism to pass actual FDs via sendmsg()/recvmsg(). - // Define destination socket address - struct sockaddr_un addr = {.sun_family = AF_UNIX}; - auto socket_name = get_socket_name(dst_pid); - TORCH_CHECK( - socket_name.size() < sizeof(addr.sun_path), - "IpcChannel::send_fd: socket path '", - socket_name, - "' is too long for sockaddr_un::sun_path (max ", - sizeof(addr.sun_path) - 1, - " bytes). Consider setting a shorter TMPDIR."); - std::copy(socket_name.begin(), socket_name.end(), addr.sun_path); - - // Prepare data to send - // Data being sent is "fd", the value of fd will be sent as auxiliary data - // (control message) - struct iovec io = {.iov_base = (void*)("fd"), .iov_len = 2}; - - // Prepare control message data buffer and zero it out - // NOLINTNEXTLINE(*array*) - char cbuf[CMSG_SPACE(sizeof(int))]; - memset(cbuf, 0, sizeof(cbuf)); - - // Create message header - struct msghdr msg { - // destination socket address and size of it - // message content in msg_iov and number of such structs (1 in our case) - // auxiliary data with the value of fd and size of it - .msg_name = (void*)&addr, .msg_namelen = sizeof(struct sockaddr_un), - .msg_iov = &io, .msg_iovlen = 1, .msg_control = cbuf, - .msg_controllen = sizeof(cbuf) - }; - - // This points to the first control message header - // With SCM_RIGHTS we let the kernel know that we are passing file - // descriptors. - auto cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_len = CMSG_LEN(sizeof(int)); - // Specify socket level message - cmsg->cmsg_level = SOL_SOCKET; - // SCM_RIGHTS is the type used to pass file descriptors - cmsg->cmsg_type = SCM_RIGHTS; - - if (fd != -1) { - std::copy( - reinterpret_cast(&fd), - reinterpret_cast(&fd) + sizeof(fd), - reinterpret_cast(CMSG_DATA(cmsg))); - } else { - msg.msg_controllen = 0; - } - - // Retry sending with exponential backoff (wait for destination socket to be - // ready) - const int max_retries = 100; - int retry = 0; - ssize_t result = -1; - - while (retry < max_retries) { - result = sendmsg(socket_, &msg, 0); - if (result > 0) { - return; // Success - } - - // Check if error is because destination doesn't exist yet - if (errno == ENOENT || errno == ECONNREFUSED) { - // Exponential backoff: 1ms, 2ms, 4ms, ..., up to 100ms - int sleep_ms = std::min(1 << retry, 100); - usleep(sleep_ms * 1000); - retry++; - continue; - } - - // Other errors should fail immediately - break; - } - - // Finally check if we succeeded or report error - TORCH_CHECK( - result > 0, - "Failed to send fd after ", - retry, - " retries: ", - c10::utils::str_error(errno)); -} - -int IpcChannel::recv_fd() { - // Prepare buffer for regular message "fd" - // NOLINTNEXTLINE(*array*) - char buf[2]; - memset(&buf, 0, sizeof(buf)); - struct iovec io = {.iov_base = (void*)buf, .iov_len = sizeof(buf)}; - - // Prepare buffer for control message and zero it out - // NOLINTNEXTLINE(*array*) - char cbuf[CMSG_SPACE(sizeof(int))]; - memset(cbuf, 0, sizeof(cbuf)); - - // Define socket address to receive on: family AF_UNIX means unix domain - // socket - struct sockaddr_un addr = {.sun_family = AF_UNIX}; - TORCH_CHECK( - socket_name_.size() < sizeof(addr.sun_path), - "IpcChannel::recv_fd: socket path '", - socket_name_, - "' is too long for sockaddr_un::sun_path (max ", - sizeof(addr.sun_path) - 1, - " bytes). Consider setting a shorter TMPDIR."); - std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path); - - // Prepare message header - struct msghdr msg = { - .msg_name = (void*)&addr, - .msg_namelen = sizeof(struct sockaddr_un), - .msg_iov = &io, - .msg_iovlen = 1, - .msg_control = cbuf, - .msg_controllen = sizeof(cbuf)}; - - // Receive message on socket_ - TORCH_CHECK( - recvmsg(socket_, &msg, 0) > 0, - "Failed to receive fd: ", - c10::utils::str_error(errno)); - - if (msg.msg_controllen == 0) { - return -1; - } - - // Extract control message and validate its content - auto cmsg = CMSG_FIRSTHDR(&msg); - TORCH_CHECK(cmsg != nullptr); - TORCH_CHECK(cmsg->cmsg_len == CMSG_LEN(sizeof(int))); - TORCH_CHECK(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS); - return *reinterpret_cast(CMSG_DATA(cmsg)); -} - -std::vector IpcChannel::all_gather_fds( - int rank, - const std::vector& pids, - int fd) { - int world_size = static_cast(pids.size()); - std::vector fds(pids.size()); - fds[rank] = fd; - - int dst_rank = (rank + 1) % world_size; - for (int step = 1; step < world_size; ++step) { - int src_rank = (rank + world_size - step) % world_size; - send_fd(pids[dst_rank], fd); - fd = recv_fd(); - fds[src_rank] = fd; - } - return fds; -} - -int IpcChannel::broadcast_fds( - int rank, - int src_rank, - const std::vector& pids, - int fd) { - int world_size = static_cast(pids.size()); - - if (rank == src_rank) { - for (int dst_rank = 0; dst_rank < world_size; ++dst_rank) { - if (dst_rank == rank) { - continue; - } - send_fd(pids[dst_rank], fd); - } - return fd; - } - return recv_fd(); -} - -std::string IpcChannel::get_socket_name(int pid) { - const char* tmp_dir = "/tmp"; - for (const char* env_var : {"TMPDIR", "TMP", "TEMP", "TEMPDIR"}) { - if (const char* path = getenv(env_var)) { - tmp_dir = path; - break; - } - } - std::ostringstream oss; - oss << tmp_dir << "/symm_mem-" << pid; - std::string socket_name = oss.str(); - - // sockaddr_un::sun_path is a fixed-size buffer (108 bytes on Linux). - // Silent truncation here would produce a corrupted path and a confusing - // bind()/sendmsg() failure later, so reject up front. - constexpr size_t kMaxSunPath = sizeof(sockaddr_un{}.sun_path); - TORCH_CHECK( - socket_name.size() < kMaxSunPath, - "IpcChannel: socket path '", - socket_name, - "' (", - socket_name.size(), - " bytes) is too long for sockaddr_un::sun_path (max ", - kMaxSunPath - 1, - " bytes). Please set TMPDIR/TMP/TEMP/TEMPDIR to a shorter directory."); - return socket_name; -} - } // namespace c10d::symmetric_memory diff --git a/src/xccl/XPUSymmetricMemoryUtils.hpp b/src/xccl/XPUSymmetricMemoryUtils.hpp index ffd75ce79a..b2a7661dcb 100644 --- a/src/xccl/XPUSymmetricMemoryUtils.hpp +++ b/src/xccl/XPUSymmetricMemoryUtils.hpp @@ -4,11 +4,6 @@ #include #include -#include -#include -#include -#include - #include namespace c10d { @@ -16,32 +11,6 @@ namespace symmetric_memory { std::string getSymmMemBackendXPU(); -class IpcChannel { - public: - IpcChannel(); - ~IpcChannel(); - - void send_fd(int dst_pid, int fd); - int recv_fd(); - - std::vector all_gather_fds( - int rank, - const std::vector& pids, - int fd); - - int broadcast_fds( - int rank, - int src_rank, - const std::vector& pids, - int fd); - - private: - static std::string get_socket_name(int pid); - - std::string socket_name_; - int socket_; -}; - class StoreExchange { public: StoreExchange(const std::string& store_prefix) From af0dc7402e38e3ddf92914565164c721db6136c3 Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Mon, 25 May 2026 13:23:10 +0800 Subject: [PATCH 19/25] lint --- src/xccl/XPUSymmetricMemory.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp index 3231cabef6..605abbff37 100644 --- a/src/xccl/XPUSymmetricMemory.cpp +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -11,9 +11,9 @@ #include -#include #include #include +#include namespace c10d { namespace symmetric_memory { @@ -279,8 +279,7 @@ void* XPUSymmetricMemoryAllocator::alloc( size_t signal_pad_offset = at::round_up(size, 16UL); size_t block_size = signal_pad_offset + get_signal_pad_size(); - c10::DeviceGuard device_guard( - c10::Device(c10::DeviceType::XPU, device_idx)); + c10::DeviceGuard device_guard(c10::Device(c10::DeviceType::XPU, device_idx)); sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); void* ptr = sycl::malloc_device(block_size, current_queue); current_queue.memset(ptr, 0, block_size); From 9fb690fae45f2ed34984953b7da27de40926246d Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Tue, 26 May 2026 06:15:45 +0000 Subject: [PATCH 20/25] Implemented prctl(PR_SET_PTRACER, ppid, 0, 0, 0) during the initialization --- src/xccl/XPUSymmetricMemory.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp index 605abbff37..f7aa4f36fc 100644 --- a/src/xccl/XPUSymmetricMemory.cpp +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -11,9 +11,11 @@ #include +#include #include #include #include +#include namespace c10d { namespace symmetric_memory { @@ -386,6 +388,14 @@ c10::intrusive_ptr XPUSymmetricMemoryAllocator::rendezvous( auto store = group->getStore(); sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); + // SYCL/L0 IPC import uses `pidfd_getfd` between peer processes. + // Using prctl(PR_SET_PTRACER, ppid) ensures that only the parent process can + // trace the current process. + static std::once_flag prctl_once; + std::call_once(prctl_once, []() { + (void)::prctl(PR_SET_PTRACER, ::getppid(), 0, 0, 0); + }); + auto local_req = RendezvousRequest{ .device_idx = block->device_idx, .pid = getpid(), From fd39f3d969bc5eb2eda9e46311632edf64a71f55 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Tue, 26 May 2026 06:21:02 +0000 Subject: [PATCH 21/25] lint --- src/xccl/XPUSymmetricMemory.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp index f7aa4f36fc..52fdd1d13c 100644 --- a/src/xccl/XPUSymmetricMemory.cpp +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -391,8 +391,8 @@ c10::intrusive_ptr XPUSymmetricMemoryAllocator::rendezvous( // SYCL/L0 IPC import uses `pidfd_getfd` between peer processes. // Using prctl(PR_SET_PTRACER, ppid) ensures that only the parent process can // trace the current process. - static std::once_flag prctl_once; - std::call_once(prctl_once, []() { + static c10::once_flag prctl_once; + c10::call_once(prctl_once, []() { (void)::prctl(PR_SET_PTRACER, ::getppid(), 0, 0, 0); }); From 947bd9be79fb1967359c929d026e70f16e07650f Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Fri, 29 May 2026 07:59:38 +0000 Subject: [PATCH 22/25] ppid to any --- src/xccl/XPUSymmetricMemory.cpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp index 52fdd1d13c..5a993786ae 100644 --- a/src/xccl/XPUSymmetricMemory.cpp +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -388,12 +388,18 @@ c10::intrusive_ptr XPUSymmetricMemoryAllocator::rendezvous( auto store = group->getStore(); sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); - // SYCL/L0 IPC import uses `pidfd_getfd` between peer processes. - // Using prctl(PR_SET_PTRACER, ppid) ensures that only the parent process can - // trace the current process. + // SYCL/L0 IPC import uses `pidfd_getfd` between peer processes, which + // requires PTRACE_MODE_ATTACH_REALCREDS on the target pid. With Yama + // (/proc/sys/kernel/yama/ptrace_scope >= 1, the default on most distros), + // only ancestor processes or those explicitly whitelisted via + // PR_SET_PTRACER can attach. Sibling ranks spawned by a launcher are + // neither ancestors nor descendants of each other, so + // PR_SET_PTRACER(getppid()) is NOT sufficient -- it only authorizes the + // launcher. We need PR_SET_PTRACER_ANY so any peer rank can import our + // IPC handles. This is scoped to the lifetime of this process. static c10::once_flag prctl_once; c10::call_once(prctl_once, []() { - (void)::prctl(PR_SET_PTRACER, ::getppid(), 0, 0, 0); + (void)::prctl(PR_SET_PTRACER, PR_SET_PTRACER_ANY, 0, 0, 0); }); auto local_req = RendezvousRequest{ From 97ae6ca2c6d6869cdc33da05145d6f9e4ca60c9b Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Tue, 2 Jun 2026 15:49:04 +0800 Subject: [PATCH 23/25] rm fallback barrier and first test symm --- src/xccl/XPUSymmetricMemory.cpp | 62 +----- test/xpu/distributed/test_c10d_xccl.py | 294 +++++++++++++------------ 2 files changed, 159 insertions(+), 197 deletions(-) diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp index 5a993786ae..c2fdd8ffe5 100644 --- a/src/xccl/XPUSymmetricMemory.cpp +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -30,15 +30,6 @@ thread_local StoreExchange storeExchange = []() { return StoreExchange("XPUSymmetricMemory_" + std::to_string(nonce)); }(); -bool use_signal_barrier_enabled() { - static const bool cached_value = []() { - const char* env = std::getenv("USE_SIGNAL_BARRIER"); - // Default to enabled; only opt out when explicitly set to "0". - return env == nullptr || std::string(env) != "0"; - }(); - return cached_value; -} - } // namespace AllocationRef::AllocationRef( @@ -161,51 +152,16 @@ void XPUSymmetricMemory::barrier(int channel, size_t timeout_ms) { c10::Device local_device(c10::DeviceType::XPU, local_device_idx_); c10::DeviceGuard guard(local_device); - if (use_signal_barrier_enabled()) { - auto stream = at::xpu::getCurrentXPUStream(); - barrier_impl_xpu( - reinterpret_cast(signal_pads_dev_), - channel, - rank_, - world_size_, - timeout_ms, - stream); - return; - } - - auto group = c10d::resolve_process_group(group_name_); - TORCH_CHECK( - group != nullptr, - "Process group '", - group_name_, - "' not found, please init process group first before calling " - "SymmetricMemory"); - - auto backend = group->getBackend(c10::DeviceType::XPU); - - static thread_local at::Tensor barrier_tensor; - if (!barrier_tensor.defined() || barrier_tensor.device() != local_device) { - barrier_tensor = at::zeros( - {1}, at::TensorOptions().device(local_device).dtype(at::kFloat)); - } else { - barrier_tensor.zero_(); - } - c10d::AllreduceOptions arOpts; - arOpts.asyncOp = false; - std::vector tensors = {barrier_tensor}; - auto work = backend->allreduce(tensors, arOpts); - - if (work) { - bool success = work->wait(std::chrono::milliseconds(timeout_ms)); - TORCH_CHECK( - success, - "Barrier timeout after ", - timeout_ms, - " ms for group '", - group_name_, - "'"); - } + auto stream = at::xpu::getCurrentXPUStream(); + barrier_impl_xpu( + reinterpret_cast(signal_pads_dev_), + channel, + rank_, + world_size_, + timeout_ms, + stream); + return; } void XPUSymmetricMemory::put_signal( diff --git a/test/xpu/distributed/test_c10d_xccl.py b/test/xpu/distributed/test_c10d_xccl.py index afd3cd7acb..0b6854085d 100644 --- a/test/xpu/distributed/test_c10d_xccl.py +++ b/test/xpu/distributed/test_c10d_xccl.py @@ -113,6 +113,156 @@ def simple_reduce_tests(rank, world_size): TEST_MULTIXPU = torch.xpu.device_count() > 1 +# ------------------------------------------------------------------ +# XPU SymmetricMemory tests (SYCL IPC backend) +# ------------------------------------------------------------------ + +# XPU does not support multicast. +os.environ["TORCH_SYMM_MEM_DISABLE_MULTICAST"] = "1" + +device_type = "xpu" + +try: + from torch.testing._internal.inductor_utils import ( + HAS_XPU_AND_TRITON as _HAS_XPU_AND_TRITON, + ) +except ImportError: + _HAS_XPU_AND_TRITON = False + + +@instantiate_parametrized_tests +class SymmetricMemoryTest(MultiProcContinuousTest): + """XPU SymmetricMemory tests (SYCL IPC backend).""" + + @property + def device(self) -> torch.device: + return torch.device("xpu", self.rank) + + def _init_process(self): + torch.xpu.set_device(self.device) + torch.manual_seed(42 + self.rank) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_rendezvous_basic(self) -> None: + """Smoke-test the SYCL IPC rendezvous path: allocate → rendezvous → + write → barrier → read peer buffer.""" + self._init_process() + + numel = 1024 + t = symm_mem.empty(numel, dtype=torch.float32, device=self.device) + hdl = symm_mem.rendezvous(t, group=dist.group.WORLD) + + self.assertEqual(hdl.rank, self.rank) + self.assertEqual(hdl.world_size, self.world_size) + self.assertEqual(len(hdl.buffer_ptrs), self.world_size) + + t.fill_(float(self.rank)) + hdl.barrier() + + for r in range(self.world_size): + buf = hdl.get_buffer(r, (numel,), torch.float32) + self.assertTrue( + buf.eq(float(r)).all().item(), + f"peer {r} buffer != {r} (seen from rank {self.rank})", + ) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_get_signal_pad(self) -> None: + """Verify that signal-pad views (dtype, numel, data_ptr) match the + handle's metadata, and that buffer writes do not corrupt the pad.""" + self._init_process() + + t = symm_mem.empty(1, device="xpu") + hdl = symm_mem.rendezvous(t, group=dist.group.WORLD) + peer = (self.rank + 1) % self.world_size + + # Local pad pointer must match what the handle advertises. + local_pad = hdl.get_signal_pad(self.rank) + self.assertEqual(local_pad.data_ptr(), hdl.signal_pad_ptrs[hdl.rank]) + + # Default: uint32, signal_pad_size // 4 elements. + pad = hdl.get_signal_pad(peer) + self.assertEqual(pad.dtype, torch.uint32) + self.assertEqual(pad.numel(), hdl.signal_pad_size // 4) + + # Sizes only. + pad = hdl.get_signal_pad(peer, (8, 8)) + self.assertEqual(pad.dtype, torch.uint32) + self.assertEqual(pad.numel(), 64) + + # dtype only. + pad = hdl.get_signal_pad(peer, dtype=torch.uint64) + self.assertEqual(pad.dtype, torch.uint64) + self.assertEqual(pad.numel(), hdl.signal_pad_size // 8) + + # Sizes + dtype. + pad = hdl.get_signal_pad(peer, (8, 8), dtype=torch.uint64) + self.assertEqual(pad.dtype, torch.uint64) + self.assertEqual(pad.numel(), 64) + + # Writes to buffer must not corrupt the signal pad. + t2 = symm_mem.empty(1, device="xpu") + hdl2 = symm_mem.rendezvous(t2, group=dist.group.WORLD) + local_pad2 = hdl2.get_signal_pad(self.rank) + local_pad2.fill_(42) + t2.fill_(0) + self.assertTrue(local_pad2.eq(42).all().item()) + + @requires_xccl() + @skip_if_lt_x_gpu(4) + def test_subgroup(self) -> None: + """Two disjoint subgroups rendezvous on the same tensor; each can + observe its peers correctly via the SYCL IPC mapping.""" + self._init_process() + + ranks = list(range(self.world_size)) + subgroup_0 = dist.new_group(ranks[: len(ranks) // 2]) + subgroup_1 = dist.new_group(ranks[len(ranks) // 2 :]) + + world = dist.group.WORLD + subgroup = subgroup_0 if world.rank() < world.size() // 2 else subgroup_1 + + t = symm_mem.empty(64, device="xpu") + sm_world = symm_mem.rendezvous(t, group=world) + sm_sub = symm_mem.rendezvous(t, group=subgroup) + + self.assertEqual(sm_world.world_size, world.size()) + self.assertEqual(sm_world.rank, world.rank()) + self.assertEqual(sm_sub.world_size, world.size() // 2) + self.assertEqual(sm_sub.rank, world.rank() % subgroup.size()) + + t.fill_(world.rank()) + sm_world.barrier() + + peer = (world.rank() + 1) % world.size() + buf = sm_world.get_buffer(peer, (64,), torch.float32) + self.assertTrue(buf.eq(peer).all().item()) + + peer_sub = (subgroup.rank() + 1) % subgroup.size() + buf = sm_sub.get_buffer(peer_sub, (64,), torch.float32) + if world.rank() < world.size() // 2: + self.assertTrue(buf.eq(peer_sub).all().item()) + else: + self.assertTrue(buf.eq(peer_sub + world.size() // 2).all().item()) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_put_wait_signal(self) -> None: + """Verify put_signal / wait_signal over the SYCL IPC peer mapping.""" + self._init_process() + + t = symm_mem.empty(1, device="xpu") + hdl = symm_mem.rendezvous(t, group=dist.group.WORLD) + + # Ring: each rank sends a signal to its right neighbor and waits for + # a signal from its left neighbor. + dst = (self.rank + 1) % self.world_size + src = (self.rank - 1) % self.world_size + hdl.put_signal(dst_rank=dst, channel=0, timeout_ms=10_000) + hdl.wait_signal(src_rank=src, channel=0, timeout_ms=10_000) + class RendezvousEnvTest(TestCase): @retry_on_connect_failures @@ -1405,150 +1555,6 @@ def test_coalescing_manager_collective(self, timing_enabled): self.assertTrue("duration_ms" not in t["entries"][0]) -# ------------------------------------------------------------------ -# XPU SymmetricMemory tests (SYCL IPC backend) -# ------------------------------------------------------------------ - -# XPU does not support multicast. -os.environ["TORCH_SYMM_MEM_DISABLE_MULTICAST"] = "1" - -device_type = "xpu" - - -@instantiate_parametrized_tests -class SymmetricMemoryTest(MultiProcContinuousTest): - """XPU SymmetricMemory tests (SYCL IPC backend).""" - - @property - def device(self) -> torch.device: - return torch.device("xpu", self.rank) - - def _init_process(self): - torch.xpu.set_device(self.device) - torch.manual_seed(42 + self.rank) - - @requires_xccl() - @skip_if_lt_x_gpu(2) - def test_rendezvous_basic(self) -> None: - """Smoke-test the SYCL IPC rendezvous path: allocate → rendezvous → - write → barrier → read peer buffer.""" - self._init_process() - - numel = 1024 - t = symm_mem.empty(numel, dtype=torch.float32, device=self.device) - hdl = symm_mem.rendezvous(t, group=dist.group.WORLD) - - self.assertEqual(hdl.rank, self.rank) - self.assertEqual(hdl.world_size, self.world_size) - self.assertEqual(len(hdl.buffer_ptrs), self.world_size) - - t.fill_(float(self.rank)) - hdl.barrier() - - for r in range(self.world_size): - buf = hdl.get_buffer(r, (numel,), torch.float32) - self.assertTrue( - buf.eq(float(r)).all().item(), - f"peer {r} buffer != {r} (seen from rank {self.rank})", - ) - - @requires_xccl() - @skip_if_lt_x_gpu(2) - def test_get_signal_pad(self) -> None: - """Verify that signal-pad views (dtype, numel, data_ptr) match the - handle's metadata, and that buffer writes do not corrupt the pad.""" - self._init_process() - - t = symm_mem.empty(1, device="xpu") - hdl = symm_mem.rendezvous(t, group=dist.group.WORLD) - peer = (self.rank + 1) % self.world_size - - # Local pad pointer must match what the handle advertises. - local_pad = hdl.get_signal_pad(self.rank) - self.assertEqual(local_pad.data_ptr(), hdl.signal_pad_ptrs[hdl.rank]) - - # Default: uint32, signal_pad_size // 4 elements. - pad = hdl.get_signal_pad(peer) - self.assertEqual(pad.dtype, torch.uint32) - self.assertEqual(pad.numel(), hdl.signal_pad_size // 4) - - # Sizes only. - pad = hdl.get_signal_pad(peer, (8, 8)) - self.assertEqual(pad.dtype, torch.uint32) - self.assertEqual(pad.numel(), 64) - - # dtype only. - pad = hdl.get_signal_pad(peer, dtype=torch.uint64) - self.assertEqual(pad.dtype, torch.uint64) - self.assertEqual(pad.numel(), hdl.signal_pad_size // 8) - - # Sizes + dtype. - pad = hdl.get_signal_pad(peer, (8, 8), dtype=torch.uint64) - self.assertEqual(pad.dtype, torch.uint64) - self.assertEqual(pad.numel(), 64) - - # Writes to buffer must not corrupt the signal pad. - t2 = symm_mem.empty(1, device="xpu") - hdl2 = symm_mem.rendezvous(t2, group=dist.group.WORLD) - local_pad2 = hdl2.get_signal_pad(self.rank) - local_pad2.fill_(42) - t2.fill_(0) - self.assertTrue(local_pad2.eq(42).all().item()) - - @requires_xccl() - @skip_if_lt_x_gpu(4) - def test_subgroup(self) -> None: - """Two disjoint subgroups rendezvous on the same tensor; each can - observe its peers correctly via the SYCL IPC mapping.""" - self._init_process() - - ranks = list(range(self.world_size)) - subgroup_0 = dist.new_group(ranks[: len(ranks) // 2]) - subgroup_1 = dist.new_group(ranks[len(ranks) // 2 :]) - - world = dist.group.WORLD - subgroup = subgroup_0 if world.rank() < world.size() // 2 else subgroup_1 - - t = symm_mem.empty(64, device="xpu") - sm_world = symm_mem.rendezvous(t, group=world) - sm_sub = symm_mem.rendezvous(t, group=subgroup) - - self.assertEqual(sm_world.world_size, world.size()) - self.assertEqual(sm_world.rank, world.rank()) - self.assertEqual(sm_sub.world_size, world.size() // 2) - self.assertEqual(sm_sub.rank, world.rank() % subgroup.size()) - - t.fill_(world.rank()) - sm_world.barrier() - - peer = (world.rank() + 1) % world.size() - buf = sm_world.get_buffer(peer, (64,), torch.float32) - self.assertTrue(buf.eq(peer).all().item()) - - peer_sub = (subgroup.rank() + 1) % subgroup.size() - buf = sm_sub.get_buffer(peer_sub, (64,), torch.float32) - if world.rank() < world.size() // 2: - self.assertTrue(buf.eq(peer_sub).all().item()) - else: - self.assertTrue(buf.eq(peer_sub + world.size() // 2).all().item()) - - @requires_xccl() - @skip_if_lt_x_gpu(2) - def test_put_wait_signal(self) -> None: - """Verify put_signal / wait_signal over the SYCL IPC peer mapping.""" - self._init_process() - - t = symm_mem.empty(1, device="xpu") - hdl = symm_mem.rendezvous(t, group=dist.group.WORLD) - - # Ring: each rank sends a signal to its right neighbor and waits for - # a signal from its left neighbor. - dst = (self.rank + 1) % self.world_size - src = (self.rank - 1) % self.world_size - hdl.put_signal(dst_rank=dst, channel=0, timeout_ms=10_000) - hdl.wait_signal(src_rank=src, channel=0, timeout_ms=10_000) - - instantiate_parametrized_tests(XCCLTraceTest) instantiate_parametrized_tests(ProcessGroupXCCLTest) From 38cbd2518e00c12621dfa3e40557b8e402e3aac6 Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Wed, 3 Jun 2026 14:07:49 +0800 Subject: [PATCH 24/25] remove unused code --- src/xccl/Signal.cpp | 8 ++++---- src/xccl/Signal.hpp | 6 ++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/xccl/Signal.cpp b/src/xccl/Signal.cpp index 61086ad0a2..fa5b02f944 100644 --- a/src/xccl/Signal.cpp +++ b/src/xccl/Signal.cpp @@ -14,13 +14,13 @@ struct barrierKernel { if (target_rank == rank) { return; } - auto put_success = try_put_signal_device( + auto put_success = try_put_signal_device( signal_pads[target_rank] + world_size * channel + rank, timeout_ms); if (!put_success) { SYCL_KERNEL_ASSERT(false); } - auto wait_success = try_wait_signal_device( + auto wait_success = try_wait_signal_device( signal_pads[rank] + world_size * channel + target_rank, timeout_ms); if (!wait_success) { SYCL_KERNEL_ASSERT(false); @@ -77,7 +77,7 @@ struct putSignalKernel { auto thread_id = item.get_local_id(0); if (thread_id == 0) { - auto put_success = try_put_signal_device( + auto put_success = try_put_signal_device( signal_pads[dst_rank] + world_size * channel + rank, 10000000); if (!put_success) { SYCL_KERNEL_ASSERT(false); @@ -139,7 +139,7 @@ struct waitSignalKernel { auto thread_id = item.get_local_id(0); if (thread_id == 0) { - auto wait_success = try_wait_signal_device( + auto wait_success = try_wait_signal_device( signal_pads[rank] + world_size * channel + src_rank, 10000000); if (!wait_success) { SYCL_KERNEL_ASSERT(false); diff --git a/src/xccl/Signal.hpp b/src/xccl/Signal.hpp index b9089e31cf..da4023cab3 100644 --- a/src/xccl/Signal.hpp +++ b/src/xccl/Signal.hpp @@ -33,8 +33,7 @@ inline uint32_t load_acquire(uint32_t* addr) { // Put signal: wait until addr == 0, then set to 1 (release semantics) // ============================================================================= -template -bool try_put_signal_device(uint32_t* addr, size_t timeout_ms) { +inline bool try_put_signal_device(uint32_t* addr, size_t timeout_ms) { // Wait until the slot is free (value == 0) while (load_acquire(addr) != 0) { // Spin wait (no timeout check as IGC issue) @@ -48,8 +47,7 @@ bool try_put_signal_device(uint32_t* addr, size_t timeout_ms) { // ============================================================================= // Wait signal: wait until addr == 1, then set to 0 (acquire semantics) // ============================================================================= -template -bool try_wait_signal_device(uint32_t* addr, size_t timeout_ms) { +inline bool try_wait_signal_device(uint32_t* addr, size_t timeout_ms) { // Wait until signal is set (value == 1) while (load_acquire(addr) != 1) { // Spin wait (no timeout check as IGC issue) From 57032edfc028f3809953fae61a43d5518fedf275 Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Wed, 3 Jun 2026 16:17:48 +0800 Subject: [PATCH 25/25] Add more comments --- src/xccl/Signal.hpp | 6 ++++++ src/xccl/XPUSymmetricMemory.cpp | 6 ++++++ src/xccl/XPUSymmetricMemoryTypes.hpp | 4 ++++ 3 files changed, 16 insertions(+) diff --git a/src/xccl/Signal.hpp b/src/xccl/Signal.hpp index da4023cab3..0e820fc38d 100644 --- a/src/xccl/Signal.hpp +++ b/src/xccl/Signal.hpp @@ -13,6 +13,12 @@ using at::native::memory::get_alignment; // Signal primitives using store/load + atomic_fence // (sycl::atomic_ref is not supported, use explicit fence instead) // ============================================================================= +// +// Note on memory scope: +// We intentionally use memory_scope::system because signal pads are exchanged +// across ranks/devices (including peer/device-visible IPC mappings). These +// flags are polled and updated by kernels running on different devices, so a +// device/work-group scope is too narrow for this protocol. // Store value with release fence (for put_signal) // Order: store first, then release fence to flush the store diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp index c2fdd8ffe5..87c29a7522 100644 --- a/src/xccl/XPUSymmetricMemory.cpp +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -239,6 +239,12 @@ void* XPUSymmetricMemoryAllocator::alloc( c10::DeviceGuard device_guard(c10::Device(c10::DeviceType::XPU, device_idx)); sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); + // Allocate directly from SYCL runtime instead of XPUCachingAllocator: + // 1) keep behavior aligned with CUDA symmetric-memory implementation; + // 2) avoid allocator-level expandable-memory remapping side effects on + // exchanged IPC handles/addresses; + // 3) preserve flexibility for future handle-based features (for example, + // reconstructing multicast objects from physical/shared handles). void* ptr = sycl::malloc_device(block_size, current_queue); current_queue.memset(ptr, 0, block_size); auto alloc_ref = c10::make_intrusive( diff --git a/src/xccl/XPUSymmetricMemoryTypes.hpp b/src/xccl/XPUSymmetricMemoryTypes.hpp index 4cab3b81f7..ff80c3b8bb 100644 --- a/src/xccl/XPUSymmetricMemoryTypes.hpp +++ b/src/xccl/XPUSymmetricMemoryTypes.hpp @@ -2,6 +2,10 @@ namespace c10d::symmetric_memory { +// Default signal-pad size for each rank's control area. +// 2048 keeps parity with the CUDA-side default and has worked as a practical +// baseline for channelized signaling. This is a default value; higher-level +// symmetric-memory configuration can override the effective pad size. constexpr size_t signal_pad_size = 2048; using HandleType = void*;