@@ -965,60 +965,12 @@ class Device:
965965 __slots__ = (" _id" , " _memory_resource" , " _has_inited" , " _properties" , " _uuid" )
966966
967967 def __new__(cls , device_id: Device | int | None = None ):
968- # Handle device_id argument.
969968 if isinstance (device_id, Device):
970969 return device_id
971- else :
972- device_id = getattr (device_id, ' device_id' , device_id)
973-
974- # Initialize CUDA.
975- global _is_cuInit
976- if _is_cuInit is False :
977- with _lock, nogil:
978- HANDLE_RETURN(cydriver.cuInit(0 ))
979- _is_cuInit = True
980- # Check version compatibility after CUDA is initialized
981- try :
982- from cuda.bindings.utils import warn_if_cuda_major_version_mismatch
983- except ImportError :
984- pass
985- else :
986- warn_if_cuda_major_version_mismatch()
987-
988- # important: creating a Device instance does not initialize the GPU!
989- cdef cydriver.CUdevice dev
990- cdef cydriver.CUcontext ctx
991- if device_id is None :
992- with nogil:
993- err = cydriver.cuCtxGetDevice(& dev)
994- if err == cydriver.CUresult.CUDA_SUCCESS:
995- device_id = int (dev)
996- elif err == cydriver.CUresult.CUDA_ERROR_INVALID_CONTEXT:
997- with nogil:
998- HANDLE_RETURN(cydriver.cuCtxGetCurrent(& ctx))
999- assert < void * > (ctx) == NULL
1000- device_id = 0 # cudart behavior
1001- else :
1002- HANDLE_RETURN(err)
1003- elif device_id < 0 :
1004- raise ValueError (f" device_id must be >= 0, got {device_id}" )
1005970
1006- # ensure Device is singleton
1007- cdef int total
1008- try :
1009- devices = _tls.devices
1010- except AttributeError :
1011- with nogil:
1012- HANDLE_RETURN(cydriver.cuDeviceGetCount(& total))
1013- devices = _tls.devices = []
1014- for dev_id in range (total):
1015- device = super ().__new__(cls )
1016- device._id = dev_id
1017- device._memory_resource = None
1018- device._has_inited = False
1019- device._properties = None
1020- device._uuid = None
1021- devices.append(device)
971+ Device_ensure_cuda_initialized()
972+ device_id = Device_resolve_device_id(device_id)
973+ devices = Device_ensure_tls_devices(cls )
1022974
1023975 try :
1024976 return devices[device_id]
@@ -1413,3 +1365,60 @@ class Device:
14131365 """
14141366 self._check_context_initialized()
14151367 return GraphBuilder._init(stream = self .create_stream(), is_stream_owner = True )
1368+
1369+
1370+ cdef inline void Device_ensure_cuda_initialized() except *:
1371+ """Initialize CUDA driver and check version compatibility (once per process )."""
1372+ global _is_cuInit
1373+ if _is_cuInit is False:
1374+ with _lock , nogil:
1375+ HANDLE_RETURN(cydriver.cuInit(0))
1376+ _is_cuInit = True
1377+ try:
1378+ from cuda.bindings.utils import warn_if_cuda_major_version_mismatch
1379+ except ImportError:
1380+ pass
1381+ else:
1382+ warn_if_cuda_major_version_mismatch()
1383+
1384+
1385+ cdef inline int Device_resolve_device_id(device_id ) except? -1:
1386+ """Resolve device_id , defaulting to current device or 0."""
1387+ cdef cydriver.CUdevice dev
1388+ cdef cydriver.CUcontext ctx
1389+ cdef cydriver.CUresult err
1390+ if device_id is None:
1391+ with nogil:
1392+ err = cydriver.cuCtxGetDevice(& dev)
1393+ if err == cydriver.CUresult.CUDA_SUCCESS:
1394+ return int(dev )
1395+ elif err == cydriver.CUresult.CUDA_ERROR_INVALID_CONTEXT:
1396+ with nogil:
1397+ HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx ))
1398+ assert <void*>(ctx ) == NULL
1399+ return 0 # cudart behavior
1400+ else:
1401+ HANDLE_RETURN(err )
1402+ elif device_id < 0:
1403+ raise ValueError(f"device_id must be >= 0, got {device_id}")
1404+ return device_id
1405+
1406+
1407+ cdef inline list Device_ensure_tls_devices(cls ):
1408+ """ Ensure thread-local Device singletons exist, creating if needed."""
1409+ cdef int total
1410+ try :
1411+ return _tls.devices
1412+ except AttributeError :
1413+ with nogil:
1414+ HANDLE_RETURN(cydriver.cuDeviceGetCount(& total))
1415+ devices = _tls.devices = []
1416+ for dev_id in range (total):
1417+ device = super (Device, cls ).__new__(cls )
1418+ device._id = dev_id
1419+ device._memory_resource = None
1420+ device._has_inited = False
1421+ device._properties = None
1422+ device._uuid = None
1423+ devices.append(device)
1424+ return devices
0 commit comments