diff --git a/cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in b/cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in index d102901cf..81283084f 100644 --- a/cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in +++ b/cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in @@ -9,13 +9,12 @@ # This code was automatically generated with version 12.9.0. Do not modify it directly. {{if 'Windows' == platform.system()}} import os -import site -import struct import win32api -from pywintypes import error {{else}} cimport cuda.bindings._lib.dlfcn as dlfcn +from libc.stdint cimport uintptr_t {{endif}} +from cuda.bindings import path_finder from libc.stdint cimport intptr_t @@ -56,51 +55,10 @@ cdef int cuPythonInit() except -1 nogil: # Load library {{if 'Windows' == platform.system()}} with gil: - # First check if the DLL has been loaded by 3rd parties - try: - handle = win32api.GetModuleHandle("nvrtc64_120_0.dll") - except: - handle = None - - # Check if DLLs can be found within pip installations - if not handle: - LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000 - LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100 - site_packages = [site.getusersitepackages()] + site.getsitepackages() - for sp in site_packages: - mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin") - if os.path.isdir(mod_path): - os.add_dll_directory(mod_path) - try: - handle = win32api.LoadLibraryEx( - # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... - os.path.join(mod_path, "nvrtc64_120_0.dll"), - 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) - - # Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is - # located in the same mod_path. - # Update PATH environ so that the two dlls can find each other - os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path)) - except: - pass - else: - break - else: - # Else try default search - # Only reached if DLL wasn't found in any site-package path - LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000 - try: - handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) - except: - pass - - if not handle: - raise RuntimeError('Failed to LoadLibraryEx nvrtc64_120_0.dll') + handle = path_finder._load_nvidia_dynamic_library("nvrtc").handle {{else}} - handle = dlfcn.dlopen('libnvrtc.so.12', dlfcn.RTLD_NOW) - if handle == NULL: - with gil: - raise RuntimeError('Failed to dlopen libnvrtc.so.12') + with gil: + handle = path_finder._load_nvidia_dynamic_library("nvrtc").handle {{endif}} diff --git a/cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx b/cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx index bb61a3e22..36bdcb4f4 100644 --- a/cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx +++ b/cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx @@ -4,12 +4,12 @@ # # This code was automatically generated across versions from 12.0.1 to 12.9.0. Do not modify it directly. -from libc.stdint cimport intptr_t - -from .utils cimport get_nvjitlink_dso_version_suffix +from libc.stdint cimport intptr_t, uintptr_t from .utils import FunctionNotFoundError, NotSupportedError +from cuda.bindings import path_finder + ############################################################################### # Extern ############################################################################### @@ -52,17 +52,9 @@ cdef void* __nvJitLinkGetInfoLog = NULL cdef void* __nvJitLinkVersion = NULL -cdef void* load_library(const int driver_ver) except* with gil: - cdef void* handle - for suffix in get_nvjitlink_dso_version_suffix(driver_ver): - so_name = "libnvJitLink.so" + (f".{suffix}" if suffix else suffix) - handle = dlopen(so_name.encode(), RTLD_NOW | RTLD_GLOBAL) - if handle != NULL: - break - else: - err_msg = dlerror() - raise RuntimeError(f'Failed to dlopen libnvJitLink ({err_msg.decode()})') - return handle +cdef void* load_library(int driver_ver) except* with gil: + cdef uintptr_t handle = path_finder._load_nvidia_dynamic_library("nvJitLink").handle + return handle cdef int _check_or_init_nvjitlink() except -1 nogil: diff --git a/cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx b/cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx index a2f77ca2e..0020fe486 100644 --- a/cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx +++ b/cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx @@ -6,12 +6,9 @@ from libc.stdint cimport intptr_t -from .utils cimport get_nvjitlink_dso_version_suffix - from .utils import FunctionNotFoundError, NotSupportedError -import os -import site +from cuda.bindings import path_finder import win32api @@ -42,46 +39,6 @@ cdef void* __nvJitLinkGetInfoLog = NULL cdef void* __nvJitLinkVersion = NULL -cdef inline list get_site_packages(): - return [site.getusersitepackages()] + site.getsitepackages() - - -cdef load_library(const int driver_ver): - handle = 0 - - for suffix in get_nvjitlink_dso_version_suffix(driver_ver): - if len(suffix) == 0: - continue - dll_name = f"nvJitLink_{suffix}0_0.dll" - - # First check if the DLL has been loaded by 3rd parties - try: - return win32api.GetModuleHandle(dll_name) - except: - pass - - # Next, check if DLLs are installed via pip - for sp in get_site_packages(): - mod_path = os.path.join(sp, "nvidia", "nvJitLink", "bin") - if os.path.isdir(mod_path): - os.add_dll_directory(mod_path) - try: - return win32api.LoadLibraryEx( - # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... - os.path.join(mod_path, dll_name), - 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) - except: - pass - # Finally, try default search - # Only reached if DLL wasn't found in any site-package path - try: - return win32api.LoadLibrary(dll_name) - except: - pass - - raise RuntimeError('Failed to load nvJitLink') - - cdef int _check_or_init_nvjitlink() except -1 nogil: global __py_nvjitlink_init if __py_nvjitlink_init: @@ -104,7 +61,7 @@ cdef int _check_or_init_nvjitlink() except -1 nogil: raise RuntimeError('something went wrong') # Load library - handle = load_library(driver_ver) + handle = path_finder._load_nvidia_dynamic_library("nvJitLink").handle # Load function global __nvJitLinkCreate diff --git a/cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx b/cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx index 53675b094..8759096a4 100644 --- a/cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx +++ b/cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx @@ -4,12 +4,12 @@ # # This code was automatically generated across versions from 11.0.3 to 12.9.0. Do not modify it directly. -from libc.stdint cimport intptr_t - -from .utils cimport get_nvvm_dso_version_suffix +from libc.stdint cimport intptr_t, uintptr_t from .utils import FunctionNotFoundError, NotSupportedError +from cuda.bindings import path_finder + ############################################################################### # Extern ############################################################################### @@ -51,16 +51,8 @@ cdef void* __nvvmGetProgramLog = NULL cdef void* load_library(const int driver_ver) except* with gil: - cdef void* handle - for suffix in get_nvvm_dso_version_suffix(driver_ver): - so_name = "libnvvm.so" + (f".{suffix}" if suffix else suffix) - handle = dlopen(so_name.encode(), RTLD_NOW | RTLD_GLOBAL) - if handle != NULL: - break - else: - err_msg = dlerror() - raise RuntimeError(f'Failed to dlopen libnvvm ({err_msg.decode()})') - return handle + cdef uintptr_t handle = path_finder._load_nvidia_dynamic_library("nvvm").handle + return handle cdef int _check_or_init_nvvm() except -1 nogil: diff --git a/cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx b/cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx index 3f9f54a4d..0a7eae320 100644 --- a/cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx +++ b/cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx @@ -6,12 +6,9 @@ from libc.stdint cimport intptr_t -from .utils cimport get_nvvm_dso_version_suffix - from .utils import FunctionNotFoundError, NotSupportedError -import os -import site +from cuda.bindings import path_finder import win32api @@ -40,54 +37,6 @@ cdef void* __nvvmGetProgramLogSize = NULL cdef void* __nvvmGetProgramLog = NULL -cdef inline list get_site_packages(): - return [site.getusersitepackages()] + site.getsitepackages() + ["conda"] - - -cdef load_library(const int driver_ver): - handle = 0 - - for suffix in get_nvvm_dso_version_suffix(driver_ver): - if len(suffix) == 0: - continue - dll_name = "nvvm64_40_0.dll" - - # First check if the DLL has been loaded by 3rd parties - try: - return win32api.GetModuleHandle(dll_name) - except: - pass - - # Next, check if DLLs are installed via pip or conda - for sp in get_site_packages(): - if sp == "conda": - # nvvm is not under $CONDA_PREFIX/lib, so it's not in the default search path - conda_prefix = os.environ.get("CONDA_PREFIX") - if conda_prefix is None: - continue - mod_path = os.path.join(conda_prefix, "Library", "nvvm", "bin") - else: - mod_path = os.path.join(sp, "nvidia", "cuda_nvcc", "nvvm", "bin") - if os.path.isdir(mod_path): - os.add_dll_directory(mod_path) - try: - return win32api.LoadLibraryEx( - # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... - os.path.join(mod_path, dll_name), - 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) - except: - pass - - # Finally, try default search - # Only reached if DLL wasn't found in any site-package path - try: - return win32api.LoadLibrary(dll_name) - except: - pass - - raise RuntimeError('Failed to load nvvm') - - cdef int _check_or_init_nvvm() except -1 nogil: global __py_nvvm_init if __py_nvvm_init: @@ -110,7 +59,7 @@ cdef int _check_or_init_nvvm() except -1 nogil: raise RuntimeError('something went wrong') # Load library - handle = load_library(driver_ver) + handle = path_finder._load_nvidia_dynamic_library("nvvm").handle # Load function global __nvvmVersion diff --git a/cuda_bindings/cuda/bindings/_internal/utils.pxd b/cuda_bindings/cuda/bindings/_internal/utils.pxd index cac7846ff..a4b71c531 100644 --- a/cuda_bindings/cuda/bindings/_internal/utils.pxd +++ b/cuda_bindings/cuda/bindings/_internal/utils.pxd @@ -165,6 +165,3 @@ cdef int get_nested_resource_ptr(nested_resource[ResT] &in_out_ptr, object obj, cdef bint is_nested_sequence(data) cdef void* get_buffer_pointer(buf, Py_ssize_t size, readonly=*) except* - -cdef tuple get_nvjitlink_dso_version_suffix(int driver_ver) -cdef tuple get_nvvm_dso_version_suffix(int driver_ver) diff --git a/cuda_bindings/cuda/bindings/_internal/utils.pyx b/cuda_bindings/cuda/bindings/_internal/utils.pyx index 0a693c052..7fc77b22c 100644 --- a/cuda_bindings/cuda/bindings/_internal/utils.pyx +++ b/cuda_bindings/cuda/bindings/_internal/utils.pyx @@ -127,17 +127,3 @@ cdef int get_nested_resource_ptr(nested_resource[ResT] &in_out_ptr, object obj, class FunctionNotFoundError(RuntimeError): pass class NotSupportedError(RuntimeError): pass - - -cdef tuple get_nvjitlink_dso_version_suffix(int driver_ver): - if 12000 <= driver_ver < 13000: - return ('12', '') - raise NotSupportedError(f'CUDA driver version {driver_ver} is not supported') - - -cdef tuple get_nvvm_dso_version_suffix(int driver_ver): - if 11000 <= driver_ver < 11020: - return ('3', '') - if 11020 <= driver_ver < 13000: - return ('4', '') - raise NotSupportedError(f'CUDA driver version {driver_ver} is not supported') diff --git a/cuda_bindings/setup.py b/cuda_bindings/setup.py index 84a4c86f2..8b951b2e0 100644 --- a/cuda_bindings/setup.py +++ b/cuda_bindings/setup.py @@ -379,31 +379,7 @@ def initialize_options(self): def build_extension(self, ext): if building_wheel and sys.platform == "linux": # Strip binaries to remove debug symbols - extra_linker_flags = ["-Wl,--strip-all"] - - # Allow extensions to discover libraries at runtime - # relative their wheels installation. - if ext.name == "cuda.bindings._bindings.cynvrtc": - ldflag = "-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib" - elif ext.name == "cuda.bindings._internal.nvjitlink": - ldflag = "-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/nvjitlink/lib" - elif ext.name == "cuda.bindings._internal.nvvm": - # from /site-packages/cuda/bindings/_internal/ - # to /site-packages/nvidia/cuda_nvcc/nvvm/lib64/ - rel1 = "$ORIGIN/../../../nvidia/cuda_nvcc/nvvm/lib64" - # from /lib/python3.*/site-packages/cuda/bindings/_internal/ - # to /nvvm/lib64/ - rel2 = "$ORIGIN/../../../../../../nvvm/lib64" - ldflag = f"-Wl,--disable-new-dtags,-rpath,{rel1},-rpath,{rel2}" - else: - ldflag = None - - if ldflag: - extra_linker_flags.append(ldflag) - else: - extra_linker_flags = [] - - ext.extra_link_args += extra_linker_flags + ext.extra_link_args.append("-Wl,--strip-all") super().build_extension(ext)