diff --git a/cuda_core/build_hooks.py b/cuda_core/build_hooks.py index a94120616e..812c5a76a1 100644 --- a/cuda_core/build_hooks.py +++ b/cuda_core/build_hooks.py @@ -11,7 +11,6 @@ import glob import os import re -import subprocess import sys import tempfile import zipfile @@ -185,28 +184,6 @@ def get_sources(mod_name): # related to free-threading builds. extra_compile_args += ["-DCYTHON_TRACE_NOGIL=1", "-DCYTHON_USE_SYS_MONITORING=0"] - # On Windows, _tensor_bridge.pyx needs a stub import library so the MSVC - # linker can resolve the AOTI symbols (they live in torch_cpu.dll at - # runtime). We generate the .lib from a .def file at build time. - # Note: aoti_torch_get_current_cuda_stream lives in torch_cuda.dll and - # is resolved lazily at runtime (not via the stub lib) — see - # _tensor_bridge.pyx. - _aoti_extra_link_args = [] - if sys.platform == "win32": - _def_file = os.path.join("cuda", "core", "_include", "aoti_shim.def") - _lib_file = os.path.join("build", "aoti_shim.lib") - os.makedirs("build", exist_ok=True) - subprocess.check_call( # noqa: S603 - ["lib", f"/DEF:{_def_file}", f"/OUT:{_lib_file}", "/MACHINE:X64"], # noqa: S607 - stdout=subprocess.DEVNULL, - ) - _aoti_extra_link_args = [_lib_file] - - def get_extra_link_args(mod_name): - if mod_name == "_tensor_bridge" and _aoti_extra_link_args: - return extra_link_args + _aoti_extra_link_args - return extra_link_args - ext_modules = tuple( Extension( f"cuda.core.{mod.replace(os.path.sep, '.')}", @@ -218,7 +195,7 @@ def get_extra_link_args(mod_name): + all_include_dirs, language="c++", extra_compile_args=extra_compile_args, - extra_link_args=get_extra_link_args(mod), + extra_link_args=extra_link_args, ) for mod in module_names() ) diff --git a/cuda_core/cuda/core/_include/aoti_shim.def b/cuda_core/cuda/core/_include/aoti_shim.def index e21097bd25..3cd1e31aea 100644 --- a/cuda_core/cuda/core/_include/aoti_shim.def +++ b/cuda_core/cuda/core/_include/aoti_shim.def @@ -4,7 +4,7 @@ ; At runtime the symbols resolve from torch_cpu.dll (loaded by 'import torch'). ; ; IMPORTANT: Keep this export list in sync with the AOTI_SHIM_API declarations -; in aoti_shim.h. build_hooks.py turns this file into the stub import library +; in aoti_shim.h. setup.py turns this file into the stub import library ; that MSVC uses to link _tensor_bridge, so any added/removed/renamed AOTI ; symbol must be updated in both files. LIBRARY torch_cpu.dll diff --git a/cuda_core/cuda/core/_include/aoti_shim.h b/cuda_core/cuda/core/_include/aoti_shim.h index 464d27de46..bf8894d940 100644 --- a/cuda_core/cuda/core/_include/aoti_shim.h +++ b/cuda_core/cuda/core/_include/aoti_shim.h @@ -52,10 +52,10 @@ typedef struct AtenTensorOpaque* AtenTensorHandle; /* * IMPORTANT: Keep the AOTI_SHIM_API declaration list below in sync with - * aoti_shim.def. On Windows, build_hooks.py turns that .def file into the - * stub import library that MSVC needs to link _tensor_bridge without making - * PyTorch a build-time dependency. If you add, remove, or rename an - * imported AOTI symbol here, update aoti_shim.def in the same change. + * aoti_shim.def. On Windows, setup.py generates that stub import library + * during build_ext so MSVC can link _tensor_bridge without making PyTorch a + * build-time dependency. If you add, remove, or rename an imported AOTI + * symbol here, update aoti_shim.def in the same change. * * Exception: aoti_torch_get_current_cuda_stream lives in torch_cuda (not * torch_cpu) and is resolved lazily at runtime — see _tensor_bridge.pyx. diff --git a/cuda_core/setup.py b/cuda_core/setup.py index 71d1548755..bde1fe22fe 100644 --- a/cuda_core/setup.py +++ b/cuda_core/setup.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import os +from pathlib import Path import build_hooks # our build backend from setuptools import setup @@ -11,11 +12,63 @@ nthreads = int(os.environ.get("CUDA_PYTHON_PARALLEL_LEVEL", os.cpu_count() // 2)) coverage_mode = bool(int(os.environ.get("CUDA_PYTHON_COVERAGE", "0"))) +_ROOT_DIR = Path(__file__).resolve().parent +_AOTI_SHIM_DEF_FILE = _ROOT_DIR / "cuda" / "core" / "_include" / "aoti_shim.def" +_AOTI_SHIM_LIB_FILE = _ROOT_DIR / "build" / "aoti_shim.lib" +_TENSOR_BRIDGE_EXT_NAME = "cuda.core._tensor_bridge" + + +def _ensure_compiler_initialized(compiler, plat_name): + initialize = getattr(compiler, "initialize", None) + if callable(initialize) and not getattr(compiler, "initialized", False): + if plat_name is None: + initialize() + else: + initialize(plat_name) + + +def _build_aoti_shim_lib(compiler): + # Reuse setuptools' initialized MSVC compiler instead of rediscovering + # lib.exe separately in the build backend. + lib_exe = getattr(compiler, "lib", None) + if not lib_exe: + raise RuntimeError("MSVC compiler did not expose lib.exe after initialization.") + + _AOTI_SHIM_LIB_FILE.parent.mkdir(exist_ok=True) + compiler.spawn( + [ + lib_exe, + f"/DEF:{_AOTI_SHIM_DEF_FILE}", + f"/OUT:{_AOTI_SHIM_LIB_FILE}", + "/MACHINE:X64", + ] + ) + return str(_AOTI_SHIM_LIB_FILE) class build_ext(_build_ext): # noqa: N801 + def _configure_windows_tensor_bridge(self): + if os.name != "nt" or getattr(self.compiler, "compiler_type", None) != "msvc": + return + + # _tensor_bridge imports AOTI symbols from torch_cpu.dll, which on + # Windows requires a stub import library for the MSVC linker. + for ext in self.extensions: + if ext.name != _TENSOR_BRIDGE_EXT_NAME: + continue + + _ensure_compiler_initialized(self.compiler, self.plat_name) + shim_lib = _build_aoti_shim_lib(self.compiler) + link_args = list(ext.extra_link_args or []) + if shim_lib not in link_args: + ext.extra_link_args = [*link_args, shim_lib] + return + + raise RuntimeError(f"Failed to find extension {_TENSOR_BRIDGE_EXT_NAME!r} for Windows build.") + def build_extensions(self): self.parallel = nthreads + self._configure_windows_tensor_bridge() super().build_extensions()