Skip to content

Commit 0a81ed5

Browse files
authored
Allow for patching tvm-ffi (Dao-AILab#83)
1 parent b6a22fc commit 0a81ed5

1 file changed

Lines changed: 105 additions & 19 deletions

File tree

quack/cute_dsl_ptxas.py

Lines changed: 105 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
"""
22
System ptxas replacement for CUTLASS DSL.
3+
4+
Usage::
5+
6+
CUTE_DSL_KEEP_PTX=1 CUTE_DSL_PTXAS_PATH=/usr/local/cuda/bin/ptxas pytest tests/
7+
38
Environment variables:
49
CUTE_DSL_PTXAS_PATH - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas)
10+
CUTE_DSL_KEEP_PTX - Must be set to 1 before cutlass is imported
511
CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output
12+
CUTE_DSL_DUMP_DIR - Directory for dumped PTX files (default: cwd)
13+
CUTE_DSL_KEEP_CUBIN - Set to 1 to save compiled cubin files
614
"""
715

816
import os
@@ -16,29 +24,81 @@
1624

1725

1826
CUTE_DSL_PTXAS_PATH = os.environ.get("CUTE_DSL_PTXAS_PATH", None)
27+
28+
if CUTE_DSL_PTXAS_PATH:
29+
os.environ["CUTE_DSL_KEEP_PTX"] = "1"
1930
VERBOSE = os.environ.get("CUTE_DSL_PTXAS_VERBOSE", "0") == "1"
2031

2132
_original_load_cuda_library = None
33+
_original_create_tvm_ffi_function = None
2234
_user_wanted_ptx = False # True if user originally set CUTE_DSL_KEEP_PTX=1
2335

2436

25-
def _log(msg):
37+
def _log(msg: str):
2638
if VERBOSE:
2739
print(f"[ptxas] {msg}", file=sys.stderr)
2840

2941

42+
def _read_ptx(ptx_path: Path) -> str | None:
43+
try:
44+
return ptx_path.read_bytes().decode("utf-8", errors="ignore").rstrip("\x00")
45+
except OSError as exc:
46+
_log(f"Failed to read {ptx_path}: {exc}")
47+
return None
48+
49+
50+
def _read_complete_ptx(ptx_path: Path) -> str | None:
51+
content = _read_ptx(ptx_path)
52+
if content is None or not content.rstrip().endswith("}"):
53+
return None
54+
return content
55+
56+
3057
def _get_ptx(compiled_func) -> tuple[str, Path] | None:
31-
"""Find and read PTX file, stripping null bytes."""
58+
"""Find dumped PTX for the compiled function."""
3259
func_name = getattr(compiled_func, "function_name", None)
3360
if not func_name:
61+
_log("Compiled function is missing function_name")
3462
return None
3563

36-
dump_dir = os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd())
37-
for ptx_path in Path(dump_dir).glob(f"*{func_name}*.ptx"):
38-
content = ptx_path.read_text().rstrip("\x00")
39-
if ".entry " in content and content.rstrip().endswith("}"):
40-
_log(f"Found PTX: {ptx_path}")
64+
dump_dir = Path(os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd()))
65+
dump_dir.mkdir(parents=True, exist_ok=True)
66+
67+
ptx_paths = sorted(
68+
dump_dir.rglob("*.ptx"), key=lambda path: path.stat().st_mtime_ns, reverse=True
69+
)
70+
_log(f"Searching dumped PTX for {func_name} in {dump_dir}")
71+
_log(f"Found {len(ptx_paths)} PTX candidate files in {dump_dir}")
72+
73+
# Strategy 1: match by filename
74+
filename_matches = [ptx_path for ptx_path in ptx_paths if func_name in ptx_path.name]
75+
if filename_matches:
76+
_log(f"Found {len(filename_matches)} filename matches for {func_name}")
77+
for ptx_path in filename_matches:
78+
content = _read_complete_ptx(ptx_path)
79+
if content is None:
80+
continue
81+
_log(f"Using PTX filename match for {func_name}: {ptx_path}")
82+
return content, ptx_path
83+
84+
# Strategy 2: match by .entry directive inside PTX
85+
entry_pattern = re.compile(rf"\.entry\s+{re.escape(func_name)}(?:\s|\()", re.MULTILINE)
86+
for ptx_path in ptx_paths:
87+
content = _read_complete_ptx(ptx_path)
88+
if content is None:
89+
continue
90+
if entry_pattern.search(content):
91+
_log(f"Found PTX for {func_name}: {ptx_path}")
4192
return content, ptx_path
93+
94+
# Strategy 3: use sole candidate as fallback
95+
if len(ptx_paths) == 1:
96+
content = _read_complete_ptx(ptx_paths[0])
97+
if content is not None:
98+
_log(f"Using sole PTX candidate for {func_name}: {ptx_paths[0]}")
99+
return content, ptx_paths[0]
100+
101+
_log(f"No PTX found for function {func_name} in {dump_dir}")
42102
return None
43103

