-
Notifications
You must be signed in to change notification settings - Fork 278
Expand file tree
/
Copy pathload_dl_windows.py
More file actions
200 lines (155 loc) · 7.43 KB
/
load_dl_windows.py
File metadata and controls
200 lines (155 loc) · 7.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import ctypes
import ctypes.wintypes
import os
import struct
from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import (
LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY,
SUPPORTED_WINDOWS_DLLS,
)
# Mirrors WinBase.h (unfortunately not defined already elsewhere)
WINBASE_LOAD_WITH_ALTERED_SEARCH_PATH = 0x00000008
WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
POINTER_ADDRESS_SPACE = 2 ** (struct.calcsize("P") * 8)
# Set up kernel32 functions with proper types
kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined]
# GetModuleHandleW
kernel32.GetModuleHandleW.argtypes = [ctypes.wintypes.LPCWSTR]
kernel32.GetModuleHandleW.restype = ctypes.wintypes.HMODULE
# LoadLibraryExW
kernel32.LoadLibraryExW.argtypes = [
ctypes.wintypes.LPCWSTR, # lpLibFileName
ctypes.wintypes.HANDLE, # hFile (reserved, must be NULL)
ctypes.wintypes.DWORD, # dwFlags
]
kernel32.LoadLibraryExW.restype = ctypes.wintypes.HMODULE
# GetModuleFileNameW
kernel32.GetModuleFileNameW.argtypes = [
ctypes.wintypes.HMODULE, # hModule
ctypes.wintypes.LPWSTR, # lpFilename
ctypes.wintypes.DWORD, # nSize
]
kernel32.GetModuleFileNameW.restype = ctypes.wintypes.DWORD
# AddDllDirectory (Windows 7+)
kernel32.AddDllDirectory.argtypes = [ctypes.wintypes.LPCWSTR]
kernel32.AddDllDirectory.restype = ctypes.c_void_p # DLL_DIRECTORY_COOKIE
# SearchPathW - find a file in the system search path
kernel32.SearchPathW.argtypes = [
ctypes.wintypes.LPCWSTR, # lpPath (NULL to use standard search)
ctypes.wintypes.LPCWSTR, # lpFileName
ctypes.wintypes.LPCWSTR, # lpExtension
ctypes.wintypes.DWORD, # nBufferLength
ctypes.wintypes.LPWSTR, # lpBuffer
ctypes.POINTER(ctypes.wintypes.LPWSTR), # lpFilePart
]
kernel32.SearchPathW.restype = ctypes.wintypes.DWORD
def ctypes_handle_to_unsigned_int(handle: ctypes.wintypes.HMODULE) -> int:
"""Convert ctypes HMODULE to unsigned int."""
handle_uint = int(handle)
if handle_uint < 0:
# Convert from signed to unsigned representation
handle_uint += POINTER_ADDRESS_SPACE
return handle_uint
def add_dll_directory(dll_abs_path: str) -> None:
"""Add a DLL directory to the search path and update PATH environment variable.
Args:
dll_abs_path: Absolute path to the DLL file
Raises:
AssertionError: If the directory containing the DLL does not exist
"""
dirpath = os.path.dirname(dll_abs_path)
assert os.path.isdir(dirpath), dll_abs_path
# Add the DLL directory to the search path
result = kernel32.AddDllDirectory(dirpath)
if not result:
# Fallback: just update PATH if AddDllDirectory fails
pass
# Update PATH as a fallback for dependent DLL resolution
curr_path = os.environ.get("PATH")
os.environ["PATH"] = dirpath if curr_path is None else os.pathsep.join((curr_path, dirpath))
def abs_path_for_dynamic_library(libname: str, handle: ctypes.wintypes.HMODULE) -> str:
"""Get the absolute path of a loaded dynamic library on Windows."""
# Create buffer for the path
buffer = ctypes.create_unicode_buffer(260) # MAX_PATH
length = kernel32.GetModuleFileNameW(handle, buffer, len(buffer))
if length == 0:
error_code = ctypes.GetLastError() # type: ignore[attr-defined]
raise RuntimeError(f"GetModuleFileNameW failed for {libname!r} (error code: {error_code})")
# If buffer was too small, try with larger buffer
if length == len(buffer):
buffer = ctypes.create_unicode_buffer(32768) # Extended path length
length = kernel32.GetModuleFileNameW(handle, buffer, len(buffer))
if length == 0:
error_code = ctypes.GetLastError() # type: ignore[attr-defined]
raise RuntimeError(f"GetModuleFileNameW failed for {libname!r} (error code: {error_code})")
return buffer.value
def check_if_already_loaded_from_elsewhere(libname: str, have_abs_path: bool) -> LoadedDL | None:
for dll_name in SUPPORTED_WINDOWS_DLLS.get(libname, ()):
handle = kernel32.GetModuleHandleW(dll_name)
if handle:
abs_path = abs_path_for_dynamic_library(libname, handle)
if have_abs_path and libname in LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY:
# This is a side-effect if the pathfinder loads the library via
# load_with_abs_path(). To make the side-effect more deterministic,
# activate it even if the library was already loaded from elsewhere.
add_dll_directory(abs_path)
return LoadedDL(abs_path, True, ctypes_handle_to_unsigned_int(handle), "was-already-loaded-from-elsewhere")
return None
def _search_path_for_dll(dll_name: str) -> str | None:
"""Search for a DLL using Windows SearchPathW.
Args:
dll_name: The name of the DLL to find
Returns:
The absolute path to the DLL if found, None otherwise
"""
buffer = ctypes.create_unicode_buffer(260) # MAX_PATH
length = kernel32.SearchPathW(None, dll_name, None, len(buffer), buffer, None)
if length == 0:
return None
# If buffer was too small, try with larger buffer
if length > len(buffer):
buffer = ctypes.create_unicode_buffer(length)
length = kernel32.SearchPathW(None, dll_name, None, len(buffer), buffer, None)
if length == 0:
return None
return buffer.value
def load_with_system_search(libname: str) -> LoadedDL | None:
"""Try to load a DLL using system search paths.
Args:
libname: The name of the library to load
Returns:
A LoadedDL object if successful, None if the library cannot be loaded
"""
# Reverse tabulated names to achieve new → old search order.
for dll_name in reversed(SUPPORTED_WINDOWS_DLLS.get(libname, ())):
# First, find the DLL's full path using SearchPathW
found_path = _search_path_for_dll(dll_name)
if found_path:
# Load with LOAD_WITH_ALTERED_SEARCH_PATH so Windows searches for
# dependencies from the DLL's directory (required for CUDA DLLs
# whose dependencies are co-located)
handle = kernel32.LoadLibraryExW(found_path, None, WINBASE_LOAD_WITH_ALTERED_SEARCH_PATH)
if handle:
return LoadedDL(found_path, False, ctypes_handle_to_unsigned_int(handle), "system-search")
return None
def load_with_abs_path(libname: str, found_path: str, found_via: str | None = None) -> LoadedDL:
"""Load a dynamic library from the given path.
Args:
libname: The name of the library to load
found_path: The absolute path to the DLL file
Returns:
A LoadedDL object representing the loaded library
Raises:
RuntimeError: If the DLL cannot be loaded
"""
if libname in LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY:
add_dll_directory(found_path)
flags = WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR
handle = kernel32.LoadLibraryExW(found_path, None, flags)
if not handle:
error_code = ctypes.GetLastError() # type: ignore[attr-defined]
raise RuntimeError(f"Failed to load DLL at {found_path}: Windows error {error_code}")
return LoadedDL(found_path, False, ctypes_handle_to_unsigned_int(handle), found_via)