Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 1 addition & 24 deletions cuda_core/build_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import glob
import os
import re
import subprocess
import sys
import tempfile
import zipfile
Expand Down Expand Up @@ -185,28 +184,6 @@ def get_sources(mod_name):
# related to free-threading builds.
extra_compile_args += ["-DCYTHON_TRACE_NOGIL=1", "-DCYTHON_USE_SYS_MONITORING=0"]

# On Windows, _tensor_bridge.pyx needs a stub import library so the MSVC
# linker can resolve the AOTI symbols (they live in torch_cpu.dll at
# runtime). We generate the .lib from a .def file at build time.
# Note: aoti_torch_get_current_cuda_stream lives in torch_cuda.dll and
# is resolved lazily at runtime (not via the stub lib) — see
# _tensor_bridge.pyx.
_aoti_extra_link_args = []
if sys.platform == "win32":
_def_file = os.path.join("cuda", "core", "_include", "aoti_shim.def")
_lib_file = os.path.join("build", "aoti_shim.lib")
os.makedirs("build", exist_ok=True)
subprocess.check_call( # noqa: S603
["lib", f"/DEF:{_def_file}", f"/OUT:{_lib_file}", "/MACHINE:X64"], # noqa: S607
stdout=subprocess.DEVNULL,
)
_aoti_extra_link_args = [_lib_file]

def get_extra_link_args(mod_name):
if mod_name == "_tensor_bridge" and _aoti_extra_link_args:
return extra_link_args + _aoti_extra_link_args
return extra_link_args

ext_modules = tuple(
Extension(
f"cuda.core.{mod.replace(os.path.sep, '.')}",
Expand All @@ -218,7 +195,7 @@ def get_extra_link_args(mod_name):
+ all_include_dirs,
language="c++",
extra_compile_args=extra_compile_args,
extra_link_args=get_extra_link_args(mod),
extra_link_args=extra_link_args,
)
for mod in module_names()
)
Expand Down
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/_include/aoti_shim.def
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
; At runtime the symbols resolve from torch_cpu.dll (loaded by 'import torch').
;
; IMPORTANT: Keep this export list in sync with the AOTI_SHIM_API declarations
; in aoti_shim.h. build_hooks.py turns this file into the stub import library
; in aoti_shim.h. setup.py turns this file into the stub import library
; that MSVC uses to link _tensor_bridge, so any added/removed/renamed AOTI
; symbol must be updated in both files.
LIBRARY torch_cpu.dll
Expand Down
8 changes: 4 additions & 4 deletions cuda_core/cuda/core/_include/aoti_shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ typedef struct AtenTensorOpaque* AtenTensorHandle;

/*
* IMPORTANT: Keep the AOTI_SHIM_API declaration list below in sync with
* aoti_shim.def. On Windows, build_hooks.py turns that .def file into the
* stub import library that MSVC needs to link _tensor_bridge without making
* PyTorch a build-time dependency. If you add, remove, or rename an
* imported AOTI symbol here, update aoti_shim.def in the same change.
* aoti_shim.def. On Windows, setup.py generates that stub import library
* during build_ext so MSVC can link _tensor_bridge without making PyTorch a
* build-time dependency. If you add, remove, or rename an imported AOTI
* symbol here, update aoti_shim.def in the same change.
*
* Exception: aoti_torch_get_current_cuda_stream lives in torch_cuda (not
* torch_cpu) and is resolved lazily at runtime — see _tensor_bridge.pyx.
Expand Down
53 changes: 53 additions & 0 deletions cuda_core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import os
from pathlib import Path

import build_hooks # our build backend
from setuptools import setup
Expand All @@ -11,11 +12,63 @@

nthreads = int(os.environ.get("CUDA_PYTHON_PARALLEL_LEVEL", os.cpu_count() // 2))
coverage_mode = bool(int(os.environ.get("CUDA_PYTHON_COVERAGE", "0")))
_ROOT_DIR = Path(__file__).resolve().parent
_AOTI_SHIM_DEF_FILE = _ROOT_DIR / "cuda" / "core" / "_include" / "aoti_shim.def"
_AOTI_SHIM_LIB_FILE = _ROOT_DIR / "build" / "aoti_shim.lib"
_TENSOR_BRIDGE_EXT_NAME = "cuda.core._tensor_bridge"


def _ensure_compiler_initialized(compiler, plat_name):
initialize = getattr(compiler, "initialize", None)
if callable(initialize) and not getattr(compiler, "initialized", False):
if plat_name is None:
initialize()
else:
initialize(plat_name)


def _build_aoti_shim_lib(compiler):
# Reuse setuptools' initialized MSVC compiler instead of rediscovering
# lib.exe separately in the build backend.
lib_exe = getattr(compiler, "lib", None)
if not lib_exe:
raise RuntimeError("MSVC compiler did not expose lib.exe after initialization.")

_AOTI_SHIM_LIB_FILE.parent.mkdir(exist_ok=True)
compiler.spawn(
[
lib_exe,
f"/DEF:{_AOTI_SHIM_DEF_FILE}",
f"/OUT:{_AOTI_SHIM_LIB_FILE}",
"/MACHINE:X64",
]
)
return str(_AOTI_SHIM_LIB_FILE)


class build_ext(_build_ext): # noqa: N801
def _configure_windows_tensor_bridge(self):
if os.name != "nt" or getattr(self.compiler, "compiler_type", None) != "msvc":
return

# _tensor_bridge imports AOTI symbols from torch_cpu.dll, which on
# Windows requires a stub import library for the MSVC linker.
for ext in self.extensions:
if ext.name != _TENSOR_BRIDGE_EXT_NAME:
continue

_ensure_compiler_initialized(self.compiler, self.plat_name)
shim_lib = _build_aoti_shim_lib(self.compiler)
link_args = list(ext.extra_link_args or [])
if shim_lib not in link_args:
ext.extra_link_args = [*link_args, shim_lib]
return

raise RuntimeError(f"Failed to find extension {_TENSOR_BRIDGE_EXT_NAME!r} for Windows build.")

def build_extensions(self):
self.parallel = nthreads
self._configure_windows_tensor_bridge()
super().build_extensions()


Expand Down
Loading