Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 98 additions & 82 deletions py/torch_tensorrt/distributed/_nccl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
symlink workarounds.
"""

import ctypes
import logging
import os
import subprocess
from typing import Optional

import torch.distributed as dist
Expand Down Expand Up @@ -101,88 +101,122 @@ def ensure_nccl_symlink(nccl_lib_dir: str) -> bool:
return False


def check_nccl_library_path() -> bool:
"""
Check if LD_LIBRARY_PATH includes PyTorch's NCCL directory.
def _sys_libdir_on_ldso_path() -> str:
"""Pick a system library directory that ld.so searches by default.

Returns:
True if configuration is correct, False if LD_LIBRARY_PATH needs updating.
Returns the first existing directory from a portability-ordered list:
Debian/Ubuntu x86_64 multiarch → ARM64 multiarch → RHEL/CentOS lib64 →
bare /usr/lib (always on ld.so's search path as a final fallback).
"""
nccl_lib_dir = get_nccl_library_path()

if nccl_lib_dir is None:
# System NCCL - no action needed
return True

ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
return nccl_lib_dir in ld_library_path
for d in (
"/usr/lib/x86_64-linux-gnu", # Debian / Ubuntu x86_64
"/usr/lib/aarch64-linux-gnu", # Debian / Ubuntu ARM64 (Jetson)
"/usr/lib64", # RHEL / CentOS / Fedora x86_64
):
if os.path.isdir(d):
return d
return "/usr/lib"


def setup_nccl_for_torch_tensorrt() -> None:
"""
Setup NCCL library for TensorRT distributed inference.

This function:
1. Detects if nvidia.nccl pip package is installed
2. Creates libnccl.so symlink if needed
3. Pre-loads libnccl.so via ctypes (helps Python runtime path)
4. Updates LD_LIBRARY_PATH for dynamic loaders

Note: TRT's internal loader (libLoader.cpp) reads LD_LIBRARY_PATH at
process launch time, not when updated via os.environ. For the C++ TRT
runtime path, LD_LIBRARY_PATH must be set before the process starts:

NCCL_LIB=$(python -c "from torch_tensorrt.distributed._nccl_utils import get_nccl_library_path; print(get_nccl_library_path())")
LD_LIBRARY_PATH="$NCCL_LIB:$LD_LIBRARY_PATH" python script.py

For NGC containers (system NCCL), this is a no-op.
Point a `libnccl.so` symlink on ld.so's default search path at PyTorch's
libnccl.so.2 so TRT and PyTorch share a single NCCL library in the process.

What this function does:
1. Locate the nvidia.nccl pip package's libnccl.so.2 via
get_nccl_library_path(). If pip's nccl isn't installed (NGC /
system-NCCL environments) returns immediately — no action needed.
2. Pick a system library directory that ld.so already searches by
default via _sys_libdir_on_ldso_path() (Debian/Ubuntu multiarch,
RHEL lib64, or /usr/lib fallback).
3. If <sys_libdir>/libnccl.so already points at that libnccl.so.2,
return.
4. Otherwise, atomically install a fresh symlink:
<sys_libdir>/libnccl.so → <pip>/libnccl.so.2
via "symlink to a unique per-pid temp name, then os.replace onto
the target." This is multi-process safe: when several ranks of a
distributed test call this function concurrently, none of them
crash on FileExistsError, and the final on-disk state is the same
regardless of execution order.
5. Run `ldconfig` to refresh /etc/ld.so.cache.
6. Guarded by a module-global flag so subsequent calls in the same
process are a no-op.

Requires write access to the chosen sys_libdir (root inside Docker is
the common case). On OSError the function raises RuntimeError with
documented LD_PRELOAD / LD_LIBRARY_PATH workarounds for non-root setups.
"""
global _nccl_setup_checked

# Only check once per process
if _nccl_setup_checked:
return
_nccl_setup_checked = True

nccl_lib_dir = get_nccl_library_path()

if nccl_lib_dir is None:
# NGC container or system NCCL - no action needed
logger.debug(
"nvidia.nccl package not found. "
"Assuming system NCCL is used by both PyTorch and TensorRT."
"nvidia.nccl package not found; assuming system NCCL is shared by PyTorch and TensorRT."
)
return

