Skip to content

Commit fdb3a7e

Browse files
committed
Refactor Device.__new__ into helper functions
Extract Device.__new__ logic into cdef helper functions: - Device_ensure_cuda_initialized(): cuInit + version check - Device_resolve_device_id(): resolve None to current device or 0 - Device_ensure_tls_devices(): create thread-local singletons Reduces Device.__new__ from ~60 lines to ~12 lines. Helpers placed after Device class following memory module pattern.
1 parent b2083ed commit fdb3a7e

1 file changed

Lines changed: 60 additions & 51 deletions

File tree

cuda_core/cuda/core/_device.pyx

Lines changed: 60 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -965,60 +965,12 @@ class Device:
965965
__slots__ = ("_id", "_memory_resource", "_has_inited", "_properties", "_uuid")
966966

967967
def __new__(cls, device_id: Device | int | None = None):
968-
# Handle device_id argument.
969968
if isinstance(device_id, Device):
970969
return device_id
971-
else:
972-
device_id = getattr(device_id, 'device_id', device_id)
973-
974-
# Initialize CUDA.
975-
global _is_cuInit
976-
if _is_cuInit is False:
977-
with _lock, nogil:
978-
HANDLE_RETURN(cydriver.cuInit(0))
979-
_is_cuInit = True
980-
# Check version compatibility after CUDA is initialized
981-
try:
982-
from cuda.bindings.utils import warn_if_cuda_major_version_mismatch
983-
except ImportError:
984-
pass
985-
else:
986-
warn_if_cuda_major_version_mismatch()
987-
988-
# important: creating a Device instance does not initialize the GPU!
989-
cdef cydriver.CUdevice dev
990-
cdef cydriver.CUcontext ctx
991-
if device_id is None:
992-
with nogil:
993-
err = cydriver.cuCtxGetDevice(&dev)
994-
if err == cydriver.CUresult.CUDA_SUCCESS:
995-
device_id = int(dev)
996-
elif err == cydriver.CUresult.CUDA_ERROR_INVALID_CONTEXT:
997-
with nogil:
998-
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx))
999-
assert <void*>(ctx) == NULL
1000-
device_id = 0 # cudart behavior
1001-
else:
1002-
HANDLE_RETURN(err)
1003-
elif device_id < 0:
1004-
raise ValueError(f"device_id must be >= 0, got {device_id}")
1005970

1006-
# ensure Device is singleton
1007-
cdef int total
1008-
try:
1009-
devices = _tls.devices
1010-
except AttributeError:
1011-
with nogil:
1012-
HANDLE_RETURN(cydriver.cuDeviceGetCount(&total))
1013-
devices = _tls.devices = []
1014-
for dev_id in range(total):
1015-
device = super().__new__(cls)
1016-
device._id = dev_id
1017-
device._memory_resource = None
1018-
device._has_inited = False
1019-
device._properties = None
1020-
device._uuid = None
1021-
devices.append(device)
971+
Device_ensure_cuda_initialized()
972+
device_id = Device_resolve_device_id(device_id)
973+
devices = Device_ensure_tls_devices(cls)
1022974

1023975
try:
1024976
return devices[device_id]
@@ -1413,3 +1365,60 @@ class Device:
14131365
"""
14141366
self._check_context_initialized()
14151367
return GraphBuilder._init(stream=self.create_stream(), is_stream_owner=True)
1368+
1369+
1370+
cdef inline void Device_ensure_cuda_initialized() except *:
1371+
"""Initialize CUDA driver and check version compatibility (once per process)."""
1372+
global _is_cuInit
1373+
if _is_cuInit is False:
1374+
with _lock, nogil:
1375+
HANDLE_RETURN(cydriver.cuInit(0))
1376+
_is_cuInit = True
1377+
try:
1378+
from cuda.bindings.utils import warn_if_cuda_major_version_mismatch
1379+
except ImportError:
1380+
pass
1381+
else:
1382+
warn_if_cuda_major_version_mismatch()
1383+
1384+
1385+
cdef inline int Device_resolve_device_id(device_id) except? -1:
1386+
"""Resolve device_id, defaulting to current device or 0."""
1387+
cdef cydriver.CUdevice dev
1388+
cdef cydriver.CUcontext ctx
1389+
cdef cydriver.CUresult err
1390+
if device_id is None:
1391+
with nogil:
1392+
err = cydriver.cuCtxGetDevice(&dev)
1393+
if err == cydriver.CUresult.CUDA_SUCCESS:
1394+
return int(dev)
1395+
elif err == cydriver.CUresult.CUDA_ERROR_INVALID_CONTEXT:
1396+
with nogil:
1397+
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx))
1398+
assert <void*>(ctx) == NULL
1399+
return 0 # cudart behavior
1400+
else:
1401+
HANDLE_RETURN(err)
1402+
elif device_id < 0:
1403+
raise ValueError(f"device_id must be >= 0, got {device_id}")
1404+
return device_id
1405+
1406+
1407+
cdef inline list Device_ensure_tls_devices(cls):
1408+
"""Ensure thread-local Device singletons exist, creating if needed."""
1409+
cdef int total
1410+
try:
1411+
return _tls.devices
1412+
except AttributeError:
1413+
with nogil:
1414+
HANDLE_RETURN(cydriver.cuDeviceGetCount(&total))
1415+
devices = _tls.devices = []
1416+
for dev_id in range(total):
1417+
device = super(Device, cls).__new__(cls)
1418+
device._id = dev_id
1419+
device._memory_resource = None
1420+
device._has_inited = False
1421+
device._properties = None
1422+
device._uuid = None
1423+
devices.append(device)
1424+
return devices

0 commit comments

Comments
 (0)