44104

@@ -102,13 +162,15 @@ def _patched_load_cuda_library(self):
102162
_log(f"cudaLibraryLoadData failed ({err}), falling back to embedded ptxas")
103163
return _original_load_cuda_library(self)
104164

105-
# Register kernels on all devices
165+
# Register kernels on all devices (must match cuda_load_to_device's void*** convention)
106166
_, cuda_load_to_device = self._get_cuda_init_and_load()
107-
lib_ptr = ctypes.c_void_p(int(library))
167+
lib_handle = ctypes.c_void_p(int(library))
168+
ptr_to_lib = ctypes.pointer(lib_handle)
169+
ptr_to_ptr_to_lib = ctypes.pointer(ptr_to_lib)
108170
dev_id = ctypes.c_int32(0)
109171
err_val = ctypes.c_int32(0)
110172
args = (ctypes.c_void_p * 3)(
111-
ctypes.cast(ctypes.pointer(lib_ptr), ctypes.c_void_p),
173+
ctypes.cast(ptr_to_ptr_to_lib, ctypes.c_void_p),
112174
ctypes.cast(ctypes.pointer(dev_id), ctypes.c_void_p),
113175
ctypes.cast(ctypes.pointer(err_val), ctypes.c_void_p),
114176
)
@@ -126,26 +188,50 @@ def _patched_load_cuda_library(self):
126188
if not _user_wanted_ptx:
127189
ptx_path.unlink(missing_ok=True)
128190

129-
return [cuda_runtime.cudaLibrary_t(lib_ptr.value)]
191+
return [cuda_runtime.cudaLibrary_t(lib_handle.value)]
192+
193+
194+
def _patched_create_tvm_ffi_function(self):
195+
# Ensure CUDA library is loaded before TVM FFI creation
196+
if getattr(self, "_ptxas_cuda_library", None) is None:
197+
self._ptxas_cuda_library = self._load_cuda_library()
198+
_log(
199+
f"Loaded {len(self._ptxas_cuda_library)} CUDA libraries before creating TVM FFI function"
200+
)
201+
return _original_create_tvm_ffi_function(self)
130202

131203

132204
def patch():
133205
"""Install system ptxas hook. Call before importing cutlass."""
134-
global _original_load_cuda_library, _user_wanted_ptx
206+
global _original_load_cuda_library, _original_create_tvm_ffi_function, _user_wanted_ptx
135207

136208
assert CUTE_DSL_PTXAS_PATH is not None
137209
if not os.path.isfile(CUTE_DSL_PTXAS_PATH) or not os.access(CUTE_DSL_PTXAS_PATH, os.X_OK):
138210
raise RuntimeError(f"ptxas not found: {CUTE_DSL_PTXAS_PATH}")
139211

140-
# Track if user originally wanted PTX kept
141212
_user_wanted_ptx = os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1"
142-
# os.environ['CUTE_DSL_KEEP_PTX'] = '1'
143213
assert os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1", (
144214
"Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas"
145215
)
146216

147-
cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction
148-
_original_load_cuda_library = cls._load_cuda_library
149-
cls._load_cuda_library = _patched_load_cuda_library
150-
_log("Patch applied")
151-
return
217+
patched = False
218+
cuda_jit_function_cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction
219+
if cuda_jit_function_cls._load_cuda_library is not _patched_load_cuda_library:
220+
_original_load_cuda_library = cuda_jit_function_cls._load_cuda_library
221+
cuda_jit_function_cls._load_cuda_library = _patched_load_cuda_library
222+
patched = True
223+
224+
from cutlass.cutlass_dsl.tvm_ffi_provider import TVMFFIJitCompiledFunctionBase
225+
226+
if (
227+
TVMFFIJitCompiledFunctionBase._create_tvm_ffi_function
228+
is not _patched_create_tvm_ffi_function
229+
):
230+
_original_create_tvm_ffi_function = TVMFFIJitCompiledFunctionBase._create_tvm_ffi_function
231+
TVMFFIJitCompiledFunctionBase._create_tvm_ffi_function = _patched_create_tvm_ffi_function
232+
patched = True
233+
234+
if patched:
235+
_log(f"Installed system ptxas patch with {CUTE_DSL_PTXAS_PATH}")
236+
else:
237+
_log("System ptxas patch already installed")

0 commit comments

Comments
 (0)