diff --git a/python/openimpala/__init__.py b/python/openimpala/__init__.py index f4ca2ed..fa1b538 100644 --- a/python/openimpala/__init__.py +++ b/python/openimpala/__init__.py @@ -32,6 +32,62 @@ except PackageNotFoundError: __version__ = "unknown" + +def _preload_cuda_libs(): + """Pre-load CUDA shared libs shipped by PyPI nvidia-*-cu12 packages. + + The openimpala-cuda wheel uses ``auditwheel --exclude libcublas.so.12`` + (etc.) to keep the wheel under PyPI's 320 MiB cap, then declares + ``nvidia-cublas-cu12`` and friends as runtime deps. Those wheels install + their .so's under ``site-packages/nvidia//lib/``, which the + dynamic linker does NOT search by default — so loading ``_core.so`` + fails with ``undefined symbol: cublasSetStream_v2`` (and similar) until + the libs are dlopened with RTLD_GLOBAL. + + Same trick as PyTorch / CuPy / JAX. For each lib we try the bare soname + first (so HPC nodes with system CUDA on LD_LIBRARY_PATH win the race), + then fall back to the PyPI bundled path. No-op for pure-Python wheels + where the nvidia/ dir doesn't exist. + + Load order matters: cudart first, cublasLt before cublas (cublas links + against cublasLt at the symbol level). + """ + import os + import sys + import ctypes + + if sys.platform != "linux": + return + + pkg_root = os.path.dirname(os.path.abspath(__file__)) + site_pkgs = os.path.dirname(pkg_root) + nvidia_root = os.path.join(site_pkgs, "nvidia") + + libs = ( + ("libcudart.so.12", "cuda_runtime/lib/libcudart.so.12"), + ("libnvJitLink.so.12", "nvjitlink/lib/libnvJitLink.so.12"), + ("libcublasLt.so.12", "cublas/lib/libcublasLt.so.12"), + ("libcublas.so.12", "cublas/lib/libcublas.so.12"), + ("libcusparse.so.12", "cusparse/lib/libcusparse.so.12"), + ("libcurand.so.10", "curand/lib/libcurand.so.10"), + ) + + for soname, fallback in libs: + try: + ctypes.CDLL(soname, mode=ctypes.RTLD_GLOBAL) + continue + except OSError: + pass + path = os.path.join(nvidia_root, fallback) + if os.path.exists(path): + try: + ctypes.CDLL(path, mode=ctypes.RTLD_GLOBAL) + except OSError: + pass + + +_preload_cuda_libs() + # Session context manager (pure Python — always available) from .session import Session