2525
2626
2727@functools .cache
28- def _get_cuda_major_version () -> str :
28+ def _get_cuda_paths ():
29+ CUDA_PATH = os .environ .get ("CUDA_PATH" , os .environ .get ("CUDA_HOME" , None ))
30+ if not CUDA_PATH :
31+ raise RuntimeError ("Environment variable CUDA_PATH or CUDA_HOME is not set" )
32+ CUDA_PATH = CUDA_PATH .split (os .pathsep )
33+ print ("CUDA paths:" , CUDA_PATH )
34+ return CUDA_PATH
35+
36+ @functools .cache
37+ def _determine_cuda_major_version () -> str :
2938 """Determine the CUDA major version for building cuda.core.
3039
3140 This version is used for two purposes:
@@ -42,10 +51,11 @@ def _get_cuda_major_version() -> str:
4251 # Explicit override, e.g. in CI.
4352 cuda_major = os .environ .get ("CUDA_CORE_BUILD_MAJOR" )
4453 if cuda_major is not None :
54+ print ("CUDA MAJOR VERSION:" , cuda_major )
4555 return cuda_major
4656
4757 # Derive from the CUDA headers (the authoritative source for what we compile against).
48- cuda_path = os . environ . get ( "CUDA_PATH" , os . environ . get ( "CUDA_HOME" , None ) )
58+ cuda_path = _get_cuda_paths ( )
4959 if cuda_path :
5060 for root in cuda_path .split (os .pathsep ):
5161 cuda_h = os .path .join (root , "include" , "cuda.h" )
@@ -56,7 +66,9 @@ def _get_cuda_major_version() -> str:
5666 if m :
5767 v = int (m .group (1 ))
5868 # CUDA_VERSION is e.g. 12020 for 12.2.
59- return str (v // 1000 )
69+ cuda_major = str (v // 1000 )
70+ print ("CUDA MAJOR VERSION:" , cuda_major )
71+ return cuda_major
6072 except OSError :
6173 continue
6274
@@ -83,25 +95,12 @@ def _build_cuda_core():
8395
8496 # It seems setuptools' wildcard support has problems for namespace packages,
8597 # so we explicitly spell out all Extension instances.
86- root_module = "cuda.core"
87- root_path = f"{ os .path .sep } " .join (root_module .split ("." )) + os .path .sep
88- ext_files = glob .glob (f"{ root_path } /**/*.pyx" , recursive = True )
89-
90- def strip_prefix_suffix (filename ):
91- return filename [len (root_path ) : - 4 ]
92-
93- module_names = (strip_prefix_suffix (f ) for f in ext_files )
94-
95- @functools .cache
96- def get_cuda_paths ():
97- CUDA_PATH = os .environ .get ("CUDA_PATH" , os .environ .get ("CUDA_HOME" , None ))
98- if not CUDA_PATH :
99- raise RuntimeError ("Environment variable CUDA_PATH or CUDA_HOME is not set" )
100- CUDA_PATH = CUDA_PATH .split (os .pathsep )
101- print ("CUDA paths:" , CUDA_PATH )
102- return CUDA_PATH
98+ def module_names ():
99+ root_path = os .path .sep .join (["cuda" , "core" , "" ])
100+ for filename in glob .glob (f"{ root_path } /**/*.pyx" , recursive = True ):
101+ yield filename [len (root_path ) : - 4 ]
103102
104- all_include_dirs = list (os .path .join (root , "include" ) for root in get_cuda_paths ())
103+ all_include_dirs = list (os .path .join (root , "include" ) for root in _get_cuda_paths ())
105104 extra_compile_args = []
106105 if COMPILE_FOR_COVERAGE :
107106 # CYTHON_TRACE_NOGIL indicates to trace nogil functions. It is not
@@ -116,11 +115,11 @@ def get_cuda_paths():
116115 language = "c++" ,
117116 extra_compile_args = extra_compile_args ,
118117 )
119- for mod in module_names
118+ for mod in module_names ()
120119 )
121120
122121 nthreads = int (os .environ .get ("CUDA_PYTHON_PARALLEL_LEVEL" , os .cpu_count () // 2 ))
123- compile_time_env = {"CUDA_CORE_BUILD_MAJOR" : int (_get_cuda_major_version ())}
122+ compile_time_env = {"CUDA_CORE_BUILD_MAJOR" : int (_determine_cuda_major_version ())}
124123 compiler_directives = {"embedsignature" : True , "warn.deprecated.IF" : False , "freethreading_compatible" : True }
125124 if COMPILE_FOR_COVERAGE :
126125 compiler_directives ["linetrace" ] = True
@@ -147,7 +146,7 @@ def build_wheel(wheel_directory, config_settings=None, metadata_directory=None):
147146
148147
149148def _get_cuda_bindings_require ():
150- cuda_major = _get_cuda_major_version ()
149+ cuda_major = _determine_cuda_major_version ()
151150 return [f"cuda-bindings=={ cuda_major } .*" ]
152151
153152
0 commit comments