11"""
22System ptxas replacement for CUTLASS DSL.
3+
4+ Usage::
5+
6+ CUTE_DSL_KEEP_PTX=1 CUTE_DSL_PTXAS_PATH=/usr/local/cuda/bin/ptxas pytest tests/
7+
38Environment variables:
49 CUTE_DSL_PTXAS_PATH - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas)
10+ CUTE_DSL_KEEP_PTX - Must be set to 1 before cutlass is imported
511 CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output
12+ CUTE_DSL_DUMP_DIR - Directory for dumped PTX files (default: cwd)
13+ CUTE_DSL_KEEP_CUBIN - Set to 1 to save compiled cubin files
614"""
715
816import os
1624
1725
1826CUTE_DSL_PTXAS_PATH = os .environ .get ("CUTE_DSL_PTXAS_PATH" , None )
27+
28+ if CUTE_DSL_PTXAS_PATH :
29+ os .environ ["CUTE_DSL_KEEP_PTX" ] = "1"
1930VERBOSE = os .environ .get ("CUTE_DSL_PTXAS_VERBOSE" , "0" ) == "1"
2031
2132_original_load_cuda_library = None
33+ _original_create_tvm_ffi_function = None
2234_user_wanted_ptx = False # True if user originally set CUTE_DSL_KEEP_PTX=1
2335
2436
25- def _log (msg ):
37+ def _log (msg : str ):
2638 if VERBOSE :
2739 print (f"[ptxas] { msg } " , file = sys .stderr )
2840
2941
42+ def _read_ptx (ptx_path : Path ) -> str | None :
43+ try :
44+ return ptx_path .read_bytes ().decode ("utf-8" , errors = "ignore" ).rstrip ("\x00 " )
45+ except OSError as exc :
46+ _log (f"Failed to read { ptx_path } : { exc } " )
47+ return None
48+
49+
50+ def _read_complete_ptx (ptx_path : Path ) -> str | None :
51+ content = _read_ptx (ptx_path )
52+ if content is None or not content .rstrip ().endswith ("}" ):
53+ return None
54+ return content
55+
56+
3057def _get_ptx (compiled_func ) -> tuple [str , Path ] | None :
31- """Find and read PTX file, stripping null bytes ."""
58+ """Find dumped PTX for the compiled function ."""
3259 func_name = getattr (compiled_func , "function_name" , None )
3360 if not func_name :
61+ _log ("Compiled function is missing function_name" )
3462 return None
3563
36- dump_dir = os .environ .get ("CUTE_DSL_DUMP_DIR" , Path .cwd ())
37- for ptx_path in Path (dump_dir ).glob (f"*{ func_name } *.ptx" ):
38- content = ptx_path .read_text ().rstrip ("\x00 " )
39- if ".entry " in content and content .rstrip ().endswith ("}" ):
40- _log (f"Found PTX: { ptx_path } " )
64+ dump_dir = Path (os .environ .get ("CUTE_DSL_DUMP_DIR" , Path .cwd ()))
65+ dump_dir .mkdir (parents = True , exist_ok = True )
66+
67+ ptx_paths = sorted (
68+ dump_dir .rglob ("*.ptx" ), key = lambda path : path .stat ().st_mtime_ns , reverse = True
69+ )
70+ _log (f"Searching dumped PTX for { func_name } in { dump_dir } " )
71+ _log (f"Found { len (ptx_paths )} PTX candidate files in { dump_dir } " )
72+
73+ # Strategy 1: match by filename
74+ filename_matches = [ptx_path for ptx_path in ptx_paths if func_name in ptx_path .name ]
75+ if filename_matches :
76+ _log (f"Found { len (filename_matches )} filename matches for { func_name } " )
77+ for ptx_path in filename_matches :
78+ content = _read_complete_ptx (ptx_path )
79+ if content is None :
80+ continue
81+ _log (f"Using PTX filename match for { func_name } : { ptx_path } " )
82+ return content , ptx_path
83+
84+ # Strategy 2: match by .entry directive inside PTX
85+ entry_pattern = re .compile (rf"\.entry\s+{ re .escape (func_name )} (?:\s|\()" , re .MULTILINE )
86+ for ptx_path in ptx_paths :
87+ content = _read_complete_ptx (ptx_path )
88+ if content is None :
89+ continue
90+ if entry_pattern .search (content ):
91+ _log (f"Found PTX for { func_name } : { ptx_path } " )
4192 return content , ptx_path
93+
94+ # Strategy 3: use sole candidate as fallback
95+ if len (ptx_paths ) == 1 :
96+ content = _read_complete_ptx (ptx_paths [0 ])
97+ if content is not None :
98+ _log (f"Using sole PTX candidate for { func_name } : { ptx_paths [0 ]} " )
99+ return content , ptx_paths [0 ]
100+
101+ _log (f"No PTX found for function { func_name } in { dump_dir } " )
42102 return None
43103
44104
@@ -102,13 +162,15 @@ def _patched_load_cuda_library(self):
102162 _log (f"cudaLibraryLoadData failed ({ err } ), falling back to embedded ptxas" )
103163 return _original_load_cuda_library (self )
104164
105- # Register kernels on all devices
165+ # Register kernels on all devices (must match cuda_load_to_device's void*** convention)
106166 _ , cuda_load_to_device = self ._get_cuda_init_and_load ()
107- lib_ptr = ctypes .c_void_p (int (library ))
167+ lib_handle = ctypes .c_void_p (int (library ))
168+ ptr_to_lib = ctypes .pointer (lib_handle )
169+ ptr_to_ptr_to_lib = ctypes .pointer (ptr_to_lib )
108170 dev_id = ctypes .c_int32 (0 )
109171 err_val = ctypes .c_int32 (0 )
110172 args = (ctypes .c_void_p * 3 )(
111- ctypes .cast (ctypes . pointer ( lib_ptr ) , ctypes .c_void_p ),
173+ ctypes .cast (ptr_to_ptr_to_lib , ctypes .c_void_p ),
112174 ctypes .cast (ctypes .pointer (dev_id ), ctypes .c_void_p ),
113175 ctypes .cast (ctypes .pointer (err_val ), ctypes .c_void_p ),
114176 )
@@ -126,26 +188,50 @@ def _patched_load_cuda_library(self):
126188 if not _user_wanted_ptx :
127189 ptx_path .unlink (missing_ok = True )
128190
129- return [cuda_runtime .cudaLibrary_t (lib_ptr .value )]
191+ return [cuda_runtime .cudaLibrary_t (lib_handle .value )]
192+
193+
194+ def _patched_create_tvm_ffi_function (self ):
195+ # Ensure CUDA library is loaded before TVM FFI creation
196+ if getattr (self , "_ptxas_cuda_library" , None ) is None :
197+ self ._ptxas_cuda_library = self ._load_cuda_library ()
198+ _log (
199+ f"Loaded { len (self ._ptxas_cuda_library )} CUDA libraries before creating TVM FFI function"
200+ )
201+ return _original_create_tvm_ffi_function (self )
130202
131203
132204def patch ():
133205 """Install system ptxas hook. Call before importing cutlass."""
134- global _original_load_cuda_library , _user_wanted_ptx
206+ global _original_load_cuda_library , _original_create_tvm_ffi_function , _user_wanted_ptx
135207
136208 assert CUTE_DSL_PTXAS_PATH is not None
137209 if not os .path .isfile (CUTE_DSL_PTXAS_PATH ) or not os .access (CUTE_DSL_PTXAS_PATH , os .X_OK ):
138210 raise RuntimeError (f"ptxas not found: { CUTE_DSL_PTXAS_PATH } " )
139211
140- # Track if user originally wanted PTX kept
141212 _user_wanted_ptx = os .environ .get ("CUTE_DSL_KEEP_PTX" , "0" ) == "1"
142- # os.environ['CUTE_DSL_KEEP_PTX'] = '1'
143213 assert os .environ .get ("CUTE_DSL_KEEP_PTX" , "0" ) == "1" , (
144214 "Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas"
145215 )
146216
147- cls = cutlass .cutlass_dsl .cuda_jit_executor .CudaDialectJitCompiledFunction
148- _original_load_cuda_library = cls ._load_cuda_library
149- cls ._load_cuda_library = _patched_load_cuda_library
150- _log ("Patch applied" )
151- return
217+ patched = False
218+ cuda_jit_function_cls = cutlass .cutlass_dsl .cuda_jit_executor .CudaDialectJitCompiledFunction
219+ if cuda_jit_function_cls ._load_cuda_library is not _patched_load_cuda_library :
220+ _original_load_cuda_library = cuda_jit_function_cls ._load_cuda_library
221+ cuda_jit_function_cls ._load_cuda_library = _patched_load_cuda_library
222+ patched = True
223+
224+ from cutlass .cutlass_dsl .tvm_ffi_provider import TVMFFIJitCompiledFunctionBase
225+
226+ if (
227+ TVMFFIJitCompiledFunctionBase ._create_tvm_ffi_function
228+ is not _patched_create_tvm_ffi_function
229+ ):
230+ _original_create_tvm_ffi_function = TVMFFIJitCompiledFunctionBase ._create_tvm_ffi_function
231+ TVMFFIJitCompiledFunctionBase ._create_tvm_ffi_function = _patched_create_tvm_ffi_function
232+ patched = True
233+
234+ if patched :
235+ _log (f"Installed system ptxas patch with { CUTE_DSL_PTXAS_PATH } " )
236+ else :
237+ _log ("System ptxas patch already installed" )
0 commit comments