logger.debug(f"Found nvidia.nccl package at: {nccl_lib_dir}")

# Ensure symlink exists
symlink_ok = ensure_nccl_symlink(nccl_lib_dir)

# Ensure LD_LIBRARY_PATH includes the NCCL directory so TRT's dlopen("libnccl.so")
# finds the same library PyTorch already loaded. dlopen() reads LD_LIBRARY_PATH
# dynamically, so updating os.environ here takes effect for subsequent loads.
ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
if nccl_lib_dir not in ld_library_path:
os.environ["LD_LIBRARY_PATH"] = (
f"{nccl_lib_dir}:{ld_library_path}" if ld_library_path else nccl_lib_dir
nccl_so_2 = os.path.join(nccl_lib_dir, "libnccl.so.2")
if not os.path.isfile(nccl_so_2):
logger.warning(
f"Expected {nccl_so_2} to exist but it doesn't; skipping NCCL setup."
)
logger.debug(f"Added NCCL directory to LD_LIBRARY_PATH: {nccl_lib_dir}")
else:
logger.debug(f"LD_LIBRARY_PATH already includes NCCL directory: {nccl_lib_dir}")
return

if symlink_ok:
# Pre-load libnccl.so into the process with RTLD_GLOBAL so that TRT's
# subsequent dlopen("libnccl.so") inside setCommunicator() finds the
# already-loaded library rather than searching LD_LIBRARY_PATH again.
nccl_so = os.path.join(nccl_lib_dir, "libnccl.so")
try:
ctypes.CDLL(nccl_so, mode=ctypes.RTLD_GLOBAL)
logger.debug(f"Pre-loaded NCCL library: {nccl_so}")
except OSError as e:
logger.warning(f"Failed to pre-load NCCL library {nccl_so}: {e}")
sys_libdir = _sys_libdir_on_ldso_path()
target = os.path.join(sys_libdir, "libnccl.so")

logger.debug("NCCL library setup complete")
# Fast path: already set up by a prior process or rank.
if os.path.islink(target) and os.readlink(target) == nccl_so_2:
logger.debug(f"{target} already points at {nccl_so_2}; nothing to do.")
return

# Race-safe symlink swap. Multiple ranks may enter this function
# concurrently (e.g. MultiProcessTestCase forks 2 children that each call
# setup_nccl_for_torch_tensorrt simultaneously). Using `os.remove` +
# `os.symlink` opens a window where one rank's symlink call races with
# another's, raising FileExistsError on the loser.
#
# Instead: create the new symlink under a unique per-pid temp name
# (no contention possible — different filenames), then atomically rename
# it onto `target`. `os.replace` is a single POSIX rename(2) call:
# it overwrites unconditionally and is observable as a single transition,
# never a missing or half-written state. All ranks converge on the same
# final symlink without any of them crashing.
tmp = f"{target}.torchtrt-{os.getpid()}"
try:
if os.path.lexists(tmp):
os.remove(tmp)
os.symlink(nccl_so_2, tmp)
os.replace(tmp, target)
subprocess.run(["ldconfig"], check=False)
logger.info(
f"NCCL: linked {target} -> {nccl_so_2} so TRT and PyTorch share one libnccl."
)
except OSError as e:
# Clean up our temp link if we left one behind.
if os.path.lexists(tmp):
try:
os.remove(tmp)
except OSError:
pass
# If another rank already produced the correct final symlink while
# we were failing, accept that as success — the end state we wanted
# is in place.
if os.path.islink(target) and os.readlink(target) == nccl_so_2:
return
raise RuntimeError(
f"setup_nccl_for_torch_tensorrt(): cannot write {target} "
f"(needed so TRT's dlopen('libnccl.so') resolves to PyTorch's libnccl.so.2). "
f"Workarounds without root: relaunch python with "
f"LD_PRELOAD={nccl_so_2} ; or pre-set "
f"LD_LIBRARY_PATH={nccl_lib_dir}:$LD_LIBRARY_PATH before python starts "
f"(and create a libnccl.so symlink in that dir first). "
f"Original error: {e}"
) from e


def initialize_nccl_comm(device: Optional[int] = None) -> None:
Expand Down Expand Up @@ -253,29 +287,11 @@ def initialize_nccl_comm(device: Optional[int] = None) -> None:


def check_nccl_engine_requirements() -> None:
"""Warn if an requires_native_multidevice TRT engine's NCCL prerequisites are not satisfied.

Checks two conditions and logs a warning for each:
1. LD_LIBRARY_PATH does not include PyTorch's NCCL lib dir (too late to fix,
must be set before process launch — use torchtrtrun).
2. torch.distributed is not initialized or world_size == 1.
"""Warn if a requires_native_multidevice TRT engine's NCCL prerequisites are not satisfied.

Call this from both TorchTensorRTModule and PythonTorchTensorRTModule after
Called from TorchTensorRTModule and PythonTorchTensorRTModule after
confirming the engine has NCCL collective ops.
"""
if get_nccl_library_path() is not None and not check_nccl_library_path():
logger.warning(
"This TRT engine contains NCCL collective ops but "
"LD_LIBRARY_PATH does not include PyTorch's NCCL library directory. "
"TRT may load a different NCCL instance than PyTorch, causing "
"communicator sharing to fail. Use torchtrtrun to launch distributed "
"scripts, or set LD_PRELOAD and LD_LIBRARY_PATH before process start:\n"
" NCCL_LIB=$(python -c 'from torch_tensorrt.distributed._nccl_utils "
"import get_nccl_library_path; print(get_nccl_library_path())')\n"
" LD_PRELOAD=$NCCL_LIB/libnccl.so.2 "
"LD_LIBRARY_PATH=$NCCL_LIB:$LD_LIBRARY_PATH python ..."
)

if not (
dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1
):
Expand Down
36 changes: 0 additions & 36 deletions tests/py/dynamo/distributed/test_native_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,22 +713,6 @@ def test_get_nccl_library_path_returns_none_or_string(self) -> None:
f"libnccl.so.2 not found in {result}",
)

def test_check_nccl_library_path_system_nccl(self) -> None:
"""check_nccl_library_path returns True when nvidia.nccl not installed."""
from torch_tensorrt.distributed._nccl_utils import (
check_nccl_library_path,
get_nccl_library_path,
)

nccl_lib_dir = get_nccl_library_path()
if nccl_lib_dir is None:
# System NCCL path — must return True
self.assertTrue(check_nccl_library_path())
else:
# nvidia.nccl installed — result depends on LD_LIBRARY_PATH
result = check_nccl_library_path()
self.assertIsInstance(result, bool)

def test_setup_nccl_for_torch_tensorrt_idempotent(self) -> None:
"""Calling setup_nccl_for_torch_tensorrt() multiple times is safe."""
from torch_tensorrt.distributed import _nccl_utils
Expand All @@ -749,26 +733,6 @@ def test_ensure_nccl_symlink_nonexistent_dir(self) -> None:
# libnccl.so.2 doesn't exist there → returns False
self.assertFalse(result)

def test_check_nccl_library_path_detects_missing_ld_path(self) -> None:
"""check_nccl_library_path returns False when LD_LIBRARY_PATH is absent."""
from torch_tensorrt.distributed._nccl_utils import get_nccl_library_path

nccl_lib_dir = get_nccl_library_path()
if nccl_lib_dir is None:
self.skipTest("nvidia.nccl not installed; system NCCL path is always OK")

from torch_tensorrt.distributed._nccl_utils import check_nccl_library_path

original = os.environ.get("LD_LIBRARY_PATH", "")
# Remove nccl_lib_dir from LD_LIBRARY_PATH
paths = [p for p in original.split(":") if p and p != nccl_lib_dir]
os.environ["LD_LIBRARY_PATH"] = ":".join(paths)
try:
result = check_nccl_library_path()
self.assertFalse(result)
finally:
os.environ["LD_LIBRARY_PATH"] = original


# ============================================================================
# Section 4 — fuse_distributed_ops graph pass (no GPU, no dist) [was Section 3]
Expand Down
Loading