Skip to content

Commit b5cef1b

Browse files
committed
Save result of factoring out load_dl_common.py, load_dl_linux.py, load_dl_windows.py with the help of Cursor.
1 parent a649e7d commit b5cef1b

File tree

4 files changed

+354
-171
lines changed

4 files changed

+354
-171
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2025 NVIDIA Corporation. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
3+
4+
from dataclasses import dataclass
5+
from typing import Callable, Optional
6+
7+
from .supported_libs import DIRECT_DEPENDENCIES
8+
9+
10+
@dataclass
11+
class LoadedDL:
12+
"""Represents a loaded dynamic library.
13+
14+
Attributes:
15+
handle: The library handle (can be converted to void* in Cython)
16+
abs_path: The absolute path to the library file
17+
was_already_loaded_from_elsewhere: Whether the library was already loaded
18+
"""
19+
20+
# ATTENTION: To convert `handle` back to `void*` in cython:
21+
# Linux: `cdef void* ptr = <void*><uintptr_t>`
22+
# Windows: `cdef void* ptr = <void*><intptr_t>`
23+
handle: int
24+
abs_path: Optional[str]
25+
was_already_loaded_from_elsewhere: bool
26+
27+
28+
def add_dll_directory(dll_abs_path: str) -> None:
29+
"""Add a DLL directory to the search path and update PATH environment variable.
30+
31+
Args:
32+
dll_abs_path: Absolute path to the DLL file
33+
34+
Raises:
35+
AssertionError: If the directory containing the DLL does not exist
36+
"""
37+
import os
38+
39+
dirpath = os.path.dirname(dll_abs_path)
40+
assert os.path.isdir(dirpath), dll_abs_path
41+
# Add the DLL directory to the search path
42+
os.add_dll_directory(dirpath)
43+
# Update PATH as a fallback for dependent DLL resolution
44+
curr_path = os.environ.get("PATH")
45+
os.environ["PATH"] = dirpath if curr_path is None else os.pathsep.join((curr_path, dirpath))
46+
47+
48+
def load_dependencies(libname: str, load_func: Callable[[str], LoadedDL]) -> None:
49+
"""Load all dependencies for a given library.
50+
51+
Args:
52+
libname: The name of the library whose dependencies should be loaded
53+
load_func: The function to use for loading libraries (e.g. load_nvidia_dynamic_library)
54+
55+
Example:
56+
>>> load_dependencies("cudart", load_nvidia_dynamic_library)
57+
# This will load all dependencies of cudart using the provided loading function
58+
"""
59+
for dep in DIRECT_DEPENDENCIES.get(libname, ()):
60+
load_func(dep)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2025 NVIDIA Corporation. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
3+
4+
import ctypes
5+
import ctypes.util
6+
import os
7+
from typing import Optional
8+
9+
from .load_dl_common import LoadedDL
10+
11+
CDLL_MODE = os.RTLD_NOW | os.RTLD_GLOBAL
12+
13+
LIBDL_PATH = ctypes.util.find_library("dl") or "libdl.so.2"
14+
LIBDL = ctypes.CDLL(LIBDL_PATH)
15+
LIBDL.dladdr.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
16+
LIBDL.dladdr.restype = ctypes.c_int
17+
18+
19+
class Dl_info(ctypes.Structure):
20+
"""Structure used by dladdr to return information about a loaded symbol."""
21+
22+
_fields_ = [
23+
("dli_fname", ctypes.c_char_p), # path to .so
24+
("dli_fbase", ctypes.c_void_p),
25+
("dli_sname", ctypes.c_char_p),
26+
("dli_saddr", ctypes.c_void_p),
27+
]
28+
29+
30+
def abs_path_for_dynamic_library(libname: str, handle: int) -> Optional[str]:
31+
"""Get the absolute path of a loaded dynamic library on Linux.
32+
33+
Args:
34+
libname: The name of the library
35+
handle: The library handle
36+
37+
Returns:
38+
The absolute path to the library file, or None if no expected symbol is found
39+
40+
Raises:
41+
OSError: If dladdr fails to get information about the symbol
42+
"""
43+
from .supported_libs import EXPECTED_LIB_SYMBOLS
44+
45+
for symbol_name in EXPECTED_LIB_SYMBOLS[libname]:
46+
symbol = getattr(handle, symbol_name, None)
47+
if symbol is not None:
48+
break
49+
else:
50+
return None
51+
52+
addr = ctypes.cast(symbol, ctypes.c_void_p)
53+
info = Dl_info()
54+
if LIBDL.dladdr(addr, ctypes.byref(info)) == 0:
55+
raise OSError(f"dladdr failed for {libname=!r}")
56+
return info.dli_fname.decode()
57+
58+
59+
def check_if_already_loaded(libname: str) -> Optional[LoadedDL]:
60+
"""Check if the library is already loaded in the process.
61+
62+
Args:
63+
libname: The name of the library to check
64+
65+
Returns:
66+
A LoadedDL object if the library is already loaded, None otherwise
67+
68+
Example:
69+
>>> loaded = check_if_already_loaded("cudart")
70+
>>> if loaded is not None:
71+
... print(f"Library already loaded from {loaded.abs_path}")
72+
"""
73+
from .supported_libs import SUPPORTED_LINUX_SONAMES
74+
75+
for soname in SUPPORTED_LINUX_SONAMES.get(libname, ()):
76+
try:
77+
handle = ctypes.CDLL(soname, mode=os.RTLD_NOLOAD)
78+
except OSError:
79+
continue
80+
else:
81+
return LoadedDL(handle._handle, abs_path_for_dynamic_library(libname, handle), True)
82+
return None
83+
84+
85+
def load_with_system_search(libname: str, soname: str) -> Optional[LoadedDL]:
86+
"""Try to load a library using system search paths.
87+
88+
Args:
89+
libname: The name of the library to load
90+
soname: The soname to search for
91+
92+
Returns:
93+
A LoadedDL object if successful, None if the library cannot be loaded
94+
95+
Raises:
96+
RuntimeError: If the library is loaded but no expected symbol is found
97+
"""
98+
try:
99+
handle = ctypes.CDLL(soname, CDLL_MODE)
100+
abs_path = abs_path_for_dynamic_library(libname, handle)
101+
if abs_path is None:
102+
raise RuntimeError(f"No expected symbol for {libname=!r}")
103+
return LoadedDL(handle._handle, abs_path, False)
104+
except OSError:
105+
return None
106+
107+
108+
def load_with_abs_path(libname: str, found_path: str) -> LoadedDL:
109+
"""Load a dynamic library from the given path.
110+
111+
Args:
112+
libname: The name of the library to load
113+
found_path: The absolute path to the library file
114+
115+
Returns:
116+
A LoadedDL object representing the loaded library
117+
118+
Raises:
119+
RuntimeError: If the library cannot be loaded
120+
"""
121+
try:
122+
handle = ctypes.CDLL(found_path, CDLL_MODE)
123+
except OSError as e:
124+
raise RuntimeError(f"Failed to dlopen {found_path}: {e}") from e
125+
return LoadedDL(handle._handle, found_path, False)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright 2025 NVIDIA Corporation. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
3+
4+
import ctypes
5+
import ctypes.wintypes
6+
import functools
7+
from typing import Optional
8+
9+
import pywintypes
10+
import win32api
11+
12+
from .load_dl_common import LoadedDL, add_dll_directory
13+
14+
# Mirrors WinBase.h (unfortunately not defined already elsewhere)
15+
WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
16+
WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
17+
18+
19+
def abs_path_for_dynamic_library(handle: int) -> str:
20+
"""Get the absolute path of a loaded dynamic library on Windows.
21+
22+
Args:
23+
handle: The library handle
24+
25+
Returns:
26+
The absolute path to the DLL file
27+
28+
Raises:
29+
OSError: If GetModuleFileNameW fails
30+
"""
31+
buf = ctypes.create_unicode_buffer(260)
32+
n_chars = ctypes.windll.kernel32.GetModuleFileNameW(ctypes.wintypes.HMODULE(handle), buf, len(buf))
33+
if n_chars == 0:
34+
raise OSError("GetModuleFileNameW failed")
35+
return buf.value
36+
37+
38+
@functools.cache
39+
def cuDriverGetVersion() -> int:
40+
"""Get the CUDA driver version.
41+
42+
Returns:
43+
The CUDA driver version number
44+
45+
Raises:
46+
AssertionError: If the driver version cannot be obtained
47+
"""
48+
handle = win32api.LoadLibrary("nvcuda.dll")
49+
50+
kernel32 = ctypes.WinDLL("kernel32", use_last_error=True)
51+
GetProcAddress = kernel32.GetProcAddress
52+
GetProcAddress.argtypes = [ctypes.wintypes.HMODULE, ctypes.wintypes.LPCSTR]
53+
GetProcAddress.restype = ctypes.c_void_p
54+
cuDriverGetVersion = GetProcAddress(handle, b"cuDriverGetVersion")
55+
assert cuDriverGetVersion
56+
57+
FUNC_TYPE = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.POINTER(ctypes.c_int))
58+
cuDriverGetVersion_fn = FUNC_TYPE(cuDriverGetVersion)
59+
driver_ver = ctypes.c_int()
60+
err = cuDriverGetVersion_fn(ctypes.byref(driver_ver))
61+
assert err == 0
62+
return driver_ver.value
63+
64+
65+
def check_if_already_loaded(libname: str) -> Optional[LoadedDL]:
66+
"""Check if the library is already loaded in the process.
67+
68+
Args:
69+
libname: The name of the library to check
70+
71+
Returns:
72+
A LoadedDL object if the library is already loaded, None otherwise
73+
74+
Example:
75+
>>> loaded = check_if_already_loaded("cudart")
76+
>>> if loaded is not None:
77+
... print(f"Library already loaded from {loaded.abs_path}")
78+
"""
79+
from .supported_libs import SUPPORTED_WINDOWS_DLLS
80+
81+
for dll_name in SUPPORTED_WINDOWS_DLLS.get(libname, ()):
82+
try:
83+
handle = win32api.GetModuleHandle(dll_name)
84+
except pywintypes.error:
85+
continue
86+
else:
87+
return LoadedDL(handle, abs_path_for_dynamic_library(handle), True)
88+
return None
89+
90+
91+
def load_with_system_search(name: str, _unused: str) -> Optional[LoadedDL]:
92+
"""Try to load a DLL using system search paths.
93+
94+
Args:
95+
name: The name of the library to load
96+
_unused: Unused parameter (kept for interface consistency)
97+
98+
Returns:
99+
A LoadedDL object if successful, None if the library cannot be loaded
100+
"""
101+
from .supported_libs import SUPPORTED_WINDOWS_DLLS
102+
103+
driver_ver = cuDriverGetVersion()
104+
del driver_ver # Keeping this here because it will probably be needed in the future.
105+
106+
dll_names = SUPPORTED_WINDOWS_DLLS.get(name)
107+
if dll_names is None:
108+
return None
109+
110+
for dll_name in dll_names:
111+
handle = ctypes.windll.kernel32.LoadLibraryW(ctypes.c_wchar_p(dll_name))
112+
if handle:
113+
return LoadedDL(handle, abs_path_for_dynamic_library(handle), False)
114+
115+
return None
116+
117+
118+
def load_with_abs_path(libname: str, found_path: str) -> LoadedDL:
119+
"""Load a dynamic library from the given path.
120+
121+
Args:
122+
libname: The name of the library to load
123+
found_path: The absolute path to the DLL file
124+
125+
Returns:
126+
A LoadedDL object representing the loaded library
127+
128+
Raises:
129+
RuntimeError: If the DLL cannot be loaded
130+
"""
131+
from .supported_libs import LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY
132+
133+
if libname in LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY:
134+
add_dll_directory(found_path)
135+
136+
flags = WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR
137+
try:
138+
handle = win32api.LoadLibraryEx(found_path, 0, flags)
139+
except pywintypes.error as e:
140+
raise RuntimeError(f"Failed to load DLL at {found_path}: {e}") from e
141+
return LoadedDL(handle, found_path, False)

0 commit comments

Comments
 (0)