Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions python/openimpala/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,62 @@
except PackageNotFoundError:
__version__ = "unknown"


def _preload_cuda_libs():
"""Pre-load CUDA shared libs shipped by PyPI nvidia-*-cu12 packages.

The openimpala-cuda wheel uses ``auditwheel --exclude libcublas.so.12``
(etc.) to keep the wheel under PyPI's 320 MiB cap, then declares
``nvidia-cublas-cu12`` and friends as runtime deps. Those wheels install
their .so's under ``site-packages/nvidia/<component>/lib/``, which the
dynamic linker does NOT search by default — so loading ``_core.so``
fails with ``undefined symbol: cublasSetStream_v2`` (and similar) until
the libs are dlopened with RTLD_GLOBAL.

Same trick as PyTorch / CuPy / JAX. For each lib we try the bare soname
first (so HPC nodes with system CUDA on LD_LIBRARY_PATH win the race),
then fall back to the PyPI bundled path. No-op for pure-Python wheels
where the nvidia/ dir doesn't exist.

Load order matters: cudart first, cublasLt before cublas (cublas links
against cublasLt at the symbol level).
"""
import os
import sys
import ctypes

if sys.platform != "linux":
return

pkg_root = os.path.dirname(os.path.abspath(__file__))
site_pkgs = os.path.dirname(pkg_root)
nvidia_root = os.path.join(site_pkgs, "nvidia")

libs = (
("libcudart.so.12", "cuda_runtime/lib/libcudart.so.12"),
("libnvJitLink.so.12", "nvjitlink/lib/libnvJitLink.so.12"),
("libcublasLt.so.12", "cublas/lib/libcublasLt.so.12"),
("libcublas.so.12", "cublas/lib/libcublas.so.12"),
("libcusparse.so.12", "cusparse/lib/libcusparse.so.12"),
("libcurand.so.10", "curand/lib/libcurand.so.10"),
)

for soname, fallback in libs:
try:
ctypes.CDLL(soname, mode=ctypes.RTLD_GLOBAL)
continue
except OSError:
pass
path = os.path.join(nvidia_root, fallback)
if os.path.exists(path):
try:
ctypes.CDLL(path, mode=ctypes.RTLD_GLOBAL)
except OSError:
pass


_preload_cuda_libs()

# Session context manager (pure Python — always available)
from .session import Session

Expand Down
Loading