Skip to content

Commit 360b3a1

Browse files
committed
Minor refactoring.
1 parent af0f397 commit 360b3a1

1 file changed

Lines changed: 23 additions & 24 deletions

File tree

cuda_core/build_hooks.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,16 @@
2525

2626

2727
@functools.cache
28-
def _get_cuda_major_version() -> str:
28+
def _get_cuda_paths():
29+
CUDA_PATH = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME", None))
30+
if not CUDA_PATH:
31+
raise RuntimeError("Environment variable CUDA_PATH or CUDA_HOME is not set")
32+
CUDA_PATH = CUDA_PATH.split(os.pathsep)
33+
print("CUDA paths:", CUDA_PATH)
34+
return CUDA_PATH
35+
36+
@functools.cache
37+
def _determine_cuda_major_version() -> str:
2938
"""Determine the CUDA major version for building cuda.core.
3039
3140
This version is used for two purposes:
@@ -42,10 +51,11 @@ def _get_cuda_major_version() -> str:
4251
# Explicit override, e.g. in CI.
4352
cuda_major = os.environ.get("CUDA_CORE_BUILD_MAJOR")
4453
if cuda_major is not None:
54+
print("CUDA MAJOR VERSION:", cuda_major)
4555
return cuda_major
4656

4757
# Derive from the CUDA headers (the authoritative source for what we compile against).
48-
cuda_path = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME", None))
58+
cuda_path = _get_cuda_paths()
4959
if cuda_path:
5060
for root in cuda_path.split(os.pathsep):
5161
cuda_h = os.path.join(root, "include", "cuda.h")
@@ -56,7 +66,9 @@ def _get_cuda_major_version() -> str:
5666
if m:
5767
v = int(m.group(1))
5868
# CUDA_VERSION is e.g. 12020 for 12.2.
59-
return str(v // 1000)
69+
cuda_major = str(v // 1000)
70+
print("CUDA MAJOR VERSION:", cuda_major)
71+
return cuda_major
6072
except OSError:
6173
continue
6274

@@ -83,25 +95,12 @@ def _build_cuda_core():
8395

8496
# It seems setuptools' wildcard support has problems for namespace packages,
8597
# so we explicitly spell out all Extension instances.
86-
root_module = "cuda.core"
87-
root_path = f"{os.path.sep}".join(root_module.split(".")) + os.path.sep
88-
ext_files = glob.glob(f"{root_path}/**/*.pyx", recursive=True)
89-
90-
def strip_prefix_suffix(filename):
91-
return filename[len(root_path) : -4]
92-
93-
module_names = (strip_prefix_suffix(f) for f in ext_files)
94-
95-
@functools.cache
96-
def get_cuda_paths():
97-
CUDA_PATH = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME", None))
98-
if not CUDA_PATH:
99-
raise RuntimeError("Environment variable CUDA_PATH or CUDA_HOME is not set")
100-
CUDA_PATH = CUDA_PATH.split(os.pathsep)
101-
print("CUDA paths:", CUDA_PATH)
102-
return CUDA_PATH
98+
def module_names():
99+
root_path = os.path.sep.join(["cuda", "core", ""])
100+
for filename in glob.glob(f"{root_path}/**/*.pyx", recursive=True):
101+
yield filename[len(root_path) : -4]
103102

104-
all_include_dirs = list(os.path.join(root, "include") for root in get_cuda_paths())
103+
all_include_dirs = list(os.path.join(root, "include") for root in _get_cuda_paths())
105104
extra_compile_args = []
106105
if COMPILE_FOR_COVERAGE:
107106
# CYTHON_TRACE_NOGIL indicates to trace nogil functions. It is not
@@ -116,11 +115,11 @@ def get_cuda_paths():
116115
language="c++",
117116
extra_compile_args=extra_compile_args,
118117
)
119-
for mod in module_names
118+
for mod in module_names()
120119
)
121120

122121
nthreads = int(os.environ.get("CUDA_PYTHON_PARALLEL_LEVEL", os.cpu_count() // 2))
123-
compile_time_env = {"CUDA_CORE_BUILD_MAJOR": int(_get_cuda_major_version())}
122+
compile_time_env = {"CUDA_CORE_BUILD_MAJOR": int(_determine_cuda_major_version())}
124123
compiler_directives = {"embedsignature": True, "warn.deprecated.IF": False, "freethreading_compatible": True}
125124
if COMPILE_FOR_COVERAGE:
126125
compiler_directives["linetrace"] = True
@@ -147,7 +146,7 @@ def build_wheel(wheel_directory, config_settings=None, metadata_directory=None):
147146

148147

149148
def _get_cuda_bindings_require():
150-
cuda_major = _get_cuda_major_version()
149+
cuda_major = _determine_cuda_major_version()
151150
return [f"cuda-bindings=={cuda_major}.*"]
152151

153152

0 commit comments

Comments
 (0)