diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index 7692d34287..12962bf491 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -56,6 +56,9 @@ decltype(&cuLibraryLoadData) p_cuLibraryLoadData = nullptr; decltype(&cuLibraryUnload) p_cuLibraryUnload = nullptr; decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel = nullptr; +// Graph +decltype(&cuGraphDestroy) p_cuGraphDestroy = nullptr; + // Linker decltype(&cuLinkDestroy) p_cuLinkDestroy = nullptr; @@ -160,6 +163,49 @@ class GILAcquireGuard { } // namespace +// ============================================================================ +// Handle reverse-lookup registry +// +// Maps raw CUDA handles (CUevent, CUkernel, etc.) back to their owning +// shared_ptr so that _ref constructors can recover full metadata. +// Uses weak_ptr to avoid preventing destruction. +// ============================================================================ + +template> +class HandleRegistry { +public: + void register_handle(const Key& key, const Handle& h) { + std::lock_guard lock(mutex_); + map_[key] = h; + } + + void unregister_handle(const Key& key) noexcept { + try { + std::lock_guard lock(mutex_); + auto it = map_.find(key); + if (it != map_.end() && it->second.expired()) { + map_.erase(it); + } + } catch (...) {} + } + + Handle lookup(const Key& key) { + std::lock_guard lock(mutex_); + auto it = map_.find(key); + if (it != map_.end()) { + if (auto h = it->second.lock()) { + return h; + } + map_.erase(it); + } + return {}; + } + +private: + std::mutex mutex_; + std::unordered_map, Hash> map_; +}; + // ============================================================================ // Thread-local error handling // ============================================================================ @@ -318,10 +364,46 @@ StreamHandle get_per_thread_stream() { namespace { struct EventBox { CUevent resource; + bool timing_disabled; + bool busy_waited; + bool ipc_enabled; + int device_id; + ContextHandle h_context; }; } // namespace -EventHandle create_event_handle(const ContextHandle& h_ctx, unsigned int flags) { +static const EventBox* get_box(const EventHandle& h) { + const CUevent* p = h.get(); + return reinterpret_cast( + reinterpret_cast(p) - offsetof(EventBox, resource) + ); +} + +bool get_event_timing_disabled(const EventHandle& h) noexcept { + return h ? get_box(h)->timing_disabled : true; +} + +bool get_event_busy_waited(const EventHandle& h) noexcept { + return h ? get_box(h)->busy_waited : false; +} + +bool get_event_ipc_enabled(const EventHandle& h) noexcept { + return h ? get_box(h)->ipc_enabled : false; +} + +int get_event_device_id(const EventHandle& h) noexcept { + return h ? get_box(h)->device_id : -1; +} + +ContextHandle get_event_context(const EventHandle& h) noexcept { + return h ? get_box(h)->h_context : ContextHandle{}; +} + +static HandleRegistry event_registry; + +EventHandle create_event_handle(const ContextHandle& h_ctx, unsigned int flags, + bool timing_disabled, bool busy_waited, + bool ipc_enabled, int device_id) { GILReleaseGuard gil; CUevent event; if (CUDA_SUCCESS != (err = p_cuEventCreate(&event, flags))) { @@ -329,21 +411,33 @@ EventHandle create_event_handle(const ContextHandle& h_ctx, unsigned int flags) } auto box = std::shared_ptr( - new EventBox{event}, + new EventBox{event, timing_disabled, busy_waited, ipc_enabled, device_id, h_ctx}, [h_ctx](const EventBox* b) { + event_registry.unregister_handle(b->resource); GILReleaseGuard gil; p_cuEventDestroy(b->resource); delete b; } ); - return EventHandle(box, &box->resource); + EventHandle h(box, &box->resource); + event_registry.register_handle(event, h); + return h; } EventHandle create_event_handle_noctx(unsigned int flags) { - return create_event_handle(ContextHandle{}, flags); + return create_event_handle(ContextHandle{}, flags, true, false, false, -1); +} + +EventHandle create_event_handle_ref(CUevent event) { + if (auto h = event_registry.lookup(event)) { + return h; + } + auto box = std::make_shared(EventBox{event, true, false, false, -1, {}}); + return EventHandle(box, &box->resource); } -EventHandle create_event_handle_ipc(const CUipcEventHandle& ipc_handle) { +EventHandle create_event_handle_ipc(const CUipcEventHandle& ipc_handle, + bool busy_waited) { GILReleaseGuard gil; CUevent event; if (CUDA_SUCCESS != (err = p_cuIpcOpenEventHandle(&event, ipc_handle))) { @@ -351,14 +445,17 @@ EventHandle create_event_handle_ipc(const CUipcEventHandle& ipc_handle) { } auto box = std::shared_ptr( - new EventBox{event}, + new EventBox{event, true, busy_waited, true, -1, {}}, [](const EventBox* b) { + event_registry.unregister_handle(b->resource); GILReleaseGuard gil; p_cuEventDestroy(b->resource); delete b; } ); - return EventHandle(box, &box->resource); + EventHandle h(box, &box->resource); + event_registry.register_handle(event, h); + return h; } // ============================================================================ @@ -665,61 +762,43 @@ struct ExportDataKeyHash { } -static std::mutex ipc_ptr_cache_mutex; -static std::unordered_map, ExportDataKeyHash> ipc_ptr_cache; +static HandleRegistry ipc_ptr_cache; +static std::mutex ipc_import_mutex; DevicePtrHandle deviceptr_import_ipc(const MemoryPoolHandle& h_pool, const void* export_data, const StreamHandle& h_stream) { auto data = const_cast( reinterpret_cast(export_data)); if (use_ipc_ptr_cache()) { - // Check cache before calling cuMemPoolImportPointer ExportDataKey key; std::memcpy(&key.data, data, sizeof(key.data)); - std::lock_guard lock(ipc_ptr_cache_mutex); + std::lock_guard lock(ipc_import_mutex); - auto it = ipc_ptr_cache.find(key); - if (it != ipc_ptr_cache.end()) { - if (auto box = it->second.lock()) { - // Cache hit - return existing handle - return DevicePtrHandle(box, &box->resource); - } - ipc_ptr_cache.erase(it); // Expired entry + if (auto h = ipc_ptr_cache.lookup(key)) { + return h; } - // Cache miss - import the pointer GILReleaseGuard gil; CUdeviceptr ptr; if (CUDA_SUCCESS != (err = p_cuMemPoolImportPointer(&ptr, *h_pool, data))) { return {}; } - // Create new handle with cache-clearing deleter auto box = std::shared_ptr( new DevicePtrBox{ptr, h_stream}, [h_pool, key](DevicePtrBox* b) { + ipc_ptr_cache.unregister_handle(key); GILReleaseGuard gil; - try { - std::lock_guard lock(ipc_ptr_cache_mutex); - // Only erase if expired - avoids race where another thread - // replaced the entry with a new import before we acquired the lock. - auto it = ipc_ptr_cache.find(key); - if (it != ipc_ptr_cache.end() && it->second.expired()) { - ipc_ptr_cache.erase(it); - } - } catch (...) { - // Cache cleanup is best-effort - swallow exceptions in destructor context - } p_cuMemFreeAsync(b->resource, as_cu(b->h_stream)); delete b; } ); - ipc_ptr_cache[key] = box; - return DevicePtrHandle(box, &box->resource); + DevicePtrHandle h(box, &box->resource); + ipc_ptr_cache.register_handle(key, h); + return h; } else { - // No caching - simple handle creation GILReleaseGuard gil; CUdeviceptr ptr; if (CUDA_SUCCESS != (err = p_cuMemPoolImportPointer(&ptr, *h_pool, data))) { @@ -798,10 +877,19 @@ LibraryHandle create_library_handle_ref(CUlibrary library) { namespace { struct KernelBox { CUkernel resource; - LibraryHandle h_library; // Keeps library alive + LibraryHandle h_library; }; } // namespace +static const KernelBox* get_box(const KernelHandle& h) { + const CUkernel* p = h.get(); + return reinterpret_cast( + reinterpret_cast(p) - offsetof(KernelBox, resource) + ); +} + +static HandleRegistry kernel_registry; + KernelHandle create_kernel_handle(const LibraryHandle& h_library, const char* name) { GILReleaseGuard gil; CUkernel kernel; @@ -809,14 +897,76 @@ KernelHandle create_kernel_handle(const LibraryHandle& h_library, const char* na return {}; } - return create_kernel_handle_ref(kernel, h_library); + auto box = std::make_shared(KernelBox{kernel, h_library}); + KernelHandle h(box, &box->resource); + kernel_registry.register_handle(kernel, h); + return h; } -KernelHandle create_kernel_handle_ref(CUkernel kernel, const LibraryHandle& h_library) { - auto box = std::make_shared(KernelBox{kernel, h_library}); +KernelHandle create_kernel_handle_ref(CUkernel kernel) { + if (auto h = kernel_registry.lookup(kernel)) { + return h; + } + auto box = std::make_shared(KernelBox{kernel, {}}); return KernelHandle(box, &box->resource); } +LibraryHandle get_kernel_library(const KernelHandle& h) noexcept { + if (!h) return {}; + return get_box(h)->h_library; +} + +// ============================================================================ +// Graph Handles +// ============================================================================ + +namespace { +struct GraphBox { + CUgraph resource; + GraphHandle h_parent; // Keeps parent alive for child/branch graphs +}; +} // namespace + +GraphHandle create_graph_handle(CUgraph graph) { + auto box = std::shared_ptr( + new GraphBox{graph, {}}, + [](const GraphBox* b) { + GILReleaseGuard gil; + p_cuGraphDestroy(b->resource); + delete b; + } + ); + return GraphHandle(box, &box->resource); +} + +GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent) { + auto box = std::make_shared(GraphBox{graph, h_parent}); + return GraphHandle(box, &box->resource); +} + +namespace { +struct GraphNodeBox { + CUgraphNode resource; + GraphHandle h_graph; +}; +} // namespace + +static const GraphNodeBox* get_box(const GraphNodeHandle& h) { + const CUgraphNode* p = h.get(); + return reinterpret_cast( + reinterpret_cast(p) - offsetof(GraphNodeBox, resource) + ); +} + +GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) { + auto box = std::make_shared(GraphNodeBox{node, h_graph}); + return GraphNodeHandle(box, &box->resource); +} + +GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept { + return h ? get_box(h)->h_graph : GraphHandle{}; +} + // ============================================================================ // Graphics Resource Handles // ============================================================================ diff --git a/cuda_core/cuda/core/_cpp/resource_handles.hpp b/cuda_core/cuda/core/_cpp/resource_handles.hpp index 724d6e1bd9..090e5fa8cb 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.hpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.hpp @@ -92,6 +92,9 @@ extern decltype(&cuLibraryLoadData) p_cuLibraryLoadData; extern decltype(&cuLibraryUnload) p_cuLibraryUnload; extern decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel; +// Graph +extern decltype(&cuGraphDestroy) p_cuGraphDestroy; + // Linker extern decltype(&cuLinkDestroy) p_cuLinkDestroy; @@ -143,6 +146,8 @@ using EventHandle = std::shared_ptr; using MemoryPoolHandle = std::shared_ptr; using LibraryHandle = std::shared_ptr; using KernelHandle = std::shared_ptr; +using GraphHandle = std::shared_ptr; +using GraphNodeHandle = std::shared_ptr; using GraphicsResourceHandle = std::shared_ptr; using NvrtcProgramHandle = std::shared_ptr; using NvvmProgramHandle = std::shared_ptr; @@ -200,9 +205,12 @@ StreamHandle get_per_thread_stream(); // Create an owning event handle by calling cuEventCreate. // The event structurally depends on the provided context handle. +// Metadata fields are stored in the EventBox for later retrieval. // When the last reference is released, cuEventDestroy is called automatically. // Returns empty handle on error (caller must check). -EventHandle create_event_handle(const ContextHandle& h_ctx, unsigned int flags); +EventHandle create_event_handle(const ContextHandle& h_ctx, unsigned int flags, + bool timing_disabled, bool busy_waited, + bool ipc_enabled, int device_id); // Create an owning event handle without context dependency. // Use for temporary events that are created and destroyed in the same scope. @@ -214,7 +222,21 @@ EventHandle create_event_handle_noctx(unsigned int flags); // The originating process owns the event and its context. // When the last reference is released, cuEventDestroy is called automatically. // Returns empty handle on error (caller must check). -EventHandle create_event_handle_ipc(const CUipcEventHandle& ipc_handle); +EventHandle create_event_handle_ipc(const CUipcEventHandle& ipc_handle, + bool busy_waited); + +// Create a non-owning event handle (references existing event). +// Use for events that are managed by the CUDA graph or another owner. +// The event will NOT be destroyed when the handle is released. +// Metadata defaults to unknown (timing_disabled=true, device_id=-1). +EventHandle create_event_handle_ref(CUevent event); + +// Event metadata accessors (read from EventBox via pointer arithmetic) +bool get_event_timing_disabled(const EventHandle& h) noexcept; +bool get_event_busy_waited(const EventHandle& h) noexcept; +bool get_event_ipc_enabled(const EventHandle& h) noexcept; +int get_event_device_id(const EventHandle& h) noexcept; +ContextHandle get_event_context(const EventHandle& h) noexcept; // ============================================================================ // Memory pool handle functions @@ -345,9 +367,41 @@ LibraryHandle create_library_handle_ref(CUlibrary library); // Returns empty handle on error (caller must check). KernelHandle create_kernel_handle(const LibraryHandle& h_library, const char* name); -// Create a non-owning kernel handle with library dependency. -// Use for borrowed kernels. The library handle keeps the library alive. -KernelHandle create_kernel_handle_ref(CUkernel kernel, const LibraryHandle& h_library); +// Create a kernel handle from a raw CUkernel. +// If the kernel is already managed (in the registry), returns the owning +// handle with library dependency. Otherwise returns a non-owning ref. +KernelHandle create_kernel_handle_ref(CUkernel kernel); + +// Get the library handle associated with a kernel (from KernelBox). +// Returns empty handle if the kernel has no library dependency. +LibraryHandle get_kernel_library(const KernelHandle& h) noexcept; + +// ============================================================================ +// Graph handle functions +// ============================================================================ + +// Wrap an externally-created CUgraph with RAII cleanup. +// When the last reference is released, cuGraphDestroy is called automatically. +// The caller must have already created the graph via cuGraphCreate. +GraphHandle create_graph_handle(CUgraph graph); + +// Create a non-owning graph handle that keeps h_parent alive. +// Use for graphs owned by a child/conditional node in a parent graph. +// The child graph will NOT be destroyed when this handle is released, +// but h_parent will be prevented from destruction while this handle exists. +GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent); + +// ============================================================================ +// Graph node handle functions +// ============================================================================ + +// Create a node handle. Nodes are owned by their parent graph (not +// independently destroyable). The GraphHandle dependency ensures the +// graph outlives any node reference. +GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph); + +// Extract the owning graph handle from a node handle. +GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept; // ============================================================================ // Graphics resource handle functions @@ -445,6 +499,14 @@ inline CUkernel as_cu(const KernelHandle& h) noexcept { return h ? *h : nullptr; } +inline CUgraph as_cu(const GraphHandle& h) noexcept { + return h ? *h : nullptr; +} + +inline CUgraphNode as_cu(const GraphNodeHandle& h) noexcept { + return h ? *h : nullptr; +} + inline CUgraphicsResource as_cu(const GraphicsResourceHandle& h) noexcept { return h ? *h : nullptr; } @@ -495,6 +557,14 @@ inline std::intptr_t as_intptr(const KernelHandle& h) noexcept { return reinterpret_cast(as_cu(h)); } +inline std::intptr_t as_intptr(const GraphHandle& h) noexcept { + return reinterpret_cast(as_cu(h)); +} + +inline std::intptr_t as_intptr(const GraphNodeHandle& h) noexcept { + return reinterpret_cast(as_cu(h)); +} + inline std::intptr_t as_intptr(const GraphicsResourceHandle& h) noexcept { return reinterpret_cast(as_cu(h)); } @@ -558,6 +628,17 @@ inline PyObject* as_py(const KernelHandle& h) noexcept { return detail::make_py("cuda.bindings.driver", "CUkernel", as_intptr(h)); } +inline PyObject* as_py(const GraphHandle& h) noexcept { + return detail::make_py("cuda.bindings.driver", "CUgraph", as_intptr(h)); +} + +inline PyObject* as_py(const GraphNodeHandle& h) noexcept { + if (!as_intptr(h)) { + Py_RETURN_NONE; + } + return detail::make_py("cuda.bindings.driver", "CUgraphNode", as_intptr(h)); +} + inline PyObject* as_py(const NvrtcProgramHandle& h) noexcept { return detail::make_py("cuda.bindings.nvrtc", "nvrtcProgram", as_intptr(h)); } diff --git a/cuda_core/cuda/core/_event.pxd b/cuda_core/cuda/core/_event.pxd index c393b29ebf..5710b13699 100644 --- a/cuda_core/cuda/core/_event.pxd +++ b/cuda_core/cuda/core/_event.pxd @@ -10,15 +10,13 @@ cdef class Event: cdef: EventHandle _h_event - ContextHandle _h_context - bint _timing_disabled - bint _busy_waited - bint _ipc_enabled object _ipc_descriptor - int _device_id object __weakref__ @staticmethod cdef Event _init(type cls, int device_id, ContextHandle h_context, options, bint is_free) + @staticmethod + cdef Event _from_handle(EventHandle h_event) + cpdef close(self) diff --git a/cuda_core/cuda/core/_event.pyx b/cuda_core/cuda/core/_event.pyx index 1ff87a1ea0..4a0491d865 100644 --- a/cuda_core/cuda/core/_event.pyx +++ b/cuda_core/cuda/core/_event.pyx @@ -13,6 +13,11 @@ from cuda.core._resource_handles cimport ( EventHandle, create_event_handle, create_event_handle_ipc, + get_event_timing_disabled, + get_event_busy_waited, + get_event_ipc_enabled, + get_event_device_id, + get_event_context, as_intptr, as_cu, as_py, @@ -95,36 +100,46 @@ cdef class Event: cdef Event self = cls.__new__(cls) cdef EventOptions opts = check_or_create_options(EventOptions, options, "Event options") cdef unsigned int flags = 0x0 - self._timing_disabled = False - self._busy_waited = False - self._ipc_enabled = False + cdef bint timing_disabled = False + cdef bint busy_waited = False + cdef bint ipc_enabled = False self._ipc_descriptor = None if not opts.enable_timing: flags |= cydriver.CUevent_flags.CU_EVENT_DISABLE_TIMING - self._timing_disabled = True + timing_disabled = True if opts.busy_waited_sync: flags |= cydriver.CUevent_flags.CU_EVENT_BLOCKING_SYNC - self._busy_waited = True + busy_waited = True if opts.ipc_enabled: if is_free: raise TypeError( "IPC-enabled events must be bound; use Stream.record for creation." ) flags |= cydriver.CUevent_flags.CU_EVENT_INTERPROCESS - self._ipc_enabled = True - if not self._timing_disabled: + ipc_enabled = True + if not timing_disabled: raise TypeError("IPC-enabled events cannot use timing.") - # C++ creates the event and returns owning handle with context dependency - cdef EventHandle h_event = create_event_handle(h_context, flags) + cdef EventHandle h_event = create_event_handle( + h_context, flags, timing_disabled, busy_waited, ipc_enabled, device_id) if not h_event: raise RuntimeError("Failed to create CUDA event") self._h_event = h_event - self._h_context = h_context - self._device_id = device_id - if opts.ipc_enabled: + if ipc_enabled: self.get_ipc_descriptor() return self + @staticmethod + cdef Event _from_handle(EventHandle h_event): + """Create an Event wrapping an existing EventHandle. + + Metadata (timing, busy_waited, ipc, device_id) is read from the + EventBox via pointer arithmetic — no fields are cached on Event. + """ + cdef Event self = Event.__new__(Event) + self._h_event = h_event + self._ipc_descriptor = None + return self + cpdef close(self): """Destroy the event. @@ -191,7 +206,7 @@ cdef class Event: with nogil: HANDLE_RETURN(cydriver.cuIpcGetEventHandle(&data, as_cu(self._h_event))) cdef bytes data_b = cpython.PyBytes_FromStringAndSize((data.reserved), sizeof(data.reserved)) - self._ipc_descriptor = IPCEventDescriptor._init(data_b, self._busy_waited) + self._ipc_descriptor = IPCEventDescriptor._init(data_b, get_event_busy_waited(self._h_event)) return self._ipc_descriptor @classmethod @@ -200,33 +215,27 @@ cdef class Event: cdef cydriver.CUipcEventHandle data memcpy(data.reserved, (ipc_descriptor._reserved), sizeof(data.reserved)) cdef Event self = Event.__new__(cls) - # IPC events: the originating process owns the event and its context - cdef EventHandle h_event = create_event_handle_ipc(data) + cdef EventHandle h_event = create_event_handle_ipc(data, ipc_descriptor._busy_waited) if not h_event: raise RuntimeError("Failed to open IPC event handle") self._h_event = h_event - self._h_context = ContextHandle() - self._timing_disabled = True - self._busy_waited = ipc_descriptor._busy_waited - self._ipc_enabled = True self._ipc_descriptor = ipc_descriptor - self._device_id = -1 return self @property def is_ipc_enabled(self) -> bool: """Return True if the event can be shared across process boundaries, otherwise False.""" - return self._ipc_enabled + return get_event_ipc_enabled(self._h_event) @property def is_timing_disabled(self) -> bool: """Return True if the event does not record timing data, otherwise False.""" - return self._timing_disabled + return get_event_timing_disabled(self._h_event) @property def is_sync_busy_waited(self) -> bool: """Return True if the event synchronization would keep the CPU busy-waiting, otherwise False.""" - return self._busy_waited + return get_event_busy_waited(self._h_event) def sync(self): """Synchronize until the event completes. @@ -274,15 +283,18 @@ cdef class Event: context is set current after a event is created. """ - if self._device_id >= 0: + cdef int dev_id = get_event_device_id(self._h_event) + if dev_id >= 0: from ._device import Device # avoid circular import - return Device(self._device_id) + return Device(dev_id) @property def context(self) -> Context: """Return the :obj:`~_context.Context` associated with this event.""" - if self._h_context and self._device_id >= 0: - return Context._from_handle(Context, self._h_context, self._device_id) + cdef ContextHandle h_ctx = get_event_context(self._h_event) + cdef int dev_id = get_event_device_id(self._h_event) + if h_ctx and dev_id >= 0: + return Context._from_handle(Context, h_ctx, dev_id) cdef class IPCEventDescriptor: diff --git a/cuda_core/cuda/core/_graph.py b/cuda_core/cuda/core/_graph/__init__.py similarity index 93% rename from cuda_core/cuda/core/_graph.py rename to cuda_core/cuda/core/_graph/__init__.py index 80482c38ac..14b801137a 100644 --- a/cuda_core/cuda/core/_graph.py +++ b/cuda_core/cuda/core/_graph/__init__.py @@ -91,6 +91,43 @@ class GraphDebugPrintOptions: extra_topo_info: bool = False conditional_node_params: bool = False + def _to_flags(self) -> int: + """Convert options to CUDA driver API flags (internal use).""" + flags = 0 + if self.verbose: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE + if self.runtime_types: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES + if self.kernel_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS + if self.memcpy_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS + if self.memset_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS + if self.host_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS + if self.event_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS + if self.ext_semas_signal_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS + if self.ext_semas_wait_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS + if self.kernel_node_attributes: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES + if self.handles: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES + if self.mem_alloc_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS + if self.mem_free_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS + if self.batch_mem_op_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS + if self.extra_topo_info: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO + if self.conditional_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS + return flags + @dataclass class GraphCompleteOptions: @@ -341,41 +378,7 @@ def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None): """ if not self._building_ended: raise RuntimeError("Graph has not finished building.") - flags = 0 - if options: - if options.verbose: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE - if options.runtime_types: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES - if options.kernel_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS - if options.memcpy_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS - if options.memset_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS - if options.host_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS - if options.event_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS - if options.ext_semas_signal_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS - if options.ext_semas_wait_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS - if options.kernel_node_attributes: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES - if options.handles: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES - if options.mem_alloc_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS - if options.mem_free_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS - if options.batch_mem_op_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS - if options.extra_topo_info: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO - if options.conditional_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS - + flags = options._to_flags() if options else 0 handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags)) def split(self, count: int) -> tuple[GraphBuilder, ...]: diff --git a/cuda_core/cuda/core/_graph/_graphdef.pxd b/cuda_core/cuda/core/_graph/_graphdef.pxd new file mode 100644 index 0000000000..83612cd6bb --- /dev/null +++ b/cuda_core/cuda/core/_graph/_graphdef.pxd @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from libc.stddef cimport size_t + +from cuda.bindings cimport cydriver +from cuda.core._resource_handles cimport EventHandle, GraphHandle, GraphNodeHandle, KernelHandle + + +cdef class Condition +cdef class GraphDef +cdef class Node +cdef class EmptyNode(Node) +cdef class KernelNode(Node) +cdef class AllocNode(Node) +cdef class FreeNode(Node) +cdef class MemsetNode(Node) +cdef class MemcpyNode(Node) +cdef class ChildGraphNode(Node) +cdef class EventRecordNode(Node) +cdef class EventWaitNode(Node) +cdef class HostCallbackNode(Node) +cdef class ConditionalNode(Node) +cdef class IfNode(ConditionalNode) +cdef class IfElseNode(ConditionalNode) +cdef class WhileNode(ConditionalNode) +cdef class SwitchNode(ConditionalNode) + + +cdef class Condition: + cdef: + cydriver.CUgraphConditionalHandle _c_handle + object __weakref__ + + +cdef class GraphDef: + cdef: + GraphHandle _h_graph + object __weakref__ + + @staticmethod + cdef GraphDef _from_handle(GraphHandle h_graph) + + +cdef class Node: + cdef: + GraphNodeHandle _h_node + tuple _pred_cache + tuple _succ_cache + object __weakref__ + + @staticmethod + cdef Node _create(GraphHandle h_graph, cydriver.CUgraphNode node) + + +cdef class EmptyNode(Node): + @staticmethod + cdef EmptyNode _create_impl(GraphNodeHandle h_node) + + +cdef class KernelNode(Node): + cdef: + tuple _grid + tuple _block + unsigned int _shmem_size + KernelHandle _h_kernel + + @staticmethod + cdef KernelNode _create_with_params(GraphNodeHandle h_node, + tuple grid, tuple block, unsigned int shmem_size, + KernelHandle h_kernel) + + @staticmethod + cdef KernelNode _create_from_driver(GraphNodeHandle h_node) + + +cdef class AllocNode(Node): + cdef: + cydriver.CUdeviceptr _dptr + size_t _bytesize + int _device_id + str _memory_type + tuple _peer_access + + @staticmethod + cdef AllocNode _create_with_params(GraphNodeHandle h_node, + cydriver.CUdeviceptr dptr, size_t bytesize, + int device_id, str memory_type, tuple peer_access) + + @staticmethod + cdef AllocNode _create_from_driver(GraphNodeHandle h_node) + + +cdef class FreeNode(Node): + cdef: + cydriver.CUdeviceptr _dptr + + @staticmethod + cdef FreeNode _create_with_params(GraphNodeHandle h_node, + cydriver.CUdeviceptr dptr) + + @staticmethod + cdef FreeNode _create_from_driver(GraphNodeHandle h_node) + + +cdef class MemsetNode(Node): + cdef: + cydriver.CUdeviceptr _dptr + unsigned int _value + unsigned int _element_size + size_t _width + size_t _height + size_t _pitch + + @staticmethod + cdef MemsetNode _create_with_params(GraphNodeHandle h_node, + cydriver.CUdeviceptr dptr, unsigned int value, + unsigned int element_size, size_t width, + size_t height, size_t pitch) + + @staticmethod + cdef MemsetNode _create_from_driver(GraphNodeHandle h_node) + + +cdef class MemcpyNode(Node): + cdef: + cydriver.CUdeviceptr _dst + cydriver.CUdeviceptr _src + size_t _size + cydriver.CUmemorytype _dst_type + cydriver.CUmemorytype _src_type + + @staticmethod + cdef MemcpyNode _create_with_params(GraphNodeHandle h_node, + cydriver.CUdeviceptr dst, cydriver.CUdeviceptr src, + size_t size, cydriver.CUmemorytype dst_type, + cydriver.CUmemorytype src_type) + + @staticmethod + cdef MemcpyNode _create_from_driver(GraphNodeHandle h_node) + + +cdef class ChildGraphNode(Node): + cdef: + GraphHandle _h_child_graph + + @staticmethod + cdef ChildGraphNode _create_with_params(GraphNodeHandle h_node, + GraphHandle h_child_graph) + + @staticmethod + cdef ChildGraphNode _create_from_driver(GraphNodeHandle h_node) + + +cdef class EventRecordNode(Node): + cdef: + EventHandle _h_event + + @staticmethod + cdef EventRecordNode _create_with_params(GraphNodeHandle h_node, + EventHandle h_event) + + @staticmethod + cdef EventRecordNode _create_from_driver(GraphNodeHandle h_node) + + +cdef class EventWaitNode(Node): + cdef: + EventHandle _h_event + + @staticmethod + cdef EventWaitNode _create_with_params(GraphNodeHandle h_node, + EventHandle h_event) + + @staticmethod + cdef EventWaitNode _create_from_driver(GraphNodeHandle h_node) + + +cdef class HostCallbackNode(Node): + cdef: + object _callable + cydriver.CUhostFn _fn + void* _user_data + + @staticmethod + cdef HostCallbackNode _create_with_params(GraphNodeHandle h_node, + object callable_obj, cydriver.CUhostFn fn, + void* user_data) + + @staticmethod + cdef HostCallbackNode _create_from_driver(GraphNodeHandle h_node) + + +cdef class ConditionalNode(Node): + cdef: + Condition _condition + cydriver.CUgraphConditionalNodeType _cond_type + tuple _branches # tuple of GraphDef (non-owning wrappers) + + @staticmethod + cdef ConditionalNode _create_from_driver(GraphNodeHandle h_node) + + +cdef class IfNode(ConditionalNode): + pass + + +cdef class IfElseNode(ConditionalNode): + pass + + +cdef class WhileNode(ConditionalNode): + pass + + +cdef class SwitchNode(ConditionalNode): + pass diff --git a/cuda_core/cuda/core/_graph/_graphdef.pyx b/cuda_core/cuda/core/_graph/_graphdef.pyx new file mode 100644 index 0000000000..4c06363293 --- /dev/null +++ b/cuda_core/cuda/core/_graph/_graphdef.pyx @@ -0,0 +1,2078 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Private module for explicit CUDA graph construction. + +This module provides GraphDef and a Node class hierarchy for building CUDA +graphs explicitly (as opposed to stream capture). Both approaches produce +the same public Graph type for execution. + +Node hierarchy: + Node (base — also used for the internal entry point) + ├── EmptyNode (synchronization / join point) + ├── KernelNode (kernel launch) + ├── AllocNode (memory allocation, exposes dptr and bytesize) + ├── FreeNode (memory free, exposes dptr) + ├── MemsetNode (memory set, exposes dptr, value, element_size, etc.) + ├── MemcpyNode (memory copy, exposes dst, src, size) + ├── ChildGraphNode (embedded sub-graph) + ├── EventRecordNode (record an event) + ├── EventWaitNode (wait for an event) + ├── HostCallbackNode (host CPU callback) + └── ConditionalNode (conditional execution — base for reconstruction) + ├── IfNode (if-then conditional, 1 branch) + ├── IfElseNode (if-then-else conditional, 2 branches) + ├── WhileNode (while-loop conditional, 1 branch) + └── SwitchNode (switch conditional, N branches) +""" + +from __future__ import annotations + +from cpython.ref cimport Py_INCREF + +from libc.stddef cimport size_t +from libc.stdint cimport uintptr_t +from libc.stdlib cimport malloc, free +from libc.string cimport memset as c_memset, memcpy as c_memcpy + +from libcpp.vector cimport vector + +from cuda.bindings cimport cydriver + +from cuda.core._event cimport Event +from cuda.core._kernel_arg_handler cimport ParamHolder +from cuda.core._launch_config cimport LaunchConfig +from cuda.core._module cimport Kernel +from cuda.core._resource_handles cimport ( + EventHandle, + GraphHandle, + KernelHandle, + GraphNodeHandle, + as_cu, + as_intptr, + as_py, + create_event_handle_ref, + create_graph_handle, + create_graph_handle_ref, + create_kernel_handle_ref, + create_graph_node_handle, + graph_node_get_graph, +) +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, _parse_fill_value + +from dataclasses import dataclass + +from cuda.core import Device +from cuda.core._utils.cuda_utils import driver, handle_return + +__all__ = [ + "Condition", + "GraphAllocOptions", + "GraphDef", + "Node", + "EmptyNode", + "KernelNode", + "AllocNode", + "FreeNode", + "MemsetNode", + "MemcpyNode", + "ChildGraphNode", + "EventRecordNode", + "EventWaitNode", + "HostCallbackNode", + "ConditionalNode", + "IfNode", + "IfElseNode", + "WhileNode", + "SwitchNode", +] + + +cdef bint _has_cuGraphNodeGetParams = False +cdef bint _version_checked = False + +cdef bint _check_node_get_params(): + global _has_cuGraphNodeGetParams, _version_checked + if not _version_checked: + ver = handle_return(driver.cuDriverGetVersion()) + _has_cuGraphNodeGetParams = ver >= 13020 + _version_checked = True + return _has_cuGraphNodeGetParams + + +cdef extern from "Python.h": + void _py_decref "Py_DECREF" (void*) + + +cdef void _py_host_trampoline(void* data) noexcept with gil: + (data)() + + +cdef void _py_host_destructor(void* data) noexcept with gil: + _py_decref(data) + + +cdef void _destroy_event_handle_copy(void* ptr) noexcept nogil: + cdef EventHandle* p = ptr + del p + + +cdef void _destroy_kernel_handle_copy(void* ptr) noexcept nogil: + cdef KernelHandle* p = ptr + del p + + +cdef void _attach_user_object( + cydriver.CUgraph graph, void* ptr, + cydriver.CUhostFn destroy) except *: + """Create a CUDA user object and transfer ownership to the graph. + + On success the graph owns the resource (via MOVE semantics). + On failure the destroy callback is invoked to clean up ptr, + then a CUDAError is raised — callers need no try/except. + """ + cdef cydriver.CUuserObject user_obj = NULL + cdef cydriver.CUresult ret + with nogil: + ret = cydriver.cuUserObjectCreate( + &user_obj, ptr, destroy, 1, + cydriver.CU_USER_OBJECT_NO_DESTRUCTOR_SYNC) + if ret == cydriver.CUDA_SUCCESS: + ret = cydriver.cuGraphRetainUserObject( + graph, user_obj, 1, cydriver.CU_GRAPH_USER_OBJECT_MOVE) + if ret != cydriver.CUDA_SUCCESS: + cydriver.cuUserObjectRelease(user_obj, 1) + if ret != cydriver.CUDA_SUCCESS: + if user_obj == NULL: + destroy(ptr) + HANDLE_RETURN(ret) + + +cdef class Condition: + """Wraps a CUgraphConditionalHandle. + + Created by :meth:`GraphDef.create_condition` and passed to + conditional-node builder methods (``if_cond``, ``if_else``, + ``while_loop``, ``switch``). The underlying value is set at + runtime by device code via ``cudaGraphSetConditional``. + """ + + def __repr__(self) -> str: + return f"self._c_handle:x}>" + + def __eq__(self, other) -> bool: + if not isinstance(other, Condition): + return NotImplemented + return self._c_handle == (other)._c_handle + + def __hash__(self) -> int: + return hash(self._c_handle) + + @property + def handle(self) -> int: + """The raw CUgraphConditionalHandle as an int.""" + return self._c_handle + + +cdef ConditionalNode _make_conditional_node( + Node pred, + Condition condition, + cydriver.CUgraphConditionalNodeType cond_type, + unsigned int size, + type node_cls): + if not isinstance(condition, Condition): + raise TypeError( + f"condition must be a Condition object (from " + f"GraphDef.create_condition()), got {type(condition).__name__}") + cdef cydriver.CUgraphNodeParams params + cdef cydriver.CUgraphNode new_node = NULL + + c_memset(¶ms, 0, sizeof(params)) + params.type = cydriver.CU_GRAPH_NODE_TYPE_CONDITIONAL + params.conditional.handle = condition._c_handle + params.conditional.type = cond_type + params.conditional.size = size + + cdef cydriver.CUcontext ctx = NULL + cdef GraphHandle h_graph = graph_node_get_graph(pred._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(pred._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx)) + params.conditional.ctx = ctx + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddNode( + &new_node, as_cu(h_graph), deps, NULL, num_deps, ¶ms)) + + # cuGraphAddNode sets phGraph_out to an internal array of body + # graphs (it replaces the pointer, not writing into a caller array). + cdef list branch_list = [] + cdef unsigned int i + cdef cydriver.CUgraph bg + cdef GraphHandle h_branch + for i in range(size): + bg = params.conditional.phGraph_out[i] + h_branch = create_graph_handle_ref(bg, h_graph) + branch_list.append(GraphDef._from_handle(h_branch)) + cdef tuple branches = tuple(branch_list) + + cdef ConditionalNode n = node_cls.__new__(node_cls) + n._h_node = create_graph_node_handle(new_node, h_graph) + n._condition = condition + n._cond_type = cond_type + n._branches = branches + + pred._succ_cache = None + return n + + +@dataclass +class GraphAllocOptions: + """Options for graph memory allocation nodes. + + Attributes + ---------- + device : int or Device, optional + The device on which to allocate memory. If None (default), + uses the current CUDA context's device. + memory_type : str, optional + Type of memory to allocate. One of: + + - ``"device"`` (default): Pinned device memory, optimal for GPU kernels. + - ``"host"``: Pinned host memory, accessible from both host and device. + Useful for graphs containing host callback nodes. Note: may not be + supported on all systems/drivers. + - ``"managed"``: Managed/unified memory that automatically migrates + between host and device. Useful for mixed host/device access patterns. + + peer_access : list of int or Device, optional + List of devices that should have read-write access to the + allocated memory. If None (default), only the allocating + device has access. + + Notes + ----- + - IPC (inter-process communication) is not supported for graph + memory allocation nodes per CUDA documentation. + - The allocation uses the device's default memory pool. + """ + + device: int | Device | None = None + memory_type: str = "device" + peer_access: list | None = None + + +cdef class GraphDef: + """Represents a CUDA graph definition (CUgraph). + + A GraphDef is used to construct a graph explicitly by adding nodes + and specifying dependencies. Once construction is complete, call + instantiate() to obtain an executable Graph. + """ + + def __init__(self): + """Create a new empty graph definition.""" + cdef cydriver.CUgraph graph = NULL + with nogil: + HANDLE_RETURN(cydriver.cuGraphCreate(&graph, 0)) + self._h_graph = create_graph_handle(graph) + + @staticmethod + cdef GraphDef _from_handle(GraphHandle h_graph): + """Create a GraphDef from an existing GraphHandle (internal use).""" + cdef GraphDef g = GraphDef.__new__(GraphDef) + g._h_graph = h_graph + return g + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other) -> bool: + if not isinstance(other, GraphDef): + return NotImplemented + return as_intptr(self._h_graph) == as_intptr((other)._h_graph) + + def __hash__(self) -> int: + return hash(as_intptr(self._h_graph)) + + @property + def _entry(self) -> Node: + """Return the internal entry-point Node (no dependencies).""" + cdef Node n = Node.__new__(Node) + n._h_node = create_graph_node_handle(NULL, self._h_graph) + return n + + def alloc(self, size_t size, options: GraphAllocOptions | None = None) -> AllocNode: + """Add an entry-point memory allocation node (no dependencies). + + See :meth:`Node.alloc` for full documentation. + """ + return self._entry.alloc(size, options) + + def free(self, dptr) -> FreeNode: + """Add an entry-point memory free node (no dependencies). + + See :meth:`Node.free` for full documentation. + """ + return self._entry.free(dptr) + + def memset(self, dst, value, size_t width, size_t height=1, size_t pitch=0) -> MemsetNode: + """Add an entry-point memset node (no dependencies). + + See :meth:`Node.memset` for full documentation. + """ + return self._entry.memset(dst, value, width, height, pitch) + + def launch(self, config, kernel, *args) -> KernelNode: + """Add an entry-point kernel launch node (no dependencies). + + See :meth:`Node.launch` for full documentation. + """ + return self._entry.launch(config, kernel, *args) + + def join(self, *nodes) -> EmptyNode: + """Create an empty node that depends on all given nodes. + + Parameters + ---------- + *nodes : Node + Nodes to merge. + + Returns + ------- + EmptyNode + A new EmptyNode that depends on all input nodes. + """ + return self._entry.join(*nodes) + + def memcpy(self, dst, src, size_t size) -> MemcpyNode: + """Add an entry-point memcpy node (no dependencies). + + See :meth:`Node.memcpy` for full documentation. + """ + return self._entry.memcpy(dst, src, size) + + def embed(self, child: GraphDef) -> ChildGraphNode: + """Add an entry-point child graph node (no dependencies). + + See :meth:`Node.embed` for full documentation. + """ + return self._entry.embed(child) + + def record_event(self, event: Event) -> EventRecordNode: + """Add an entry-point event record node (no dependencies). + + See :meth:`Node.record_event` for full documentation. + """ + return self._entry.record_event(event) + + def wait_event(self, event: Event) -> EventWaitNode: + """Add an entry-point event wait node (no dependencies). + + See :meth:`Node.wait_event` for full documentation. + """ + return self._entry.wait_event(event) + + def callback(self, fn, *, user_data=None) -> HostCallbackNode: + """Add an entry-point host callback node (no dependencies). + + See :meth:`Node.callback` for full documentation. + """ + return self._entry.callback(fn, user_data=user_data) + + def create_condition(self, default_value: int | None = None) -> Condition: + """Create a condition variable for use with conditional nodes. + + The returned :class:`Condition` object is passed to conditional-node + builder methods. Its value is controlled at runtime by device code + via ``cudaGraphSetConditional``. + + Parameters + ---------- + default_value : int, optional + The default value to assign to the condition. + If None, no default is assigned. + + Returns + ------- + Condition + A condition variable for controlling conditional execution. + """ + cdef cydriver.CUgraphConditionalHandle c_handle + cdef unsigned int flags = 0 + cdef unsigned int default_val = 0 + + if default_value is not None: + default_val = default_value + flags = cydriver.CU_GRAPH_COND_ASSIGN_DEFAULT + + cdef cydriver.CUcontext ctx = NULL + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx)) + HANDLE_RETURN(cydriver.cuGraphConditionalHandleCreate( + &c_handle, as_cu(self._h_graph), ctx, default_val, flags)) + + cdef Condition cond = Condition.__new__(Condition) + cond._c_handle = c_handle + return cond + + def if_cond(self, condition: Condition) -> IfNode: + """Add an entry-point if-conditional node (no dependencies). + + See :meth:`Node.if_cond` for full documentation. + """ + return self._entry.if_cond(condition) + + def if_else(self, condition: Condition) -> IfElseNode: + """Add an entry-point if-else conditional node (no dependencies). + + See :meth:`Node.if_else` for full documentation. + """ + return self._entry.if_else(condition) + + def while_loop(self, condition: Condition) -> WhileNode: + """Add an entry-point while-loop conditional node (no dependencies). + + See :meth:`Node.while_loop` for full documentation. + """ + return self._entry.while_loop(condition) + + def switch(self, condition: Condition, unsigned int count) -> SwitchNode: + """Add an entry-point switch conditional node (no dependencies). + + See :meth:`Node.switch` for full documentation. + """ + return self._entry.switch(condition, count) + + def instantiate(self): + """Instantiate the graph definition into an executable Graph. + + Returns + ------- + Graph + An executable graph that can be launched on a stream. + """ + from cuda.core._graph import Graph + from cuda.core._utils.cuda_utils import handle_return + + graph_exec = handle_return(driver.cuGraphInstantiate( + driver.CUgraph(as_intptr(self._h_graph)), 0)) + return Graph._init(graph_exec) + + def debug_dot_print(self, path: str, options=None) -> None: + """Write a GraphViz DOT representation of the graph to a file. + + Parameters + ---------- + path : str + File path for the DOT output. + options : GraphDebugPrintOptions, optional + Customizable options for the debug print. + """ + from cuda.core._graph import GraphDebugPrintOptions + + cdef unsigned int flags = 0 + if options is not None: + if not isinstance(options, GraphDebugPrintOptions): + raise TypeError("options must be a GraphDebugPrintOptions instance") + flags = options._to_flags() + + cdef bytes path_bytes = path.encode('utf-8') + cdef const char* c_path = path_bytes + with nogil: + HANDLE_RETURN(cydriver.cuGraphDebugDotPrint(as_cu(self._h_graph), c_path, flags)) + + def nodes(self) -> tuple: + """Return all nodes in the graph. + + Returns + ------- + tuple of Node + All nodes in the graph. + """ + cdef size_t num_nodes = 0 + + with nogil: + HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), NULL, &num_nodes)) + + if num_nodes == 0: + return () + + cdef vector[cydriver.CUgraphNode] nodes_vec + nodes_vec.resize(num_nodes) + with nogil: + HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes)) + + return tuple(Node._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes)) + + def edges(self) -> tuple: + """Return all edges in the graph as (from_node, to_node) pairs. + + Returns + ------- + tuple of tuple + Each element is a (from_node, to_node) pair representing + a dependency edge in the graph. + """ + cdef size_t num_edges = 0 + + with nogil: + HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, NULL, &num_edges)) + + if num_edges == 0: + return () + + cdef vector[cydriver.CUgraphNode] from_nodes + cdef vector[cydriver.CUgraphNode] to_nodes + from_nodes.resize(num_edges) + to_nodes.resize(num_edges) + with nogil: + HANDLE_RETURN(cydriver.cuGraphGetEdges( + as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), NULL, &num_edges)) + + return tuple( + (Node._create(self._h_graph, from_nodes[i]), + Node._create(self._h_graph, to_nodes[i])) + for i in range(num_edges) + ) + + @property + def handle(self) -> int: + """Return the underlying CUgraph handle.""" + return as_py(self._h_graph) + + +cdef class Node: + """Base class for all graph nodes. + + Nodes are created by calling builder methods on GraphDef (for + entry-point nodes with no dependencies) or on other Nodes (for + nodes that depend on a predecessor). + """ + + @staticmethod + cdef Node _create(GraphHandle h_graph, cydriver.CUgraphNode node): + """Factory: dispatch to the right subclass based on node type.""" + if node == NULL: + n = Node.__new__(Node) + (n)._h_node = create_graph_node_handle(node, h_graph) + return n + + cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph) + cdef cydriver.CUgraphNodeType node_type + with nogil: + HANDLE_RETURN(cydriver.cuGraphNodeGetType(node, &node_type)) + + if node_type == cydriver.CU_GRAPH_NODE_TYPE_EMPTY: + return EmptyNode._create_impl(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_KERNEL: + return KernelNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_MEM_ALLOC: + return AllocNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_MEM_FREE: + return FreeNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_MEMSET: + return MemsetNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_MEMCPY: + return MemcpyNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_GRAPH: + return ChildGraphNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_EVENT_RECORD: + return EventRecordNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_WAIT_EVENT: + return EventWaitNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_HOST: + return HostCallbackNode._create_from_driver(h_node) + elif node_type == cydriver.CU_GRAPH_NODE_TYPE_CONDITIONAL: + return ConditionalNode._create_from_driver(h_node) + else: + n = Node.__new__(Node) + (n)._h_node = h_node + return n + + def __repr__(self) -> str: + cdef cydriver.CUgraphNode node = as_cu(self._h_node) + if node == NULL: + return "" + return f"node:x}>" + + def __eq__(self, other) -> bool: + if not isinstance(other, Node): + return NotImplemented + cdef Node o = other + return as_intptr(self._h_node) == as_intptr(o._h_node) + + def __hash__(self) -> int: + return hash(as_intptr(self._h_node)) + + @property + def type(self): + """Return the CUDA graph node type. + + Returns + ------- + CUgraphNodeType or None + The node type enum value, or None for the entry node. + """ + cdef cydriver.CUgraphNode node = as_cu(self._h_node) + if node == NULL: + return None + cdef cydriver.CUgraphNodeType node_type + with nogil: + HANDLE_RETURN(cydriver.cuGraphNodeGetType(node, &node_type)) + return driver.CUgraphNodeType(node_type) + + @property + def graph(self) -> GraphDef: + """Return the GraphDef this node belongs to.""" + return GraphDef._from_handle(graph_node_get_graph(self._h_node)) + + @property + def handle(self) -> int | None: + """Return the underlying CUgraphNode handle as an int. + + Returns None for the entry node. + """ + return as_py(self._h_node) + + @property + def pred(self) -> tuple: + """Return the predecessor nodes (dependencies) of this node. + + Results are cached since a node's dependencies are immutable + once created. + + Returns + ------- + tuple of Node + The nodes that this node depends on. + """ + if self._pred_cache is not None: + return self._pred_cache + + cdef cydriver.CUgraphNode node = as_cu(self._h_node) + if node == NULL: + self._pred_cache = () + return self._pred_cache + + cdef size_t num_deps = 0 + + with nogil: + HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, NULL, NULL, &num_deps)) + + if num_deps == 0: + self._pred_cache = () + return self._pred_cache + + cdef vector[cydriver.CUgraphNode] deps + deps.resize(num_deps) + with nogil: + HANDLE_RETURN(cydriver.cuGraphNodeGetDependencies(node, deps.data(), NULL, &num_deps)) + + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + self._pred_cache = tuple(Node._create(h_graph, deps[i]) for i in range(num_deps)) + return self._pred_cache + + @property + def succ(self) -> tuple: + """Return the successor nodes (dependents) of this node. + + Results are cached and automatically invalidated when new + dependent nodes are added via builder methods. + + Returns + ------- + tuple of Node + The nodes that depend on this node. + """ + if self._succ_cache is not None: + return self._succ_cache + + cdef cydriver.CUgraphNode node = as_cu(self._h_node) + if node == NULL: + self._succ_cache = () + return self._succ_cache + + cdef size_t num_deps = 0 + + with nogil: + HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, NULL, NULL, &num_deps)) + + if num_deps == 0: + self._succ_cache = () + return self._succ_cache + + cdef vector[cydriver.CUgraphNode] deps + deps.resize(num_deps) + with nogil: + HANDLE_RETURN(cydriver.cuGraphNodeGetDependentNodes(node, deps.data(), NULL, &num_deps)) + + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + self._succ_cache = tuple(Node._create(h_graph, deps[i]) for i in range(num_deps)) + return self._succ_cache + + def launch(self, config: LaunchConfig, kernel: Kernel, *args) -> KernelNode: + """Add a kernel launch node depending on this node. + + Parameters + ---------- + config : LaunchConfig + Launch configuration (grid, block, shared memory, etc.) + kernel : Kernel + The kernel to launch. + *args + Kernel arguments. + + Returns + ------- + KernelNode + A new KernelNode representing the kernel launch. + """ + cdef LaunchConfig conf = config + cdef Kernel ker = kernel + cdef ParamHolder ker_args = ParamHolder(args) + + cdef cydriver.CUDA_KERNEL_NODE_PARAMS node_params + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + node_params.kern = as_cu(ker._h_kernel) + node_params.func = NULL + node_params.gridDimX = conf.grid[0] + node_params.gridDimY = conf.grid[1] + node_params.gridDimZ = conf.grid[2] + node_params.blockDimX = conf.block[0] + node_params.blockDimY = conf.block[1] + node_params.blockDimZ = conf.block[2] + node_params.sharedMemBytes = conf.shmem_size + node_params.kernelParams = (ker_args.ptr) + node_params.extra = NULL + node_params.ctx = NULL + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddKernelNode( + &new_node, as_cu(h_graph), deps, num_deps, &node_params)) + + _attach_user_object(as_cu(h_graph), new KernelHandle(ker._h_kernel), + _destroy_kernel_handle_copy) + + self._succ_cache = None + return KernelNode._create_with_params( + create_graph_node_handle(new_node, h_graph), + conf.grid, conf.block, conf.shmem_size, + ker._h_kernel) + + def join(self, *nodes: Node) -> EmptyNode: + """Create an empty node that depends on this node and all given nodes. + + This is used to synchronize multiple branches of execution. + + Parameters + ---------- + *nodes : Node + Additional nodes to depend on. + + Returns + ------- + EmptyNode + A new EmptyNode that depends on all input nodes. + """ + cdef vector[cydriver.CUgraphNode] deps + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef Node other + cdef cydriver.CUgraphNode* deps_ptr = NULL + cdef size_t num_deps = 0 + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + + if pred_node != NULL: + deps.push_back(pred_node) + for other in nodes: + if as_cu((other)._h_node) != NULL: + deps.push_back(as_cu((other)._h_node)) + + num_deps = deps.size() + if num_deps > 0: + deps_ptr = deps.data() + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddEmptyNode( + &new_node, as_cu(h_graph), deps_ptr, num_deps)) + + self._succ_cache = None + for other in nodes: + (other)._succ_cache = None + return EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph)) + + def alloc(self, size_t size, options: GraphAllocOptions | None = None) -> AllocNode: + """Add a memory allocation node depending on this node. + + Parameters + ---------- + size : int + Number of bytes to allocate. + options : GraphAllocOptions, optional + Allocation options. If None, allocates on the current device. + + Returns + ------- + AllocNode + A new AllocNode representing the allocation. Access the allocated + device pointer via the dptr property. + """ + cdef int device_id + cdef cydriver.CUdevice dev + + if options is None or options.device is None: + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetDevice(&dev)) + device_id = dev + else: + device_id = getattr(options.device, 'device_id', options.device) + + cdef cydriver.CUDA_MEM_ALLOC_NODE_PARAMS alloc_params + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + cdef vector[cydriver.CUmemAccessDesc] access_descs + cdef int peer_id + cdef list peer_ids = [] + + if options is not None and options.peer_access is not None: + for peer_dev in options.peer_access: + peer_id = getattr(peer_dev, 'device_id', peer_dev) + peer_ids.append(peer_id) + access_descs.push_back(cydriver.CUmemAccessDesc_st( + cydriver.CUmemLocation_st( + cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE, + peer_id + ), + cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE + )) + + cdef str memory_type = "device" + if options is not None and options.memory_type is not None: + memory_type = options.memory_type + + c_memset(&alloc_params, 0, sizeof(alloc_params)) + alloc_params.poolProps.handleTypes = cydriver.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_NONE + alloc_params.bytesize = size + + if memory_type == "device": + alloc_params.poolProps.allocType = cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED + alloc_params.poolProps.location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + alloc_params.poolProps.location.id = device_id + elif memory_type == "host": + alloc_params.poolProps.allocType = cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED + alloc_params.poolProps.location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_HOST + alloc_params.poolProps.location.id = 0 + elif memory_type == "managed": + alloc_params.poolProps.allocType = cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_MANAGED + alloc_params.poolProps.location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + alloc_params.poolProps.location.id = device_id + else: + raise ValueError(f"Invalid memory_type: {memory_type!r}. " + "Must be 'device', 'host', or 'managed'.") + + if access_descs.size() > 0: + alloc_params.accessDescs = access_descs.data() + alloc_params.accessDescCount = access_descs.size() + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddMemAllocNode( + &new_node, as_cu(h_graph), deps, num_deps, &alloc_params)) + + self._succ_cache = None + return AllocNode._create_with_params( + create_graph_node_handle(new_node, h_graph), alloc_params.dptr, size, + device_id, memory_type, tuple(peer_ids)) + + def free(self, dptr: int) -> FreeNode: + """Add a memory free node depending on this node. + + Parameters + ---------- + dptr : int + Device pointer to free (typically from AllocNode.dptr). + + Returns + ------- + FreeNode + A new FreeNode representing the free operation. + """ + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + cdef cydriver.CUdeviceptr c_dptr = dptr + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddMemFreeNode( + &new_node, as_cu(h_graph), deps, num_deps, c_dptr)) + + self._succ_cache = None + return FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr) + + def memset(self, dst: int, value, size_t width, size_t height=1, size_t pitch=0) -> MemsetNode: + """Add a memset node depending on this node. + + Parameters + ---------- + dst : int + Destination device pointer. + value : int or buffer-protocol object + Fill value. int for 1-byte fill (range [0, 256)), + or buffer-protocol object of 1, 2, or 4 bytes. + width : int + Width of the row in elements. + height : int, optional + Number of rows (default 1). + pitch : int, optional + Pitch of destination in bytes (default 0, unused if height is 1). + + Returns + ------- + MemsetNode + A new MemsetNode representing the memset operation. + """ + cdef unsigned int val + cdef unsigned int elem_size + val, elem_size = _parse_fill_value(value) + + cdef cydriver.CUDA_MEMSET_NODE_PARAMS memset_params + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + cdef cydriver.CUdeviceptr c_dst = dst + cdef cydriver.CUcontext ctx = NULL + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx)) + + c_memset(&memset_params, 0, sizeof(memset_params)) + memset_params.dst = c_dst + memset_params.value = val + memset_params.elementSize = elem_size + memset_params.width = width + memset_params.height = height + memset_params.pitch = pitch + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddMemsetNode( + &new_node, as_cu(h_graph), deps, num_deps, + &memset_params, ctx)) + + self._succ_cache = None + return MemsetNode._create_with_params( + create_graph_node_handle(new_node, h_graph), c_dst, + val, elem_size, width, height, pitch) + + def memcpy(self, dst: int, src: int, size_t size) -> MemcpyNode: + """Add a memcpy node depending on this node. + + Copies ``size`` bytes from ``src`` to ``dst``. Memory types are + auto-detected via the driver, so both device and pinned host + pointers are supported. + + Parameters + ---------- + dst : int + Destination pointer (device or pinned host). + src : int + Source pointer (device or pinned host). + size : int + Number of bytes to copy. + + Returns + ------- + MemcpyNode + A new MemcpyNode representing the copy operation. + """ + cdef cydriver.CUdeviceptr c_dst = dst + cdef cydriver.CUdeviceptr c_src = src + + cdef unsigned int dst_mem_type = cydriver.CU_MEMORYTYPE_DEVICE + cdef unsigned int src_mem_type = cydriver.CU_MEMORYTYPE_DEVICE + cdef cydriver.CUresult ret + with nogil: + ret = cydriver.cuPointerGetAttribute( + &dst_mem_type, + cydriver.CU_POINTER_ATTRIBUTE_MEMORY_TYPE, + c_dst) + if ret != cydriver.CUDA_SUCCESS and ret != cydriver.CUDA_ERROR_INVALID_VALUE: + HANDLE_RETURN(ret) + ret = cydriver.cuPointerGetAttribute( + &src_mem_type, + cydriver.CU_POINTER_ATTRIBUTE_MEMORY_TYPE, + c_src) + if ret != cydriver.CUDA_SUCCESS and ret != cydriver.CUDA_ERROR_INVALID_VALUE: + HANDLE_RETURN(ret) + + cdef cydriver.CUmemorytype c_dst_type = dst_mem_type + cdef cydriver.CUmemorytype c_src_type = src_mem_type + + cdef cydriver.CUDA_MEMCPY3D params + c_memset(¶ms, 0, sizeof(params)) + + params.srcMemoryType = c_src_type + params.dstMemoryType = c_dst_type + if c_src_type == cydriver.CU_MEMORYTYPE_HOST: + params.srcHost = c_src + else: + params.srcDevice = c_src + if c_dst_type == cydriver.CU_MEMORYTYPE_HOST: + params.dstHost = c_dst + else: + params.dstDevice = c_dst + params.WidthInBytes = size + params.Height = 1 + params.Depth = 1 + + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + cdef cydriver.CUcontext ctx = NULL + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx)) + HANDLE_RETURN(cydriver.cuGraphAddMemcpyNode( + &new_node, as_cu(h_graph), deps, num_deps, ¶ms, ctx)) + + self._succ_cache = None + return MemcpyNode._create_with_params( + create_graph_node_handle(new_node, h_graph), c_dst, c_src, size, + c_dst_type, c_src_type) + + def embed(self, child: GraphDef) -> ChildGraphNode: + """Add a child graph node depending on this node. + + Embeds a clone of the given graph definition as a sub-graph node. + The child graph must not contain allocation, free, or conditional + nodes. + + Parameters + ---------- + child : GraphDef + The graph definition to embed (will be cloned). + + Returns + ------- + ChildGraphNode + A new ChildGraphNode representing the embedded sub-graph. + """ + cdef GraphDef child_def = child + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddChildGraphNode( + &new_node, as_cu(h_graph), deps, num_deps, as_cu(child_def._h_graph))) + + cdef cydriver.CUgraph embedded_graph = NULL + with nogil: + HANDLE_RETURN(cydriver.cuGraphChildGraphNodeGetGraph( + new_node, &embedded_graph)) + + cdef GraphHandle h_embedded = create_graph_handle_ref(embedded_graph, h_graph) + + self._succ_cache = None + return ChildGraphNode._create_with_params( + create_graph_node_handle(new_node, h_graph), h_embedded) + + def record_event(self, event: Event) -> EventRecordNode: + """Add an event record node depending on this node. + + Parameters + ---------- + event : Event + The event to record. + + Returns + ------- + EventRecordNode + A new EventRecordNode representing the event record operation. + """ + cdef Event ev = event + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddEventRecordNode( + &new_node, as_cu(h_graph), deps, num_deps, as_cu(ev._h_event))) + + _attach_user_object(as_cu(h_graph), new EventHandle(ev._h_event), + _destroy_event_handle_copy) + + self._succ_cache = None + return EventRecordNode._create_with_params( + create_graph_node_handle(new_node, h_graph), ev._h_event) + + def wait_event(self, event: Event) -> EventWaitNode: + """Add an event wait node depending on this node. + + Parameters + ---------- + event : Event + The event to wait for. + + Returns + ------- + EventWaitNode + A new EventWaitNode representing the event wait operation. + """ + cdef Event ev = event + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddEventWaitNode( + &new_node, as_cu(h_graph), deps, num_deps, as_cu(ev._h_event))) + + _attach_user_object(as_cu(h_graph), new EventHandle(ev._h_event), + _destroy_event_handle_copy) + + self._succ_cache = None + return EventWaitNode._create_with_params( + create_graph_node_handle(new_node, h_graph), ev._h_event) + + def callback(self, fn, *, user_data=None) -> HostCallbackNode: + """Add a host callback node depending on this node. + + The callback runs on the host CPU when the graph reaches this node. + Two modes are supported: + + - **Python callable**: Pass any callable. The GIL is acquired + automatically. The callable must take no arguments; use closures + or ``functools.partial`` to bind state. + - **ctypes function pointer**: Pass a ``ctypes.CFUNCTYPE`` instance. + The function receives a single ``void*`` argument (the + ``user_data``). The caller must keep the ctypes wrapper alive + for the lifetime of the graph. + + .. warning:: + + Callbacks must not call CUDA API functions. Doing so may + deadlock or corrupt driver state. + + Parameters + ---------- + fn : callable or ctypes function pointer + The callback function. + user_data : int or bytes-like, optional + Only for ctypes function pointers. If ``int``, passed as a raw + pointer (caller manages lifetime). If bytes-like, the data is + copied and its lifetime is tied to the graph. + + Returns + ------- + HostCallbackNode + A new HostCallbackNode representing the callback. + """ + import ctypes as ct + + cdef cydriver.CUDA_HOST_NODE_PARAMS node_params + cdef cydriver.CUgraphNode new_node = NULL + cdef GraphHandle h_graph = graph_node_get_graph(self._h_node) + cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) + cdef cydriver.CUgraphNode* deps = NULL + cdef size_t num_deps = 0 + cdef void* c_user_data = NULL + cdef object callable_obj = None + cdef void* fn_pyobj = NULL + + if pred_node != NULL: + deps = &pred_node + num_deps = 1 + + if isinstance(fn, ct._CFuncPtr): + node_params.fn = ct.cast( + fn, ct.c_void_p).value + + if user_data is not None: + if isinstance(user_data, int): + c_user_data = user_data + else: + buf = bytes(user_data) + c_user_data = malloc(len(buf)) + if c_user_data == NULL: + raise MemoryError( + "failed to allocate user_data buffer") + c_memcpy(c_user_data, buf, len(buf)) + _attach_user_object( + as_cu(h_graph), c_user_data, + free) + + node_params.userData = c_user_data + else: + if user_data is not None: + raise ValueError( + "user_data is only supported with ctypes " + "function pointers") + callable_obj = fn + Py_INCREF(fn) + fn_pyobj = fn + node_params.fn = _py_host_trampoline + node_params.userData = fn_pyobj + _attach_user_object( + as_cu(h_graph), fn_pyobj, + _py_host_destructor) + + with nogil: + HANDLE_RETURN(cydriver.cuGraphAddHostNode( + &new_node, as_cu(h_graph), deps, num_deps, &node_params)) + + self._succ_cache = None + return HostCallbackNode._create_with_params( + create_graph_node_handle(new_node, h_graph), callable_obj, + node_params.fn, node_params.userData) + + def if_cond(self, condition: Condition) -> IfNode: + """Add an if-conditional node depending on this node. + + The body graph executes only when the condition evaluates to + a non-zero value at runtime. + + Parameters + ---------- + condition : Condition + Condition from :meth:`GraphDef.create_condition`. + + Returns + ------- + IfNode + A new IfNode with one branch accessible via ``.then``. + """ + return _make_conditional_node( + self, condition, + cydriver.CU_GRAPH_COND_TYPE_IF, 1, IfNode) + + def if_else(self, condition: Condition) -> IfElseNode: + """Add an if-else conditional node depending on this node. + + Two body graphs: the first executes when the condition is + non-zero, the second when it is zero. + + Parameters + ---------- + condition : Condition + Condition from :meth:`GraphDef.create_condition`. + + Returns + ------- + IfElseNode + A new IfElseNode with branches accessible via + ``.then`` and ``.else_``. + """ + return _make_conditional_node( + self, condition, + cydriver.CU_GRAPH_COND_TYPE_IF, 2, IfElseNode) + + def while_loop(self, condition: Condition) -> WhileNode: + """Add a while-loop conditional node depending on this node. + + The body graph executes repeatedly while the condition + evaluates to a non-zero value. + + Parameters + ---------- + condition : Condition + Condition from :meth:`GraphDef.create_condition`. + + Returns + ------- + WhileNode + A new WhileNode with body accessible via ``.body``. + """ + return _make_conditional_node( + self, condition, + cydriver.CU_GRAPH_COND_TYPE_WHILE, 1, WhileNode) + + def switch(self, condition: Condition, unsigned int count) -> SwitchNode: + """Add a switch conditional node depending on this node. + + The condition value selects which branch to execute. If the + value is out of range, no branch executes. + + Parameters + ---------- + condition : Condition + Condition from :meth:`GraphDef.create_condition`. + count : int + Number of switch cases (branches). + + Returns + ------- + SwitchNode + A new SwitchNode with branches accessible via ``.branches``. + """ + return _make_conditional_node( + self, condition, + cydriver.CU_GRAPH_COND_TYPE_SWITCH, count, SwitchNode) + + +# ============================================================================= +# Node subclasses +# ============================================================================= + + +cdef class EmptyNode(Node): + """A synchronization / join node with no operation.""" + + @staticmethod + cdef EmptyNode _create_impl(GraphNodeHandle h_node): + cdef EmptyNode n = EmptyNode.__new__(EmptyNode) + n._h_node = h_node + return n + + def __repr__(self) -> str: + cdef Py_ssize_t n = len(self.pred) + return f"" + + +cdef class KernelNode(Node): + """A kernel launch node. + + Properties + ---------- + grid : tuple of int + Grid dimensions (gridDimX, gridDimY, gridDimZ). + block : tuple of int + Block dimensions (blockDimX, blockDimY, blockDimZ). + shmem_size : int + Dynamic shared memory size in bytes. + kernel : Kernel + The kernel object for this launch node. + config : LaunchConfig + A LaunchConfig reconstructed from this node's parameters. + """ + + @staticmethod + cdef KernelNode _create_with_params(GraphNodeHandle h_node, + tuple grid, tuple block, unsigned int shmem_size, + KernelHandle h_kernel): + """Create from known params (called by launch() builder).""" + cdef KernelNode n = KernelNode.__new__(KernelNode) + n._h_node = h_node + n._grid = grid + n._block = block + n._shmem_size = shmem_size + n._h_kernel = h_kernel + return n + + @staticmethod + cdef KernelNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUDA_KERNEL_NODE_PARAMS params + with nogil: + HANDLE_RETURN(cydriver.cuGraphKernelNodeGetParams(node, ¶ms)) + cdef KernelHandle h_kernel = create_kernel_handle_ref(params.kern) + return KernelNode._create_with_params( + h_node, + (params.gridDimX, params.gridDimY, params.gridDimZ), + (params.blockDimX, params.blockDimY, params.blockDimZ), + params.sharedMemBytes, + h_kernel) + + def __repr__(self) -> str: + return (f"") + + @property + def grid(self) -> tuple: + """Grid dimensions as a 3-tuple (gridDimX, gridDimY, gridDimZ).""" + return self._grid + + @property + def block(self) -> tuple: + """Block dimensions as a 3-tuple (blockDimX, blockDimY, blockDimZ).""" + return self._block + + @property + def shmem_size(self) -> int: + """Dynamic shared memory size in bytes.""" + return self._shmem_size + + @property + def kernel(self) -> Kernel: + """The Kernel object for this launch node.""" + return Kernel._from_handle(self._h_kernel) + + @property + def config(self) -> LaunchConfig: + """A LaunchConfig reconstructed from this node's grid, block, and shmem_size. + + Note: cluster dimensions and cooperative_launch are not preserved + by the CUDA driver's kernel node params, so they are not included. + """ + return LaunchConfig(grid=self._grid, block=self._block, + shmem_size=self._shmem_size) + + +cdef class AllocNode(Node): + """A memory allocation node. + + Properties + ---------- + dptr : int + The device pointer for the allocation. + bytesize : int + The number of bytes allocated. + device_id : int + The device on which the allocation was made. + memory_type : str + The type of memory allocated (``"device"``, ``"host"``, or ``"managed"``). + peer_access : tuple of int + Device IDs that have read-write access to this allocation. + options : GraphAllocOptions + A GraphAllocOptions reconstructed from this node's parameters. + """ + + @staticmethod + cdef AllocNode _create_with_params(GraphNodeHandle h_node, + cydriver.CUdeviceptr dptr, size_t bytesize, + int device_id, str memory_type, tuple peer_access): + """Create from known params (called by alloc() builder).""" + cdef AllocNode n = AllocNode.__new__(AllocNode) + n._h_node = h_node + n._dptr = dptr + n._bytesize = bytesize + n._device_id = device_id + n._memory_type = memory_type + n._peer_access = peer_access + return n + + @staticmethod + cdef AllocNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUDA_MEM_ALLOC_NODE_PARAMS params + with nogil: + HANDLE_RETURN(cydriver.cuGraphMemAllocNodeGetParams(node, ¶ms)) + + cdef str memory_type + if params.poolProps.allocType == cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED: + if params.poolProps.location.type == cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_HOST: + memory_type = "host" + else: + memory_type = "device" + elif params.poolProps.allocType == cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_MANAGED: + memory_type = "managed" + else: + memory_type = "device" + + cdef list peer_ids = [] + cdef size_t i + for i in range(params.accessDescCount): + peer_ids.append(params.accessDescs[i].location.id) + + return AllocNode._create_with_params( + h_node, params.dptr, params.bytesize, + params.poolProps.location.id, memory_type, tuple(peer_ids)) + + def __repr__(self) -> str: + return f"" + + @property + def dptr(self) -> int: + """The device pointer for the allocation.""" + return self._dptr + + @property + def bytesize(self) -> int: + """The number of bytes allocated.""" + return self._bytesize + + @property + def device_id(self) -> int: + """The device on which the allocation was made.""" + return self._device_id + + @property + def memory_type(self) -> str: + """The type of memory: ``"device"``, ``"host"``, or ``"managed"``.""" + return self._memory_type + + @property + def peer_access(self) -> tuple: + """Device IDs with read-write access to this allocation.""" + return self._peer_access + + @property + def options(self) -> GraphAllocOptions: + """A GraphAllocOptions reconstructed from this node's parameters.""" + return GraphAllocOptions( + device=self._device_id, + memory_type=self._memory_type, + peer_access=list(self._peer_access) if self._peer_access else None, + ) + + +cdef class FreeNode(Node): + """A memory free node. + + Properties + ---------- + dptr : int + The device pointer being freed. + """ + + @staticmethod + cdef FreeNode _create_with_params(GraphNodeHandle h_node, + cydriver.CUdeviceptr dptr): + """Create from known params (called by free() builder).""" + cdef FreeNode n = FreeNode.__new__(FreeNode) + n._h_node = h_node + n._dptr = dptr + return n + + @staticmethod + cdef FreeNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUdeviceptr dptr + with nogil: + HANDLE_RETURN(cydriver.cuGraphMemFreeNodeGetParams(node, &dptr)) + return FreeNode._create_with_params(h_node, dptr) + + def __repr__(self) -> str: + return f"" + + @property + def dptr(self) -> int: + """The device pointer being freed.""" + return self._dptr + + +cdef class MemsetNode(Node): + """A memory set node. + + Properties + ---------- + dptr : int + The destination device pointer. + value : int + The fill value. + element_size : int + Element size in bytes (1, 2, or 4). + width : int + Width of the row in elements. + height : int + Number of rows. + pitch : int + Pitch in bytes (unused if height is 1). + """ + + @staticmethod + cdef MemsetNode _create_with_params(GraphNodeHandle h_node, + cydriver.CUdeviceptr dptr, unsigned int value, + unsigned int element_size, size_t width, + size_t height, size_t pitch): + """Create from known params (called by memset() builder).""" + cdef MemsetNode n = MemsetNode.__new__(MemsetNode) + n._h_node = h_node + n._dptr = dptr + n._value = value + n._element_size = element_size + n._width = width + n._height = height + n._pitch = pitch + return n + + @staticmethod + cdef MemsetNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUDA_MEMSET_NODE_PARAMS params + with nogil: + HANDLE_RETURN(cydriver.cuGraphMemsetNodeGetParams(node, ¶ms)) + return MemsetNode._create_with_params( + h_node, params.dst, params.value, + params.elementSize, params.width, params.height, params.pitch) + + def __repr__(self) -> str: + return (f"") + + @property + def dptr(self) -> int: + """The destination device pointer.""" + return self._dptr + + @property + def value(self) -> int: + """The fill value.""" + return self._value + + @property + def element_size(self) -> int: + """Element size in bytes (1, 2, or 4).""" + return self._element_size + + @property + def width(self) -> int: + """Width of the row in elements.""" + return self._width + + @property + def height(self) -> int: + """Number of rows.""" + return self._height + + @property + def pitch(self) -> int: + """Pitch in bytes (unused if height is 1).""" + return self._pitch + + +cdef class MemcpyNode(Node): + """A memory copy node. + + Properties + ---------- + dst : int + The destination pointer. + src : int + The source pointer. + size : int + The number of bytes copied. + """ + + @staticmethod + cdef MemcpyNode _create_with_params(GraphNodeHandle h_node, + cydriver.CUdeviceptr dst, cydriver.CUdeviceptr src, + size_t size, cydriver.CUmemorytype dst_type, + cydriver.CUmemorytype src_type): + """Create from known params (called by memcpy() builder).""" + cdef MemcpyNode n = MemcpyNode.__new__(MemcpyNode) + n._h_node = h_node + n._dst = dst + n._src = src + n._size = size + n._dst_type = dst_type + n._src_type = src_type + return n + + @staticmethod + cdef MemcpyNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUDA_MEMCPY3D params + with nogil: + HANDLE_RETURN(cydriver.cuGraphMemcpyNodeGetParams(node, ¶ms)) + + cdef cydriver.CUdeviceptr dst + cdef cydriver.CUdeviceptr src + if params.dstMemoryType == cydriver.CU_MEMORYTYPE_HOST: + dst = params.dstHost + else: + dst = params.dstDevice + if params.srcMemoryType == cydriver.CU_MEMORYTYPE_HOST: + src = params.srcHost + else: + src = params.srcDevice + + return MemcpyNode._create_with_params( + h_node, dst, src, params.WidthInBytes, + params.dstMemoryType, params.srcMemoryType) + + def __repr__(self) -> str: + cdef str dt = "H" if self._dst_type == cydriver.CU_MEMORYTYPE_HOST else "D" + cdef str st = "H" if self._src_type == cydriver.CU_MEMORYTYPE_HOST else "D" + return (f"") + + @property + def dst(self) -> int: + """The destination pointer.""" + return self._dst + + @property + def src(self) -> int: + """The source pointer.""" + return self._src + + @property + def size(self) -> int: + """The number of bytes copied.""" + return self._size + + +cdef class ChildGraphNode(Node): + """A child graph (sub-graph) node. + + Properties + ---------- + child_graph : GraphDef + The embedded graph definition (non-owning wrapper). + """ + + @staticmethod + cdef ChildGraphNode _create_with_params(GraphNodeHandle h_node, + GraphHandle h_child_graph): + """Create from known params (called by embed() builder).""" + cdef ChildGraphNode n = ChildGraphNode.__new__(ChildGraphNode) + n._h_node = h_node + n._h_child_graph = h_child_graph + return n + + @staticmethod + cdef ChildGraphNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUgraph child_graph = NULL + with nogil: + HANDLE_RETURN(cydriver.cuGraphChildGraphNodeGetGraph(node, &child_graph)) + cdef GraphHandle h_graph = graph_node_get_graph(h_node) + cdef GraphHandle h_child = create_graph_handle_ref(child_graph, h_graph) + return ChildGraphNode._create_with_params(h_node, h_child) + + def __repr__(self) -> str: + cdef cydriver.CUgraph g = as_cu(self._h_child_graph) + cdef size_t num_nodes = 0 + with nogil: + HANDLE_RETURN(cydriver.cuGraphGetNodes(g, NULL, &num_nodes)) + cdef Py_ssize_t n = num_nodes + return f"" + + @property + def child_graph(self) -> GraphDef: + """The embedded graph definition (non-owning wrapper).""" + return GraphDef._from_handle(self._h_child_graph) + + +cdef class EventRecordNode(Node): + """An event record node. + + Properties + ---------- + event : Event + The event being recorded. + """ + + @staticmethod + cdef EventRecordNode _create_with_params(GraphNodeHandle h_node, + EventHandle h_event): + """Create from known params (called by record_event() builder).""" + cdef EventRecordNode n = EventRecordNode.__new__(EventRecordNode) + n._h_node = h_node + n._h_event = h_event + return n + + @staticmethod + cdef EventRecordNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUevent event + with nogil: + HANDLE_RETURN(cydriver.cuGraphEventRecordNodeGetEvent(node, &event)) + cdef EventHandle h_event = create_event_handle_ref(event) + return EventRecordNode._create_with_params(h_node, h_event) + + def __repr__(self) -> str: + return f"" + + @property + def event(self) -> Event: + """The event being recorded.""" + return Event._from_handle(self._h_event) + + +cdef class EventWaitNode(Node): + """An event wait node. + + Properties + ---------- + event : Event + The event being waited on. + """ + + @staticmethod + cdef EventWaitNode _create_with_params(GraphNodeHandle h_node, + EventHandle h_event): + """Create from known params (called by wait_event() builder).""" + cdef EventWaitNode n = EventWaitNode.__new__(EventWaitNode) + n._h_node = h_node + n._h_event = h_event + return n + + @staticmethod + cdef EventWaitNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUevent event + with nogil: + HANDLE_RETURN(cydriver.cuGraphEventWaitNodeGetEvent(node, &event)) + cdef EventHandle h_event = create_event_handle_ref(event) + return EventWaitNode._create_with_params(h_node, h_event) + + def __repr__(self) -> str: + return f"" + + @property + def event(self) -> Event: + """The event being waited on.""" + return Event._from_handle(self._h_event) + + +cdef class HostCallbackNode(Node): + """A host callback node. + + Properties + ---------- + callback_fn : callable or None + The Python callable (None for ctypes function pointer callbacks). + """ + + @staticmethod + cdef HostCallbackNode _create_with_params(GraphNodeHandle h_node, + object callable_obj, cydriver.CUhostFn fn, + void* user_data): + """Create from known params (called by callback() builder).""" + cdef HostCallbackNode n = HostCallbackNode.__new__(HostCallbackNode) + n._h_node = h_node + n._callable = callable_obj + n._fn = fn + n._user_data = user_data + return n + + @staticmethod + cdef HostCallbackNode _create_from_driver(GraphNodeHandle h_node): + """Create by fetching params from the driver (called by _create factory).""" + cdef cydriver.CUgraphNode node = as_cu(h_node) + cdef cydriver.CUDA_HOST_NODE_PARAMS params + with nogil: + HANDLE_RETURN(cydriver.cuGraphHostNodeGetParams(node, ¶ms)) + + cdef object callable_obj = None + if params.fn == _py_host_trampoline: + callable_obj = params.userData + + return HostCallbackNode._create_with_params( + h_node, callable_obj, params.fn, params.userData) + + def __repr__(self) -> str: + if self._callable is not None: + name = getattr(self._callable, '__name__', '?') + return f"" + return f"self._fn:x}>" + + @property + def callback_fn(self): + """The Python callable, or None for ctypes function pointer callbacks.""" + return self._callable + + +cdef class ConditionalNode(Node): + """Base class for conditional graph nodes. + + When created via builder methods (if_cond, if_else, while_loop, switch), + a specific subclass (IfNode, IfElseNode, WhileNode, SwitchNode) is + returned. When reconstructed from the driver on CUDA 13.2+, the + correct subclass is determined via cuGraphNodeGetParams. On older + drivers, this base class is used as a fallback. + + Properties + ---------- + condition : Condition or None + The condition variable controlling execution (None pre-13.2). + cond_type : str or None + The conditional type ("if", "while", or "switch"; None pre-13.2). + branches : tuple of GraphDef + The body graphs for each branch (empty pre-13.2). + """ + + @staticmethod + cdef ConditionalNode _create_from_driver(GraphNodeHandle h_node): + cdef ConditionalNode n + if not _check_node_get_params(): + n = ConditionalNode.__new__(ConditionalNode) + n._h_node = h_node + n._condition = None + n._cond_type = cydriver.CU_GRAPH_COND_TYPE_IF + n._branches = () + return n + + cdef cydriver.CUgraphNode node = as_cu(h_node) + params = handle_return(driver.cuGraphNodeGetParams( + node)) + cond_params = params.conditional + cdef int cond_type_int = int(cond_params.type) + cdef unsigned int size = int(cond_params.size) + + cdef Condition condition = Condition.__new__(Condition) + condition._c_handle = ( + int(cond_params.handle)) + + cdef GraphHandle h_graph = graph_node_get_graph(h_node) + cdef list branch_list = [] + cdef unsigned int i + cdef GraphHandle h_branch + if cond_params.phGraph_out is not None: + for i in range(size): + h_branch = create_graph_handle_ref( + int(cond_params.phGraph_out[i]), + h_graph) + branch_list.append(GraphDef._from_handle(h_branch)) + cdef tuple branches = tuple(branch_list) + + cdef type cls + if cond_type_int == cydriver.CU_GRAPH_COND_TYPE_IF: + if size == 1: + cls = IfNode + else: + cls = IfElseNode + elif cond_type_int == cydriver.CU_GRAPH_COND_TYPE_WHILE: + cls = WhileNode + else: + cls = SwitchNode + + n = cls.__new__(cls) + n._h_node = h_node + n._condition = condition + n._cond_type = cond_type_int + n._branches = branches + return n + + def __repr__(self) -> str: + return "" + + @property + def condition(self) -> Condition | None: + """The condition variable controlling execution.""" + return self._condition + + @property + def cond_type(self) -> str | None: + """The conditional type as a string: 'if', 'while', or 'switch'. + + Returns None when reconstructed from the driver pre-CUDA 13.2, + as the conditional type cannot be determined. + """ + if self._condition is None: + return None + if self._cond_type == cydriver.CU_GRAPH_COND_TYPE_IF: + return "if" + elif self._cond_type == cydriver.CU_GRAPH_COND_TYPE_WHILE: + return "while" + else: + return "switch" + + @property + def branches(self) -> tuple: + """The body graphs for each branch as a tuple of GraphDef. + + Returns an empty tuple when reconstructed from the driver + pre-CUDA 13.2. + """ + return self._branches + + +cdef class IfNode(ConditionalNode): + """An if-conditional node (1 branch, executes when condition is non-zero).""" + + def __repr__(self) -> str: + return f"self._condition._c_handle:x}>" + + @property + def then(self) -> GraphDef: + """The 'then' branch graph.""" + return self._branches[0] + + +cdef class IfElseNode(ConditionalNode): + """An if-else conditional node (2 branches).""" + + def __repr__(self) -> str: + return f"self._condition._c_handle:x}>" + + @property + def then(self) -> GraphDef: + """The 'then' branch graph (executed when condition is non-zero).""" + return self._branches[0] + + @property + def else_(self) -> GraphDef: + """The 'else' branch graph (executed when condition is zero).""" + return self._branches[1] + + +cdef class WhileNode(ConditionalNode): + """A while-loop conditional node (1 branch, repeats while condition is non-zero).""" + + def __repr__(self) -> str: + return f"self._condition._c_handle:x}>" + + @property + def body(self) -> GraphDef: + """The loop body graph.""" + return self._branches[0] + + +cdef class SwitchNode(ConditionalNode): + """A switch conditional node (N branches, selected by condition value).""" + + def __repr__(self) -> str: + cdef Py_ssize_t n = len(self._branches) + return (f"self._condition._c_handle:x}" + f" with {n} {'branch' if n == 1 else 'branches'}>") diff --git a/cuda_core/cuda/core/_memory/_buffer.pyx b/cuda_core/cuda/core/_memory/_buffer.pyx index 83009f74ae..a688c2065c 100644 --- a/cuda_core/cuda/core/_memory/_buffer.pyx +++ b/cuda_core/cuda/core/_memory/_buffer.pyx @@ -5,8 +5,7 @@ from __future__ import annotations cimport cython -from libc.stdint cimport uint8_t, uint16_t, uint32_t, uintptr_t -from cpython.buffer cimport PyObject_GetBuffer, PyBuffer_Release, Py_buffer, PyBUF_SIMPLE +from libc.stdint cimport uintptr_t from cuda.bindings cimport cydriver from cuda.core._memory._device_memory_resource import DeviceMemoryResource @@ -25,7 +24,7 @@ from cuda.core._resource_handles cimport ( ) from cuda.core._stream cimport Stream, Stream_accept -from cuda.core._utils.cuda_utils cimport HANDLE_RETURN +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, _parse_fill_value import sys from typing import TypeVar @@ -271,27 +270,27 @@ cdef class Buffer: """ cdef Stream s_stream = Stream_accept(stream) - - # Handle int case: 1-byte fill with automatic overflow checking. - if isinstance(value, int): - Buffer_fill_uint8(self, value, s_stream._h_stream) - return - - # Handle bytes case: direct pointer access without intermediate objects. - if isinstance(value, bytes): - Buffer_fill_from_ptr(self, value, len(value), s_stream._h_stream) - return - - # General buffer protocol path using C buffer API. - cdef Py_buffer buf - if PyObject_GetBuffer(value, &buf, PyBUF_SIMPLE) != 0: - raise TypeError( - f"value must be an int or support the buffer protocol, got {type(value).__name__}" - ) - try: - Buffer_fill_from_ptr(self, buf.buf, buf.len, s_stream._h_stream) - finally: - PyBuffer_Release(&buf) + cdef unsigned int val + cdef unsigned int elem_size + val, elem_size = _parse_fill_value(value) + + cdef size_t buffer_size = self._size + cdef cydriver.CUdeviceptr dst = as_cu(self._h_ptr) + cdef cydriver.CUstream s = as_cu(s_stream._h_stream) + + if elem_size == 1: + with nogil: + HANDLE_RETURN(cydriver.cuMemsetD8Async(dst, val, buffer_size, s)) + elif elem_size == 2: + if buffer_size & 0x1: + raise ValueError(f"buffer size ({buffer_size}) must be divisible by 2") + with nogil: + HANDLE_RETURN(cydriver.cuMemsetD16Async(dst, val, buffer_size // 2, s)) + elif elem_size == 4: + if buffer_size & 0x3: + raise ValueError(f"buffer size ({buffer_size}) must be divisible by 4") + with nogil: + HANDLE_RETURN(cydriver.cuMemsetD32Async(dst, val, buffer_size // 4, s)) def __dlpack__( self, @@ -569,36 +568,3 @@ cdef inline void Buffer_close(Buffer self, object stream): self._memory_resource = None self._ipc_data = None self._owner = None - - -cdef inline int Buffer_fill_uint8(Buffer self, uint8_t value, StreamHandle h_stream) except? -1: - cdef cydriver.CUdeviceptr ptr = as_cu(self._h_ptr) - cdef cydriver.CUstream s = as_cu(h_stream) - with nogil: - HANDLE_RETURN(cydriver.cuMemsetD8Async(ptr, value, self._size, s)) - return 0 - - -cdef inline int Buffer_fill_from_ptr( - Buffer self, const char* ptr, size_t width, StreamHandle h_stream -) except? -1: - cdef size_t buffer_size = self._size - cdef cydriver.CUdeviceptr dst = as_cu(self._h_ptr) - cdef cydriver.CUstream s = as_cu(h_stream) - - if width == 1: - with nogil: - HANDLE_RETURN(cydriver.cuMemsetD8Async(dst, (ptr)[0], buffer_size, s)) - elif width == 2: - if buffer_size & 0x1: - raise ValueError(f"buffer size ({buffer_size}) must be divisible by 2") - with nogil: - HANDLE_RETURN(cydriver.cuMemsetD16Async(dst, (ptr)[0], buffer_size // 2, s)) - elif width == 4: - if buffer_size & 0x3: - raise ValueError(f"buffer size ({buffer_size}) must be divisible by 4") - with nogil: - HANDLE_RETURN(cydriver.cuMemsetD32Async(dst, (ptr)[0], buffer_size // 4, s)) - else: - raise ValueError(f"value must be 1, 2, or 4 bytes, got {width}") - return 0 diff --git a/cuda_core/cuda/core/_module.pxd b/cuda_core/cuda/core/_module.pxd index 9468de3dff..1d3a0772c3 100644 --- a/cuda_core/cuda/core/_module.pxd +++ b/cuda_core/cuda/core/_module.pxd @@ -16,10 +16,11 @@ cdef class Kernel: KernelHandle _h_kernel KernelAttributes _attributes # lazy KernelOccupancy _occupancy # lazy + object _keepalive object __weakref__ @staticmethod - cdef Kernel _from_obj(KernelHandle h_kernel) + cdef Kernel _from_handle(KernelHandle h_kernel) cdef tuple _get_arguments_info(self, bint param_info=*) diff --git a/cuda_core/cuda/core/_module.pyx b/cuda_core/cuda/core/_module.pyx index ca5562f990..4e8f810619 100644 --- a/cuda_core/cuda/core/_module.pyx +++ b/cuda_core/cuda/core/_module.pyx @@ -19,9 +19,9 @@ from cuda.core._resource_handles cimport ( KernelHandle, create_library_handle_from_file, create_library_handle_from_data, - create_library_handle_ref, create_kernel_handle, create_kernel_handle_ref, + get_kernel_library, get_last_error, as_cu, as_py, @@ -493,7 +493,7 @@ cdef class Kernel: raise RuntimeError("Kernel objects cannot be instantiated directly. Please use ObjectCode APIs.") @staticmethod - cdef Kernel _from_obj(KernelHandle h_kernel): + cdef Kernel _from_handle(KernelHandle h_kernel): cdef Kernel ker = Kernel.__new__(Kernel) ker._h_kernel = h_kernel ker._attributes = None @@ -567,9 +567,7 @@ cdef class Kernel: @staticmethod def from_handle(handle, mod: ObjectCode = None) -> Kernel: - """Creates a new :obj:`Kernel` object from a foreign kernel handle. - - Uses a CUkernel pointer address to create a new :obj:`Kernel` object. + """Creates a new :obj:`Kernel` object from a kernel handle. Parameters ---------- @@ -577,37 +575,37 @@ cdef class Kernel: Kernel handle representing the address of a foreign kernel object (CUkernel). mod : :obj:`ObjectCode`, optional - The ObjectCode object associated with this kernel. If not provided, - a placeholder ObjectCode will be created. Note that without a proper - ObjectCode, certain operations may be limited. + The ObjectCode object associated with this kernel. Provides + library lifetime for foreign kernels not created by + cuda.core. """ - # Validate that handle is an integer if not isinstance(handle, int): raise TypeError(f"handle must be an integer, got {type(handle).__name__}") - # Convert the integer handle to CUkernel cdef cydriver.CUkernel cu_kernel = handle - cdef KernelHandle h_kernel - cdef cydriver.CUlibrary cu_library - cdef cydriver.CUresult err - - # If no module provided, create a placeholder and try to get the library - if mod is None: - mod = ObjectCode._init(b"", "cubin") - if _is_cukernel_get_library_supported(): - # Try to get the owning library via cuKernelGetLibrary - with nogil: - err = cydriver.cuKernelGetLibrary(&cu_library, cu_kernel) - if err == cydriver.CUDA_SUCCESS: - mod._h_library = create_library_handle_ref(cu_library) - - # Create kernel handle with library dependency - h_kernel = create_kernel_handle_ref(cu_kernel, mod._h_library) + cdef KernelHandle h_kernel = create_kernel_handle_ref(cu_kernel) if not h_kernel: HANDLE_RETURN(get_last_error()) - return Kernel._from_obj(h_kernel) + cdef LibraryHandle h_existing_lib = get_kernel_library(h_kernel) + cdef LibraryHandle h_caller_lib + + if mod is not None: + h_caller_lib = (mod)._h_library + if h_existing_lib and h_caller_lib: + if as_cu(h_existing_lib) != as_cu(h_caller_lib): + import warnings + warnings.warn( + "The library from the provided ObjectCode does not match " + "the library associated with this kernel.", + stacklevel=2, + ) + + cdef Kernel k = Kernel._from_handle(h_kernel) + if mod is not None and not h_existing_lib: + k._keepalive = mod + return k def __eq__(self, other) -> bool: if not isinstance(other, Kernel): @@ -825,7 +823,7 @@ cdef class ObjectCode: cdef KernelHandle h_kernel = create_kernel_handle(self._h_library, name) if not h_kernel: HANDLE_RETURN(get_last_error()) - return Kernel._from_obj(h_kernel) + return Kernel._from_handle(h_kernel) @property def code(self) -> CodeTypeT: diff --git a/cuda_core/cuda/core/_resource_handles.pxd b/cuda_core/cuda/core/_resource_handles.pxd index c5a1ab36a6..9b4baf11da 100644 --- a/cuda_core/cuda/core/_resource_handles.pxd +++ b/cuda_core/cuda/core/_resource_handles.pxd @@ -26,6 +26,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": ctypedef shared_ptr[const cydriver.CUdeviceptr] DevicePtrHandle ctypedef shared_ptr[const cydriver.CUlibrary] LibraryHandle ctypedef shared_ptr[const cydriver.CUkernel] KernelHandle + ctypedef shared_ptr[const cydriver.CUgraph] GraphHandle + ctypedef shared_ptr[const cydriver.CUgraphNode] GraphNodeHandle ctypedef shared_ptr[const cydriver.CUgraphicsResource] GraphicsResourceHandle ctypedef shared_ptr[const cynvrtc.nvrtcProgram] NvrtcProgramHandle @@ -48,6 +50,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": cydriver.CUdeviceptr as_cu(DevicePtrHandle h) noexcept nogil cydriver.CUlibrary as_cu(LibraryHandle h) noexcept nogil cydriver.CUkernel as_cu(KernelHandle h) noexcept nogil + cydriver.CUgraph as_cu(GraphHandle h) noexcept nogil + cydriver.CUgraphNode as_cu(GraphNodeHandle h) noexcept nogil cydriver.CUgraphicsResource as_cu(GraphicsResourceHandle h) noexcept nogil cynvrtc.nvrtcProgram as_cu(NvrtcProgramHandle h) noexcept nogil cynvvm.nvvmProgram as_cu(NvvmProgramHandle h) noexcept nogil @@ -62,6 +66,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": intptr_t as_intptr(DevicePtrHandle h) noexcept nogil intptr_t as_intptr(LibraryHandle h) noexcept nogil intptr_t as_intptr(KernelHandle h) noexcept nogil + intptr_t as_intptr(GraphHandle h) noexcept nogil + intptr_t as_intptr(GraphNodeHandle h) noexcept nogil intptr_t as_intptr(GraphicsResourceHandle h) noexcept nogil intptr_t as_intptr(NvrtcProgramHandle h) noexcept nogil intptr_t as_intptr(NvvmProgramHandle h) noexcept nogil @@ -76,6 +82,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": object as_py(DevicePtrHandle h) object as_py(LibraryHandle h) object as_py(KernelHandle h) + object as_py(GraphHandle h) + object as_py(GraphNodeHandle h) object as_py(GraphicsResourceHandle h) object as_py(NvrtcProgramHandle h) object as_py(NvvmProgramHandle h) @@ -108,10 +116,21 @@ cdef StreamHandle get_legacy_stream() except+ nogil cdef StreamHandle get_per_thread_stream() except+ nogil # Event handles -cdef EventHandle create_event_handle(const ContextHandle& h_ctx, unsigned int flags) except+ nogil +cdef EventHandle create_event_handle( + const ContextHandle& h_ctx, unsigned int flags, + bint timing_disabled, bint busy_waited, + bint ipc_enabled, int device_id) except+ nogil cdef EventHandle create_event_handle_noctx(unsigned int flags) except+ nogil +cdef EventHandle create_event_handle_ref(cydriver.CUevent event) except+ nogil cdef EventHandle create_event_handle_ipc( - const cydriver.CUipcEventHandle& ipc_handle) except+ nogil + const cydriver.CUipcEventHandle& ipc_handle, bint busy_waited) except+ nogil + +# Event metadata getters +cdef bint get_event_timing_disabled(const EventHandle& h) noexcept nogil +cdef bint get_event_busy_waited(const EventHandle& h) noexcept nogil +cdef bint get_event_ipc_enabled(const EventHandle& h) noexcept nogil +cdef int get_event_device_id(const EventHandle& h) noexcept nogil +cdef ContextHandle get_event_context(const EventHandle& h) noexcept nogil # Memory pool handles cdef MemoryPoolHandle create_mempool_handle( @@ -150,8 +169,16 @@ cdef LibraryHandle create_library_handle_ref(cydriver.CUlibrary library) except+ # Kernel handles cdef KernelHandle create_kernel_handle(const LibraryHandle& h_library, const char* name) except+ nogil -cdef KernelHandle create_kernel_handle_ref( - cydriver.CUkernel kernel, const LibraryHandle& h_library) except+ nogil +cdef KernelHandle create_kernel_handle_ref(cydriver.CUkernel kernel) except+ nogil +cdef LibraryHandle get_kernel_library(const KernelHandle& h) noexcept nogil + +# Graph handles +cdef GraphHandle create_graph_handle(cydriver.CUgraph graph) except+ nogil +cdef GraphHandle create_graph_handle_ref(cydriver.CUgraph graph, const GraphHandle& h_parent) except+ nogil + +# Graph node handles +cdef GraphNodeHandle create_graph_node_handle(cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil +cdef GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept nogil # Graphics resource handles cdef GraphicsResourceHandle create_graphics_resource_handle( diff --git a/cuda_core/cuda/core/_resource_handles.pyx b/cuda_core/cuda/core/_resource_handles.pyx index eebaed2e28..d4d60d6192 100644 --- a/cuda_core/cuda/core/_resource_handles.pyx +++ b/cuda_core/cuda/core/_resource_handles.pyx @@ -26,6 +26,7 @@ from ._resource_handles cimport ( DevicePtrHandle, LibraryHandle, KernelHandle, + GraphHandle, GraphicsResourceHandle, NvrtcProgramHandle, NvvmProgramHandle, @@ -70,11 +71,27 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": # Event handles (note: _create_event_handle* are internal due to C++ overloading) EventHandle create_event_handle "cuda_core::create_event_handle" ( - const ContextHandle& h_ctx, unsigned int flags) except+ nogil + const ContextHandle& h_ctx, unsigned int flags, + bint timing_disabled, bint busy_waited, + bint ipc_enabled, int device_id) except+ nogil EventHandle create_event_handle_noctx "cuda_core::create_event_handle_noctx" ( unsigned int flags) except+ nogil + EventHandle create_event_handle_ref "cuda_core::create_event_handle_ref" ( + cydriver.CUevent event) except+ nogil EventHandle create_event_handle_ipc "cuda_core::create_event_handle_ipc" ( - const cydriver.CUipcEventHandle& ipc_handle) except+ nogil + const cydriver.CUipcEventHandle& ipc_handle, bint busy_waited) except+ nogil + + # Event metadata getters + bint get_event_timing_disabled "cuda_core::get_event_timing_disabled" ( + const EventHandle& h) noexcept nogil + bint get_event_busy_waited "cuda_core::get_event_busy_waited" ( + const EventHandle& h) noexcept nogil + bint get_event_ipc_enabled "cuda_core::get_event_ipc_enabled" ( + const EventHandle& h) noexcept nogil + int get_event_device_id "cuda_core::get_event_device_id" ( + const EventHandle& h) noexcept nogil + ContextHandle get_event_context "cuda_core::get_event_context" ( + const EventHandle& h) noexcept nogil # Memory pool handles MemoryPoolHandle create_mempool_handle "cuda_core::create_mempool_handle" ( @@ -126,7 +143,21 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": KernelHandle create_kernel_handle "cuda_core::create_kernel_handle" ( const LibraryHandle& h_library, const char* name) except+ nogil KernelHandle create_kernel_handle_ref "cuda_core::create_kernel_handle_ref" ( - cydriver.CUkernel kernel, const LibraryHandle& h_library) except+ nogil + cydriver.CUkernel kernel) except+ nogil + LibraryHandle get_kernel_library "cuda_core::get_kernel_library" ( + const KernelHandle& h) noexcept nogil + + # Graph handles + GraphHandle create_graph_handle "cuda_core::create_graph_handle" ( + cydriver.CUgraph graph) except+ nogil + GraphHandle create_graph_handle_ref "cuda_core::create_graph_handle_ref" ( + cydriver.CUgraph graph, const GraphHandle& h_parent) except+ nogil + + # Graph node handles + GraphNodeHandle create_graph_node_handle "cuda_core::create_graph_node_handle" ( + cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil + GraphHandle graph_node_get_graph "cuda_core::graph_node_get_graph" ( + const GraphNodeHandle& h) noexcept nogil # Graphics resource handles GraphicsResourceHandle create_graphics_resource_handle "cuda_core::create_graphics_resource_handle" ( @@ -223,6 +254,9 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": void* p_cuLibraryUnload "reinterpret_cast(cuda_core::p_cuLibraryUnload)" void* p_cuLibraryGetKernel "reinterpret_cast(cuda_core::p_cuLibraryGetKernel)" + # Graph + void* p_cuGraphDestroy "reinterpret_cast(cuda_core::p_cuGraphDestroy)" + # Linker void* p_cuLinkDestroy "reinterpret_cast(cuda_core::p_cuLinkDestroy)" @@ -288,6 +322,9 @@ p_cuLibraryLoadData = _get_driver_fn("cuLibraryLoadData") p_cuLibraryUnload = _get_driver_fn("cuLibraryUnload") p_cuLibraryGetKernel = _get_driver_fn("cuLibraryGetKernel") +# Graph +p_cuGraphDestroy = _get_driver_fn("cuGraphDestroy") + # Linker p_cuLinkDestroy = _get_driver_fn("cuLinkDestroy") diff --git a/cuda_core/cuda/core/_utils/cuda_utils.pxd b/cuda_core/cuda/core/_utils/cuda_utils.pxd index 478ce705af..4562cd7135 100644 --- a/cuda_core/cuda/core/_utils/cuda_utils.pxd +++ b/cuda_core/cuda/core/_utils/cuda_utils.pxd @@ -4,7 +4,7 @@ cimport cpython from cpython.object cimport PyObject -from libc.stdint cimport int64_t, int32_t +from libc.stdint cimport int64_t, int32_t, uint8_t, uint16_t, uint32_t from cuda.bindings cimport cydriver, cynvrtc, cynvvm, cynvjitlink @@ -33,6 +33,8 @@ cpdef int _check_nvrtc_error(error) except?-1 cpdef check_or_create_options(type cls, options, str options_description=*, bint keep_none=*) +cpdef tuple _parse_fill_value(value) + # Create low-level externs so Cython won't "helpfully" handle reference counting # for us. Prefixing with an underscore to distinguish it from the definition in diff --git a/cuda_core/cuda/core/_utils/cuda_utils.pyx b/cuda_core/cuda/core/_utils/cuda_utils.pyx index 3134308b55..999b3be325 100644 --- a/cuda_core/cuda/core/_utils/cuda_utils.pyx +++ b/cuda_core/cuda/core/_utils/cuda_utils.pyx @@ -23,6 +23,8 @@ except ImportError: from cuda.bindings.nvvm import nvvmError from cuda.bindings.nvjitlink import nvJitLinkError +from cpython.buffer cimport PyObject_GetBuffer, PyBuffer_Release, Py_buffer, PyBUF_SIMPLE + from cuda.bindings cimport cynvrtc, cynvvm, cynvjitlink from cuda.core._utils.driver_cu_result_explanations import DRIVER_CU_RESULT_EXPLANATIONS @@ -368,6 +370,64 @@ def reset_fork_warning(): _fork_warning_checked = False +cdef inline tuple _read_fill_ptr(const char* ptr, Py_ssize_t width): + """Extract (value, element_size) from a raw pointer of known width.""" + cdef unsigned int val + if width == 1: + val = (ptr)[0] + elif width == 2: + val = (ptr)[0] + elif width == 4: + val = (ptr)[0] + else: + raise ValueError(f"value must be 1, 2, or 4 bytes, got {width}") + return (val, width) + + +cpdef tuple _parse_fill_value(value): + """Parse a fill/memset value into (raw_value, element_size). + + Parameters + ---------- + value : int or buffer-protocol object + - int: Must be in range [0, 256). Treated as 1-byte fill. + - bytes or buffer-protocol: Must be 1, 2, or 4 bytes. + + Returns + ------- + tuple of (int, int) + (raw_value, element_size) where element_size is 1, 2, or 4. + + Raises + ------ + OverflowError + If int value is outside [0, 256). + TypeError + If value is not an int and does not support the buffer protocol. + ValueError + If value byte length is not 1, 2, or 4. + """ + cdef uint8_t byte_val + cdef Py_buffer buf + + if isinstance(value, int): + byte_val = value + return (byte_val, 1) + + if isinstance(value, bytes): + return _read_fill_ptr(value, len(value)) + + if PyObject_GetBuffer(value, &buf, PyBUF_SIMPLE) != 0: + raise TypeError( + f"value must be an int or support the buffer protocol, " + f"got {type(value).__name__}" + ) + try: + return _read_fill_ptr(buf.buf, buf.len) + finally: + PyBuffer_Release(&buf) + + def check_multiprocessing_start_method(): """Check if multiprocessing start method is 'fork' and warn if so.""" global _fork_warning_checked diff --git a/cuda_core/tests/graph/test_explicit.py b/cuda_core/tests/graph/test_explicit.py new file mode 100644 index 0000000000..b9a16974b9 --- /dev/null +++ b/cuda_core/tests/graph/test_explicit.py @@ -0,0 +1,1110 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +"""Tests for explicit CUDA graph construction (GraphDef and Node).""" + +from collections.abc import Callable +from dataclasses import dataclass, field + +import pytest +from helpers.graph_kernels import compile_common_kernels +from helpers.misc import try_create_condition + +from cuda.core import Device, LaunchConfig +from cuda.core._graph import GraphDebugPrintOptions +from cuda.core._graph._graphdef import ( + AllocNode, + ChildGraphNode, + ConditionalNode, + EmptyNode, + EventRecordNode, + EventWaitNode, + FreeNode, + GraphAllocOptions, + GraphDef, + HostCallbackNode, + IfElseNode, + IfNode, + KernelNode, + MemcpyNode, + MemsetNode, + Node, + SwitchNode, + WhileNode, +) + +ALLOC_SIZE = 1024 + + +def _driver_has_node_get_params(): + from cuda.bindings import driver as drv + + return drv.cuDriverGetVersion()[1] >= 13020 + + +_HAS_NODE_GET_PARAMS = _driver_has_node_get_params() + + +# ============================================================================= +# GraphSpec — representative graph topologies +# ============================================================================= + + +@dataclass +class GraphSpec: + """Describes a graph topology with expected structural properties.""" + + name: str + graphdef: GraphDef + named_nodes: dict = field(default_factory=dict) + expected_edges: set = field(default_factory=set) + expected_pred: dict = field(default_factory=dict) + expected_succ: dict = field(default_factory=dict) + + +def _build_empty(): + """No nodes, no edges.""" + return GraphSpec("empty", GraphDef()) + + +def _build_single(): + """One alloc node, no edges.""" + g = GraphDef() + a = g.alloc(ALLOC_SIZE) + return GraphSpec( + "single", + g, + named_nodes={"a": a}, + expected_edges=set(), + expected_pred={"a": set()}, + expected_succ={"a": set()}, + ) + + +def _build_chain(): + """Linear chain: a -> b -> c.""" + g = GraphDef() + a = g.alloc(ALLOC_SIZE) + b = a.alloc(ALLOC_SIZE) + c = b.alloc(ALLOC_SIZE) + return GraphSpec( + "chain", + g, + named_nodes={"a": a, "b": b, "c": c}, + expected_edges={("a", "b"), ("b", "c")}, + expected_pred={"a": set(), "b": {"a"}, "c": {"b"}}, + expected_succ={"a": {"b"}, "b": {"c"}, "c": set()}, + ) + + +def _build_fan_out(): + """One node feeds three: a -> {b, c, d}.""" + g = GraphDef() + a = g.alloc(ALLOC_SIZE) + b = a.alloc(ALLOC_SIZE) + c = a.alloc(ALLOC_SIZE) + d = a.alloc(ALLOC_SIZE) + return GraphSpec( + "fan_out", + g, + named_nodes={"a": a, "b": b, "c": c, "d": d}, + expected_edges={("a", "b"), ("a", "c"), ("a", "d")}, + expected_pred={"a": set(), "b": {"a"}, "c": {"a"}, "d": {"a"}}, + expected_succ={"a": {"b", "c", "d"}, "b": set(), "c": set(), "d": set()}, + ) + + +def _build_fan_in(): + """Three entry nodes merge: {a, b, c} -> d (join).""" + g = GraphDef() + a = g.alloc(ALLOC_SIZE) + b = g.alloc(ALLOC_SIZE) + c = g.alloc(ALLOC_SIZE) + d = g.join(a, b, c) + return GraphSpec( + "fan_in", + g, + named_nodes={"a": a, "b": b, "c": c, "d": d}, + expected_edges={("a", "d"), ("b", "d"), ("c", "d")}, + expected_pred={"a": set(), "b": set(), "c": set(), "d": {"a", "b", "c"}}, + expected_succ={"a": {"d"}, "b": {"d"}, "c": {"d"}, "d": set()}, + ) + + +def _build_diamond(): + """Diamond: a -> {b, c} -> d (join).""" + g = GraphDef() + a = g.alloc(ALLOC_SIZE) + b = a.alloc(ALLOC_SIZE) + c = a.alloc(ALLOC_SIZE) + d = b.join(c) + return GraphSpec( + "diamond", + g, + named_nodes={"a": a, "b": b, "c": c, "d": d}, + expected_edges={("a", "b"), ("a", "c"), ("b", "d"), ("c", "d")}, + expected_pred={"a": set(), "b": {"a"}, "c": {"a"}, "d": {"b", "c"}}, + expected_succ={"a": {"b", "c"}, "b": {"d"}, "c": {"d"}, "d": set()}, + ) + + +def _build_disconnected(): + """Two independent entry nodes: a, b.""" + g = GraphDef() + a = g.alloc(ALLOC_SIZE) + b = g.alloc(ALLOC_SIZE) + return GraphSpec( + "disconnected", + g, + named_nodes={"a": a, "b": b}, + expected_edges=set(), + expected_pred={"a": set(), "b": set()}, + expected_succ={"a": set(), "b": set()}, + ) + + +_ALL_BUILDERS = [ + pytest.param(_build_empty, id="empty"), + pytest.param(_build_single, id="single"), + pytest.param(_build_chain, id="chain"), + pytest.param(_build_fan_out, id="fan_out"), + pytest.param(_build_fan_in, id="fan_in"), + pytest.param(_build_diamond, id="diamond"), + pytest.param(_build_disconnected, id="disconnected"), +] + +_NONEMPTY_BUILDERS = [p for p in _ALL_BUILDERS if p.values[0] is not _build_empty] + + +@pytest.fixture(params=_ALL_BUILDERS) +def graph_spec(request, init_cuda): + return request.param() + + +@pytest.fixture(params=_NONEMPTY_BUILDERS) +def nonempty_graph_spec(request, init_cuda): + return request.param() + + +# ============================================================================= +# NodeSpec — representative node types +# ============================================================================= + + +@dataclass +class NodeSpec: + """Describes a node type with expected properties. + + The builder returns (node, expected_attrs) where expected_attrs maps + property names to expected values. Callable values are treated as + predicates (e.g., ``lambda v: v != 0``). + """ + + name: str + expected_class: type + expected_type_name: str + builder: Callable[[GraphDef], tuple[Node, dict]] + reconstructed_class: type | None = None + + @property + def roundtrip_class(self): + """Class expected after reconstruction from the driver.""" + return self.reconstructed_class or self.expected_class + + +def _build_empty_node(g): + a = g.alloc(ALLOC_SIZE) + b = g.alloc(ALLOC_SIZE) + return g.join(a, b), {} + + +def _build_kernel_node(g): + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + config = LaunchConfig(grid=(2, 3, 1), block=(32, 4, 1), shmem_size=128) + entry = g.alloc(ALLOC_SIZE) + node = entry.launch(config, kernel) + return node, { + "grid": (2, 3, 1), + "block": (32, 4, 1), + "shmem_size": 128, + "kernel": kernel, + "config": config, + } + + +def _build_alloc_node(g): + device_id = Device().device_id + entry = g.alloc(ALLOC_SIZE) + node = entry.alloc(ALLOC_SIZE) + return node, { + "dptr": lambda v: v != 0, + "bytesize": ALLOC_SIZE, + "device_id": device_id, + "memory_type": "device", + "peer_access": (), + "options": GraphAllocOptions(device=device_id, memory_type="device"), + } + + +def _build_alloc_managed_node(g): + device_id = Device().device_id + options = GraphAllocOptions(memory_type="managed") + entry = g.alloc(ALLOC_SIZE) + node = entry.alloc(ALLOC_SIZE, options) + return node, { + "dptr": lambda v: v != 0, + "bytesize": ALLOC_SIZE, + "device_id": device_id, + "memory_type": "managed", + "peer_access": (), + "options": GraphAllocOptions(device=device_id, memory_type="managed"), + } + + +def _build_free_node(g): + alloc = g.alloc(ALLOC_SIZE) + node = alloc.free(alloc.dptr) + return node, { + "dptr": alloc.dptr, + } + + +def _build_memset_node(g): + alloc = g.alloc(ALLOC_SIZE) + node = alloc.memset(alloc.dptr, 42, ALLOC_SIZE) + return node, { + "dptr": alloc.dptr, + "value": 42, + "element_size": 1, + "width": ALLOC_SIZE, + "height": 1, + "pitch": 0, + } + + +def _build_memset_node_u16(g): + alloc = g.alloc(ALLOC_SIZE) + node = alloc.memset(alloc.dptr, b"\xab\xcd", ALLOC_SIZE // 2) + return node, { + "dptr": alloc.dptr, + "value": int.from_bytes(b"\xab\xcd", byteorder="little"), + "element_size": 2, + "width": ALLOC_SIZE // 2, + "height": 1, + "pitch": 0, + } + + +def _build_memset_node_u32(g): + alloc = g.alloc(ALLOC_SIZE) + node = alloc.memset(alloc.dptr, b"\x01\x02\x03\x04", ALLOC_SIZE // 4) + return node, { + "dptr": alloc.dptr, + "value": int.from_bytes(b"\x01\x02\x03\x04", byteorder="little"), + "element_size": 4, + "width": ALLOC_SIZE // 4, + "height": 1, + "pitch": 0, + } + + +def _build_memset_node_2d(g): + rows = 4 + cols = ALLOC_SIZE // rows + alloc = g.alloc(ALLOC_SIZE) + node = alloc.memset(alloc.dptr, 0xFF, cols, height=rows, pitch=cols) + return node, { + "dptr": alloc.dptr, + "value": 0xFF, + "element_size": 1, + "width": cols, + "height": rows, + "pitch": cols, + } + + +def _build_event_record_node(g): + event = Device().create_event() + entry = g.alloc(ALLOC_SIZE) + node = entry.record_event(event) + return node, { + "event": event, + } + + +def _build_event_wait_node(g): + event = Device().create_event() + entry = g.alloc(ALLOC_SIZE) + node = entry.wait_event(event) + return node, { + "event": event, + } + + +def _build_memcpy_node(g): + src_alloc = g.alloc(ALLOC_SIZE) + dst_alloc = g.alloc(ALLOC_SIZE) + dep = g.join(src_alloc, dst_alloc) + node = dep.memcpy(dst_alloc.dptr, src_alloc.dptr, ALLOC_SIZE) + return node, { + "dst": dst_alloc.dptr, + "src": src_alloc.dptr, + "size": ALLOC_SIZE, + } + + +def _build_host_callback_node(g): + def my_callback(): + pass + + node = g.callback(my_callback) + return node, { + "callback_fn": lambda v: v is my_callback, + } + + +def _build_host_callback_cfunc_node(g): + import ctypes + + CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + + @CALLBACK + def noop(data): + pass + + node = g.callback(noop) + return node, {} + + +def _build_child_graph_node(g): + child = GraphDef() + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + config = LaunchConfig(grid=1, block=1) + child.launch(config, kernel) + child.launch(config, kernel) + node = g.embed(child) + return node, { + "child_graph": lambda v: isinstance(v, GraphDef) and len(v.nodes()) == 2, + } + + +def _build_if_cond_node(g): + condition = try_create_condition(g) + node = g.if_cond(condition) + return node, { + "condition": condition, + "cond_type": "if", + "branches": lambda v: isinstance(v, tuple) and len(v) == 1, + "then": lambda v: isinstance(v, GraphDef), + } + + +def _build_if_else_node(g): + condition = try_create_condition(g) + node = g.if_else(condition) + return node, { + "condition": condition, + "cond_type": "if", + "branches": lambda v: isinstance(v, tuple) and len(v) == 2, + "then": lambda v: isinstance(v, GraphDef), + "else_": lambda v: isinstance(v, GraphDef), + } + + +def _build_while_loop_node(g): + condition = try_create_condition(g) + node = g.while_loop(condition) + return node, { + "condition": condition, + "cond_type": "while", + "branches": lambda v: isinstance(v, tuple) and len(v) == 1, + "body": lambda v: isinstance(v, GraphDef), + } + + +def _build_switch_node(g): + condition = try_create_condition(g) + node = g.switch(condition, 3) + return node, { + "condition": condition, + "cond_type": "switch", + "branches": lambda v: isinstance(v, tuple) and len(v) == 3, + } + + +_NODE_SPECS = [ + pytest.param(NodeSpec("empty", EmptyNode, "CU_GRAPH_NODE_TYPE_EMPTY", _build_empty_node), id="empty"), + pytest.param(NodeSpec("kernel", KernelNode, "CU_GRAPH_NODE_TYPE_KERNEL", _build_kernel_node), id="kernel"), + pytest.param(NodeSpec("alloc", AllocNode, "CU_GRAPH_NODE_TYPE_MEM_ALLOC", _build_alloc_node), id="alloc"), + pytest.param( + NodeSpec("alloc_managed", AllocNode, "CU_GRAPH_NODE_TYPE_MEM_ALLOC", _build_alloc_managed_node), + id="alloc_managed", + ), + pytest.param(NodeSpec("free", FreeNode, "CU_GRAPH_NODE_TYPE_MEM_FREE", _build_free_node), id="free"), + pytest.param(NodeSpec("memset", MemsetNode, "CU_GRAPH_NODE_TYPE_MEMSET", _build_memset_node), id="memset"), + pytest.param( + NodeSpec("memset_u16", MemsetNode, "CU_GRAPH_NODE_TYPE_MEMSET", _build_memset_node_u16), id="memset_u16" + ), + pytest.param( + NodeSpec("memset_u32", MemsetNode, "CU_GRAPH_NODE_TYPE_MEMSET", _build_memset_node_u32), id="memset_u32" + ), + pytest.param(NodeSpec("memset_2d", MemsetNode, "CU_GRAPH_NODE_TYPE_MEMSET", _build_memset_node_2d), id="memset_2d"), + pytest.param( + NodeSpec("memcpy", MemcpyNode, "CU_GRAPH_NODE_TYPE_MEMCPY", _build_memcpy_node), + id="memcpy", + ), + pytest.param( + NodeSpec("child_graph", ChildGraphNode, "CU_GRAPH_NODE_TYPE_GRAPH", _build_child_graph_node), + id="child_graph", + ), + pytest.param( + NodeSpec("host_callback", HostCallbackNode, "CU_GRAPH_NODE_TYPE_HOST", _build_host_callback_node), + id="host_callback", + ), + pytest.param( + NodeSpec("host_callback_cfunc", HostCallbackNode, "CU_GRAPH_NODE_TYPE_HOST", _build_host_callback_cfunc_node), + id="host_callback_cfunc", + ), + pytest.param( + NodeSpec("event_record", EventRecordNode, "CU_GRAPH_NODE_TYPE_EVENT_RECORD", _build_event_record_node), + id="event_record", + ), + pytest.param( + NodeSpec("event_wait", EventWaitNode, "CU_GRAPH_NODE_TYPE_WAIT_EVENT", _build_event_wait_node), + id="event_wait", + ), + pytest.param( + NodeSpec( + "if_cond", + IfNode, + "CU_GRAPH_NODE_TYPE_CONDITIONAL", + _build_if_cond_node, + reconstructed_class=IfNode if _HAS_NODE_GET_PARAMS else ConditionalNode, + ), + id="if_cond", + ), + pytest.param( + NodeSpec( + "if_else", + IfElseNode, + "CU_GRAPH_NODE_TYPE_CONDITIONAL", + _build_if_else_node, + reconstructed_class=IfElseNode if _HAS_NODE_GET_PARAMS else ConditionalNode, + ), + id="if_else", + ), + pytest.param( + NodeSpec( + "while_loop", + WhileNode, + "CU_GRAPH_NODE_TYPE_CONDITIONAL", + _build_while_loop_node, + reconstructed_class=WhileNode if _HAS_NODE_GET_PARAMS else ConditionalNode, + ), + id="while_loop", + ), + pytest.param( + NodeSpec( + "switch", + SwitchNode, + "CU_GRAPH_NODE_TYPE_CONDITIONAL", + _build_switch_node, + reconstructed_class=SwitchNode if _HAS_NODE_GET_PARAMS else ConditionalNode, + ), + id="switch", + ), +] + + +@pytest.fixture(params=_NODE_SPECS) +def node_spec(request, init_cuda): + spec = request.param + g = GraphDef() + node, expected_attrs = spec.builder(g) + return spec, g, node, expected_attrs + + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def sample_graphdef(init_cuda): + """A sample GraphDef for standalone tests.""" + return GraphDef() + + +@pytest.fixture +def dot_file(tmp_path): + """Temporary DOT file path, cleaned up after test.""" + path = tmp_path / "graph.dot" + yield path + path.unlink(missing_ok=True) + + +# ============================================================================= +# Topology tests (parameterized over graph specs) +# ============================================================================= + + +def test_node_count(graph_spec): + """Graph contains the expected number of nodes.""" + assert len(graph_spec.graphdef.nodes()) == len(graph_spec.named_nodes) + + +def test_nodes_match(nonempty_graph_spec): + """nodes() returns exactly the expected nodes.""" + spec = nonempty_graph_spec + assert set(spec.graphdef.nodes()) == set(spec.named_nodes.values()) + + +def test_edges(graph_spec): + """edges() returns exactly the expected edges.""" + spec = graph_spec + node_to_name = {v: k for k, v in spec.named_nodes.items()} + actual = {(node_to_name[a], node_to_name[b]) for a, b in spec.graphdef.edges()} + assert actual == spec.expected_edges + + +def test_pred(nonempty_graph_spec): + """Each node has the expected predecessors.""" + spec = nonempty_graph_spec + node_to_name = {v: k for k, v in spec.named_nodes.items()} + for name, node in spec.named_nodes.items(): + actual = {node_to_name[p] for p in node.pred} + assert actual == spec.expected_pred[name], f"pred mismatch for node {name}" + + +def test_succ(nonempty_graph_spec): + """Each node has the expected successors.""" + spec = nonempty_graph_spec + node_to_name = {v: k for k, v in spec.named_nodes.items()} + for name, node in spec.named_nodes.items(): + actual = {node_to_name[s] for s in node.succ} + assert actual == spec.expected_succ[name], f"succ mismatch for node {name}" + + +def test_node_graph_property(nonempty_graph_spec): + """Every node's .graph property returns the parent GraphDef.""" + spec = nonempty_graph_spec + for name, node in spec.named_nodes.items(): + assert node.graph == spec.graphdef, f"graph mismatch for node {name}" + + +# ============================================================================= +# Node type tests (parameterized over node specs) +# ============================================================================= + + +def test_node_isinstance(node_spec): + """Node is an instance of the expected subclass.""" + spec, g, node, _ = node_spec + assert isinstance(node, spec.expected_class) + assert isinstance(node, Node) + + +def test_node_type_property(node_spec): + """Node.type returns the expected CUgraphNodeType.""" + spec, g, node, _ = node_spec + assert node.type.name == spec.expected_type_name + + +def test_node_type_preserved_by_nodes(node_spec): + """Node type is preserved when retrieved via graphdef.nodes().""" + spec, g, node, _ = node_spec + all_nodes = g.nodes() + matched = [n for n in all_nodes if n == node] + assert len(matched) == 1 + assert isinstance(matched[0], spec.roundtrip_class) + + +def test_node_type_preserved_by_pred_succ(node_spec): + """Node type is preserved when retrieved via pred/succ traversal.""" + spec, g, node, _ = node_spec + for predecessor in node.pred: + matched = [s for s in predecessor.succ if s == node] + assert len(matched) == 1 + assert isinstance(matched[0], spec.roundtrip_class) + + +def test_node_attrs(node_spec): + """Type-specific attributes have expected values after construction.""" + spec, g, node, expected_attrs = node_spec + if not expected_attrs: + pytest.skip("no type-specific attributes") + for attr, expected in expected_attrs.items(): + actual = getattr(node, attr) + if callable(expected): + assert expected(actual), f"{spec.name}.{attr}: check failed (got {actual})" + else: + assert actual == expected, f"{spec.name}.{attr}: expected {expected}, got {actual}" + + +def test_node_attrs_preserved_by_nodes(node_spec): + """Type-specific attributes survive round-trip through graphdef.nodes().""" + spec, g, node, expected_attrs = node_spec + if not expected_attrs: + pytest.skip("no type-specific attributes") + if spec.roundtrip_class != spec.expected_class: + pytest.skip("reconstructed type differs — attrs not preserved") + retrieved = next(n for n in g.nodes() if n == node) + for attr in expected_attrs: + assert getattr(retrieved, attr) == getattr(node, attr), f"{spec.name}.{attr} not preserved by nodes()" + + +# ============================================================================= +# GraphDef basics +# ============================================================================= + + +def test_graphdef_handle_valid(sample_graphdef): + """GraphDef has a valid non-null handle.""" + assert sample_graphdef.handle is not None + assert int(sample_graphdef.handle) != 0 + + +def test_graphdef_entry_is_virtual(sample_graphdef): + """Internal entry node is virtual (no pred/succ, type is None).""" + entry = sample_graphdef._entry + assert isinstance(entry, Node) + assert entry.pred == () + assert entry.succ == () + assert entry.type is None + + +# ============================================================================= +# Alloc/free API +# ============================================================================= + + +def test_alloc_zero_size_fails(sample_graphdef): + """Alloc with zero size raises error (CUDA limitation).""" + from cuda.core._utils.cuda_utils import CUDAError + + with pytest.raises(CUDAError): + sample_graphdef.alloc(0) + + +def test_free_creates_dependency(sample_graphdef): + """Free node depends on its predecessor.""" + alloc = sample_graphdef.alloc(ALLOC_SIZE) + free = alloc.free(alloc.dptr) + assert alloc in free.pred + + +def test_alloc_free_chain(sample_graphdef): + """Alloc and free can be chained.""" + a1 = sample_graphdef.alloc(ALLOC_SIZE) + a2 = a1.alloc(ALLOC_SIZE) + f2 = a2.free(a2.dptr) + f1 = f2.free(a1.dptr) + assert a1 in a2.pred + assert a2 in f2.pred + assert f2 in f1.pred + + +# ============================================================================= +# Allocation options (error cases, input variants, multi-GPU) +# ============================================================================= + + +def test_alloc_memory_type_invalid(sample_graphdef): + """Invalid memory type raises ValueError.""" + options = GraphAllocOptions(memory_type="invalid") + with pytest.raises(ValueError, match="Invalid memory_type"): + sample_graphdef.alloc(ALLOC_SIZE, options) + + +@pytest.mark.parametrize( + "device_spec", + [ + pytest.param(lambda d: d.device_id, id="device_id"), + pytest.param(lambda d: d, id="Device_object"), + ], +) +def test_alloc_device_option(sample_graphdef, device_spec): + """Device can be specified as int or Device object.""" + device = Device() + options = GraphAllocOptions(device=device_spec(device)) + node = sample_graphdef.alloc(ALLOC_SIZE, options) + assert node.dptr != 0 + + +def test_alloc_peer_access(mempool_device_x2): + """AllocNode.peer_access reflects requested peers.""" + d0, d1 = mempool_device_x2 + g = GraphDef() + options = GraphAllocOptions(device=d0.device_id, peer_access=[d1.device_id]) + node = g.alloc(ALLOC_SIZE, options) + assert d1.device_id in node.peer_access + + +# ============================================================================= +# Join API +# ============================================================================= + + +@pytest.mark.parametrize("num_branches", [2, 3, 5]) +def test_join_merges_branches(sample_graphdef, num_branches): + """join() with multiple branches creates correct dependencies.""" + branches = [sample_graphdef.alloc(ALLOC_SIZE) for _ in range(num_branches)] + joined = sample_graphdef.join(*branches) + assert isinstance(joined, EmptyNode) + assert set(joined.pred) == set(branches) + + +# ============================================================================= +# Kernel launch +# ============================================================================= + + +def test_launch_creates_node(sample_graphdef): + """launch() creates a KernelNode.""" + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + config = LaunchConfig(grid=1, block=1) + node = sample_graphdef.launch(config, kernel) + assert isinstance(node, KernelNode) + + +def test_launch_chain_dependencies(sample_graphdef): + """Chained launches create correct dependencies.""" + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + config = LaunchConfig(grid=1, block=1) + n1 = sample_graphdef.launch(config, kernel) + n2 = n1.launch(config, kernel) + n3 = n2.launch(config, kernel) + assert n1 in n2.pred + assert n2 in n3.pred + assert n1 not in n3.pred + + +# ============================================================================= +# Instantiation and execution +# ============================================================================= + + +def test_instantiate_empty_graph(sample_graphdef): + """Empty graph can be instantiated.""" + graph = sample_graphdef.instantiate() + assert graph is not None + + +def test_instantiate_with_nodes(sample_graphdef): + """Graph with nodes can be instantiated.""" + sample_graphdef.alloc(ALLOC_SIZE) + sample_graphdef.alloc(ALLOC_SIZE) + graph = sample_graphdef.instantiate() + assert graph is not None + + +def test_instantiate_and_execute_kernel(sample_graphdef): + """Graph with kernel can be instantiated and executed.""" + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + config = LaunchConfig(grid=1, block=1) + sample_graphdef.launch(config, kernel) + graph = sample_graphdef.instantiate() + + stream = Device().create_stream() + graph.upload(stream) + graph.launch(stream) + stream.sync() + + +def test_instantiate_and_execute_alloc_free(sample_graphdef): + """Graph with alloc/free can be executed.""" + alloc = sample_graphdef.alloc(ALLOC_SIZE) + alloc.free(alloc.dptr) + graph = sample_graphdef.instantiate() + + stream = Device().create_stream() + graph.upload(stream) + graph.launch(stream) + stream.sync() + + +def test_instantiate_and_execute_memset(sample_graphdef): + """Graph with alloc/memset/free can be executed.""" + alloc = sample_graphdef.alloc(ALLOC_SIZE) + ms = alloc.memset(alloc.dptr, 0xAB, ALLOC_SIZE) + ms.free(alloc.dptr) + graph = sample_graphdef.instantiate() + + stream = Device().create_stream() + graph.upload(stream) + graph.launch(stream) + stream.sync() + + +def test_instantiate_and_execute_memcpy(sample_graphdef): + """Graph with alloc/memset/memcpy/free can be executed and data is copied.""" + import ctypes + + src_alloc = sample_graphdef.alloc(ALLOC_SIZE) + dst_alloc = sample_graphdef.alloc(ALLOC_SIZE) + dep = sample_graphdef.join(src_alloc, dst_alloc) + ms = dep.memset(src_alloc.dptr, 0xAB, ALLOC_SIZE) + cp = ms.memcpy(dst_alloc.dptr, src_alloc.dptr, ALLOC_SIZE) + cp.free(src_alloc.dptr) + + graph = sample_graphdef.instantiate() + stream = Device().create_stream() + graph.upload(stream) + graph.launch(stream) + stream.sync() + + host_buf = (ctypes.c_ubyte * ALLOC_SIZE)() + from cuda.bindings import driver as drv + + drv.cuMemcpyDtoH(host_buf, dst_alloc.dptr, ALLOC_SIZE) + assert all(b == 0xAB for b in host_buf) + + +def test_instantiate_and_execute_child_graph(sample_graphdef): + """Graph with embedded child graph can be executed.""" + child = GraphDef() + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + config = LaunchConfig(grid=1, block=1) + child.launch(config, kernel) + + sample_graphdef.embed(child) + graph = sample_graphdef.instantiate() + + stream = Device().create_stream() + graph.upload(stream) + graph.launch(stream) + stream.sync() + + +def test_instantiate_and_execute_host_callback(sample_graphdef): + """Graph with host callback can be executed and callback is invoked.""" + results = [] + + def my_callback(): + results.append(42) + + sample_graphdef.callback(my_callback) + graph = sample_graphdef.instantiate() + + stream = Device().create_stream() + graph.upload(stream) + graph.launch(stream) + stream.sync() + + assert results == [42] + + +def test_instantiate_and_execute_host_callback_cfunc(sample_graphdef): + """Graph with ctypes function pointer callback can be executed.""" + import ctypes + + CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + called = [False] + + @CALLBACK + def raw_fn(data): + called[0] = True + + sample_graphdef.callback(raw_fn) + graph = sample_graphdef.instantiate() + + stream = Device().create_stream() + graph.upload(stream) + graph.launch(stream) + stream.sync() + + assert called[0] + + +def test_host_callback_cfunc_with_user_data(sample_graphdef): + """Host callback with bytes user_data passes data to C function.""" + import ctypes + + CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + result = [0] + + @CALLBACK + def read_byte(data): + result[0] = ctypes.cast(data, ctypes.POINTER(ctypes.c_uint8))[0] + + sample_graphdef.callback(read_byte, user_data=bytes([0xAB])) + graph = sample_graphdef.instantiate() + + stream = Device().create_stream() + graph.upload(stream) + graph.launch(stream) + stream.sync() + + assert result[0] == 0xAB + + +def test_host_callback_user_data_rejected_for_python_callable(sample_graphdef): + """user_data is rejected for Python callables.""" + with pytest.raises(ValueError, match="user_data is only supported"): + sample_graphdef.callback(lambda: None, user_data=b"hello") + + +def test_instantiate_and_execute_event_record_wait(sample_graphdef): + """Graph with event record and wait nodes can be executed.""" + event = Device().create_event() + rec = sample_graphdef.record_event(event) + rec.wait_event(event) + graph = sample_graphdef.instantiate() + + stream = Device().create_stream() + graph.upload(stream) + graph.launch(stream) + stream.sync() + + +# ============================================================================= +# Conditional nodes +# ============================================================================= + + +def _skip_unless_cc_90(): + if Device(0).compute_capability < (9, 0): + pytest.skip("Conditional node execution requires CC >= 9.0 (Hopper)") + + +def test_instantiate_and_execute_if_cond(sample_graphdef): + """If-conditional node: body executes only when condition is non-zero.""" + _skip_unless_cc_90() + import ctypes + + from helpers.graph_kernels import compile_conditional_kernels + + condition = sample_graphdef.create_condition(default_value=0) + mod = compile_conditional_kernels(int) + set_handle = mod.get_kernel("set_handle") + add_one = mod.get_kernel("add_one") + + alloc = sample_graphdef.alloc(ctypes.sizeof(ctypes.c_int)) + ms = alloc.memset(alloc.dptr, 0, ctypes.sizeof(ctypes.c_int)) + setter = ms.launch(LaunchConfig(grid=1, block=1), set_handle, condition.handle, 1) + if_node = setter.if_cond(condition) + if_node.then.launch(LaunchConfig(grid=1, block=1), add_one, alloc.dptr) + + graph = sample_graphdef.instantiate() + stream = Device().create_stream() + graph.upload(stream) + graph.launch(stream) + stream.sync() + + result = (ctypes.c_int * 1)() + from cuda.bindings import driver as drv + + drv.cuMemcpyDtoH(result, alloc.dptr, ctypes.sizeof(ctypes.c_int)) + assert result[0] == 1 + + +def test_instantiate_and_execute_if_else(sample_graphdef): + """If-else node: then or else branch executes based on condition.""" + _skip_unless_cc_90() + import ctypes + + from helpers.graph_kernels import compile_conditional_kernels + + condition = sample_graphdef.create_condition(default_value=0) + mod = compile_conditional_kernels(int) + set_handle = mod.get_kernel("set_handle") + add_one = mod.get_kernel("add_one") + + alloc = sample_graphdef.alloc(ctypes.sizeof(ctypes.c_int)) + ms = alloc.memset(alloc.dptr, 0, ctypes.sizeof(ctypes.c_int)) + setter = ms.launch(LaunchConfig(grid=1, block=1), set_handle, condition.handle, 0) + ie_node = setter.if_else(condition) + ie_node.then.launch(LaunchConfig(grid=1, block=1), add_one, alloc.dptr) + n1 = ie_node.else_.launch(LaunchConfig(grid=1, block=1), add_one, alloc.dptr) + n1.launch(LaunchConfig(grid=1, block=1), add_one, alloc.dptr) + + graph = sample_graphdef.instantiate() + stream = Device().create_stream() + graph.upload(stream) + graph.launch(stream) + stream.sync() + + result = (ctypes.c_int * 1)() + from cuda.bindings import driver as drv + + drv.cuMemcpyDtoH(result, alloc.dptr, ctypes.sizeof(ctypes.c_int)) + assert result[0] == 2 + + +def test_instantiate_and_execute_switch(sample_graphdef): + """Switch node: selected branch executes based on condition value.""" + _skip_unless_cc_90() + import ctypes + + from helpers.graph_kernels import compile_conditional_kernels + + condition = sample_graphdef.create_condition(default_value=0) + mod = compile_conditional_kernels(int) + set_handle = mod.get_kernel("set_handle") + add_one = mod.get_kernel("add_one") + + alloc = sample_graphdef.alloc(ctypes.sizeof(ctypes.c_int)) + ms = alloc.memset(alloc.dptr, 0, ctypes.sizeof(ctypes.c_int)) + setter = ms.launch(LaunchConfig(grid=1, block=1), set_handle, condition.handle, 2) + sw_node = setter.switch(condition, 4) + for branch in sw_node.branches: + branch.launch(LaunchConfig(grid=1, block=1), add_one, alloc.dptr) + + graph = sample_graphdef.instantiate() + stream = Device().create_stream() + graph.upload(stream) + graph.launch(stream) + stream.sync() + + result = (ctypes.c_int * 1)() + from cuda.bindings import driver as drv + + drv.cuMemcpyDtoH(result, alloc.dptr, ctypes.sizeof(ctypes.c_int)) + assert result[0] == 1 + + +def test_conditional_node_type_preserved_by_nodes(sample_graphdef): + """Conditional nodes appear as ConditionalNode base when read back from graph.""" + condition = try_create_condition(sample_graphdef) + if_node = sample_graphdef.if_cond(condition) + assert isinstance(if_node, IfNode) + + all_nodes = sample_graphdef.nodes() + matched = [n for n in all_nodes if n == if_node] + assert len(matched) == 1 + assert isinstance(matched[0], ConditionalNode) + + +# ============================================================================= +# Debug output +# ============================================================================= + + +def test_debug_dot_print_creates_file(sample_graphdef, dot_file): + """debug_dot_print writes a DOT file.""" + sample_graphdef.alloc(ALLOC_SIZE) + sample_graphdef.debug_dot_print(str(dot_file)) + assert dot_file.exists() + content = dot_file.read_text() + assert "digraph" in content + + +def test_debug_dot_print_with_options(sample_graphdef, dot_file): + """debug_dot_print accepts GraphDebugPrintOptions.""" + sample_graphdef.alloc(ALLOC_SIZE) + options = GraphDebugPrintOptions(verbose=True, handles=True) + sample_graphdef.debug_dot_print(str(dot_file), options) + assert dot_file.exists() + + +def test_debug_dot_print_invalid_options(sample_graphdef, dot_file): + """debug_dot_print rejects invalid options type.""" + sample_graphdef.alloc(ALLOC_SIZE) + with pytest.raises(TypeError, match="options must be a GraphDebugPrintOptions"): + sample_graphdef.debug_dot_print(str(dot_file), "invalid") diff --git a/cuda_core/tests/graph/test_explicit_errors.py b/cuda_core/tests/graph/test_explicit_errors.py new file mode 100644 index 0000000000..e65dbe31d7 --- /dev/null +++ b/cuda_core/tests/graph/test_explicit_errors.py @@ -0,0 +1,242 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +"""Tests for error handling, input validation, and edge cases in explicit graphs. + +These tests verify that the explicit graph API properly validates inputs, +raises appropriate exceptions for misuse, and handles boundary conditions +correctly. +""" + +import ctypes + +import pytest +from helpers.graph_kernels import compile_common_kernels +from helpers.misc import try_create_condition + +from cuda.core import Device, LaunchConfig +from cuda.core._graph._graphdef import ( + Condition, + EmptyNode, + GraphDef, +) +from cuda.core._utils.cuda_utils import CUDAError + +SIZEOF_INT = ctypes.sizeof(ctypes.c_int) + + +# ============================================================================= +# Type validation — wrong types for conditional node methods +# ============================================================================= + + +@pytest.mark.parametrize( + "method, args", + [ + pytest.param("if_cond", (42,), id="if_cond_int"), + pytest.param("if_else", ("not a condition",), id="if_else_str"), + pytest.param("while_loop", (None,), id="while_loop_none"), + pytest.param("switch", ([1, 2, 3], 4), id="switch_list"), + ], +) +def test_conditional_rejects_non_condition(init_cuda, method, args): + """Conditional node methods reject non-Condition arguments.""" + g = GraphDef() + with pytest.raises(TypeError, match="Condition"): + getattr(g, method)(*args) + + +def test_embed_rejects_non_graphdef(init_cuda): + """embed() rejects non-GraphDef arguments.""" + g = GraphDef() + with pytest.raises((TypeError, AttributeError)): + g.embed("not a graph") + + +# ============================================================================= +# Value validation — invalid parameter values +# ============================================================================= + + +def test_free_null_pointer(init_cuda): + """free(0) raises a CUDA error.""" + g = GraphDef() + with pytest.raises(CUDAError): + g.free(0) + + +def test_memset_invalid_value_size(init_cuda): + """memset with 3-byte value (not 1, 2, or 4) raises ValueError.""" + g = GraphDef() + alloc = g.alloc(1024) + with pytest.raises(ValueError): + alloc.memset(alloc.dptr, b"\x01\x02\x03", 100) + + +def test_switch_zero_branches(init_cuda): + """switch with count=0 raises an error.""" + g = GraphDef() + condition = try_create_condition(g) + with pytest.raises(CUDAError): + g.switch(condition, 0) + + +# ============================================================================= +# Cross-graph misuse +# ============================================================================= + + +def test_condition_from_different_graph(init_cuda): + """Using a condition created for graph A in graph B raises an error.""" + g1 = GraphDef() + g2 = GraphDef() + condition = try_create_condition(g1) + with pytest.raises(CUDAError): + g2.if_cond(condition) + + +# ============================================================================= +# Edge cases — valid but unusual usage patterns +# ============================================================================= + + +def test_join_no_extra_nodes(init_cuda): + """join() from entry with no extra nodes creates a single empty node.""" + g = GraphDef() + joined = g.join() + assert isinstance(joined, EmptyNode) + assert len(g.nodes()) == 1 + + +def test_join_single_predecessor(init_cuda): + """node.join() with no extra args creates a single-dep empty node.""" + g = GraphDef() + a = g.alloc(1024) + joined = a.join() + assert isinstance(joined, EmptyNode) + assert set(joined.pred) == {a} + + +def test_multiple_instantiation(init_cuda): + """Same GraphDef can be instantiated multiple times independently.""" + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + cfg = LaunchConfig(grid=1, block=1) + + g = GraphDef() + g.launch(cfg, kernel) + g1 = g.instantiate() + g2 = g.instantiate() + assert g1 is not g2 + + +def test_unmatched_alloc_succeeds(init_cuda): + """Alloc without corresponding free is valid (graph-scoped lifetime).""" + g = GraphDef() + g.alloc(1024) + graph = g.instantiate() + stream = Device().create_stream() + graph.launch(stream) + stream.sync() + + +def test_create_condition_no_default_value(init_cuda): + """create_condition with no default_value succeeds.""" + g = GraphDef() + try: + condition = g.create_condition() + except CUDAError: + pytest.skip("Conditional nodes not supported (requires CC >= 9.0)") + assert isinstance(condition, Condition) + + +# ============================================================================= +# Boundary condition execution — conditional nodes with extreme values +# ============================================================================= + + +def _skip_unless_cc_90(): + if Device(0).compute_capability < (9, 0): + pytest.skip("Conditional node execution requires CC >= 9.0") + + +def test_while_loop_zero_iterations(init_cuda): + """While loop with default_value=0 never executes its body.""" + _skip_unless_cc_90() + + mod = compile_common_kernels() + add_one = mod.get_kernel("add_one") + cfg = LaunchConfig(grid=1, block=1) + + g = GraphDef() + condition = g.create_condition(default_value=0) + alloc = g.alloc(SIZEOF_INT) + ms = alloc.memset(alloc.dptr, 0, SIZEOF_INT) + loop = ms.while_loop(condition) + loop.body.launch(cfg, add_one, alloc.dptr) + + graph = g.instantiate() + stream = Device().create_stream() + graph.launch(stream) + stream.sync() + + result = (ctypes.c_int * 1)() + from cuda.bindings import driver as drv + + drv.cuMemcpyDtoH(result, alloc.dptr, SIZEOF_INT) + assert result[0] == 0, "Body should not have executed" + + +def test_if_cond_false_skips_body(init_cuda): + """If conditional with default_value=0 does not execute its body.""" + _skip_unless_cc_90() + + mod = compile_common_kernels() + add_one = mod.get_kernel("add_one") + cfg = LaunchConfig(grid=1, block=1) + + g = GraphDef() + condition = g.create_condition(default_value=0) + alloc = g.alloc(SIZEOF_INT) + ms = alloc.memset(alloc.dptr, 0, SIZEOF_INT) + if_node = ms.if_cond(condition) + if_node.then.launch(cfg, add_one, alloc.dptr) + + graph = g.instantiate() + stream = Device().create_stream() + graph.launch(stream) + stream.sync() + + result = (ctypes.c_int * 1)() + from cuda.bindings import driver as drv + + drv.cuMemcpyDtoH(result, alloc.dptr, SIZEOF_INT) + assert result[0] == 0, "Body should not have executed" + + +def test_switch_oob_skips_all_branches(init_cuda): + """Switch with out-of-range condition value does not execute any branch.""" + _skip_unless_cc_90() + + mod = compile_common_kernels() + add_one = mod.get_kernel("add_one") + cfg = LaunchConfig(grid=1, block=1) + + g = GraphDef() + condition = g.create_condition(default_value=99) + alloc = g.alloc(SIZEOF_INT) + ms = alloc.memset(alloc.dptr, 0, SIZEOF_INT) + sw = ms.switch(condition, 3) + for branch in sw.branches: + branch.launch(cfg, add_one, alloc.dptr) + + graph = g.instantiate() + stream = Device().create_stream() + graph.launch(stream) + stream.sync() + + result = (ctypes.c_int * 1)() + from cuda.bindings import driver as drv + + drv.cuMemcpyDtoH(result, alloc.dptr, SIZEOF_INT) + assert result[0] == 0, "No branch should have executed" diff --git a/cuda_core/tests/graph/test_explicit_integration.py b/cuda_core/tests/graph/test_explicit_integration.py new file mode 100644 index 0000000000..2595f4097b --- /dev/null +++ b/cuda_core/tests/graph/test_explicit_integration.py @@ -0,0 +1,465 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +"""Integration tests for explicit CUDA graph construction. + +Three test scenarios exercise complementary subsets of node types: + +test_heat_diffusion + 1D heat bar evolving toward steady state via finite differences. + Exercises: AllocNode, FreeNode, MemsetNode, ChildGraphNode, + EmptyNode, EventRecordNode, EventWaitNode, WhileNode, KernelNode, + MemcpyNode, HostCallbackNode. + +test_bisection_root + Find sqrt(2) by bisecting f(x) = x^2 - 2 on [0, 2], with an + optional Newton polish step. + Exercises: IfElseNode (interval halving), IfNode (refinement + guard), WhileNode, KernelNode, AllocNode, MemsetNode, MemcpyNode, + HostCallbackNode, FreeNode, EmptyNode. + +test_switch_dispatch + Apply one of four element-wise transforms selected at graph + creation time via a switch condition. + Exercises: SwitchNode, KernelNode, AllocNode, MemsetNode, + MemcpyNode, FreeNode. + +Together the three tests cover all 14 explicit-graph node types. +""" + +import ctypes + +import numpy as np +import pytest + +from cuda.core import Device, EventOptions, LaunchConfig, Program, ProgramOptions +from cuda.core._graph._graphdef import GraphDef +from cuda.core._utils.cuda_utils import driver, handle_return + +SIZEOF_FLOAT = 4 +SIZEOF_INT = 4 + +# =================================================================== +# Kernel sources +# =================================================================== + +_COND_PREAMBLE = r""" +extern "C" __device__ __cudart_builtin__ void CUDARTAPI +cudaGraphSetConditional(cudaGraphConditionalHandle handle, + unsigned int value); +""" + +_HEAT_KERNEL_SOURCE = ( + _COND_PREAMBLE + + r""" +extern "C" __global__ +void heat_step(float* u_next, const float* u_curr, int N, float alpha) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= N) return; + if (i == 0 || i == N - 1) + u_next[i] = u_curr[i]; + else + u_next[i] = u_curr[i] + + alpha * (u_curr[i-1] - 2.0f * u_curr[i] + u_curr[i+1]); +} + +extern "C" __global__ +void countdown(cudaGraphConditionalHandle handle, int* counter) { + int c = atomicSub(counter, 1); + cudaGraphSetConditional(handle, (c > 1) ? 1u : 0u); +} +""" +) + +_BISECT_KERNEL_SOURCE = ( + _COND_PREAMBLE + + r""" +extern "C" __global__ +void bisect_eval(float* a, float* b, + cudaGraphConditionalHandle ie_cond) { + float mid = (*a + *b) * 0.5f; + float fm = mid * mid - 2.0f; + cudaGraphSetConditional(ie_cond, (fm > 0.0f) ? 1u : 0u); +} + +extern "C" __global__ +void update_hi(float* a, float* b) { + *b = (*a + *b) * 0.5f; +} + +extern "C" __global__ +void update_lo(float* a, float* b) { + *a = (*a + *b) * 0.5f; +} + +extern "C" __global__ +void countdown(cudaGraphConditionalHandle handle, int* counter) { + int c = atomicSub(counter, 1); + cudaGraphSetConditional(handle, (c > 1) ? 1u : 0u); +} + +extern "C" __global__ +void check_refine(float* a, float* b, + cudaGraphConditionalHandle if_cond) { + float mid = (*a + *b) * 0.5f; + float fm = mid * mid - 2.0f; + float abs_fm = fm < 0.0f ? -fm : fm; + cudaGraphSetConditional(if_cond, (abs_fm > 1e-10f) ? 1u : 0u); +} + +extern "C" __global__ +void newton_refine(float* a, float* b) { + float mid = (*a + *b) * 0.5f; + float refined = mid - (mid * mid - 2.0f) / (2.0f * mid); + *a = refined; + *b = refined; +} +""" +) + +_SWITCH_KERNEL_SOURCE = r""" +extern "C" __global__ +void negate_it(int* x) { *x = -(*x); } + +extern "C" __global__ +void double_it(int* x) { *x = 2 * (*x); } + +extern "C" __global__ +void square_it(int* x) { *x = (*x) * (*x); } +""" + +# =================================================================== +# Compilation helpers +# =================================================================== + + +def _nvrtc_opts(): + arch = "".join(f"{i}" for i in Device().compute_capability) + return ProgramOptions(std="c++17", arch=f"sm_{arch}") + + +def _compile_heat_kernels(): + prog = Program(_HEAT_KERNEL_SOURCE, code_type="c++", options=_nvrtc_opts()) + try: + mod = prog.compile( + "cubin", + name_expressions=("heat_step", "countdown"), + ) + except Exception: + pytest.skip("NVRTC does not support cudaGraphConditionalHandle") + return mod.get_kernel("heat_step"), mod.get_kernel("countdown") + + +def _compile_bisect_kernels(): + names = ( + "bisect_eval", + "update_hi", + "update_lo", + "countdown", + "check_refine", + "newton_refine", + ) + prog = Program(_BISECT_KERNEL_SOURCE, code_type="c++", options=_nvrtc_opts()) + try: + mod = prog.compile("cubin", name_expressions=names) + except Exception: + pytest.skip("NVRTC does not support cudaGraphConditionalHandle") + return tuple(mod.get_kernel(n) for n in names) + + +def _compile_switch_kernels(): + names = ("negate_it", "double_it", "square_it") + prog = Program(_SWITCH_KERNEL_SOURCE, code_type="c++", options=_nvrtc_opts()) + mod = prog.compile("cubin", name_expressions=names) + return tuple(mod.get_kernel(n) for n in names) + + +# =================================================================== +# Test 1 — Heat diffusion (WhileNode, ChildGraphNode, EventNodes, …) +# +# alloc(curr) ─ memset(0) ──┐ +# alloc(next) ─ memset(0) ──┼─ join ─ embed(bc) ─ rec(start) ─ WHILE ──┐ +# alloc(ctr) ─ memset(50) ─┘ │ +# ┌─────────────────────────────────────────────────────────────────────┘ +# └─ wait(start) ─ rec(end) ─ memcpy(→host) ─ callback +# ─ free(curr) ─ free(next) ─ free(ctr) +# +# bc graph: memset(T_LEFT) ─ memset(T_RIGHT) +# while body: heat_step ─ memcpy(curr ← next) ─ countdown +# =================================================================== + +_HEAT_N = 32 +_HEAT_T_LEFT = np.float32(100.0) +_HEAT_T_RIGHT = np.float32(0.0) +_HEAT_ALPHA = np.float32(0.4) +_HEAT_ITERS = 50 + + +def _heat_reference(): + """Compute the reference heat solution on the host (NumPy).""" + u = np.zeros(_HEAT_N, dtype=np.float32) + u[0] = _HEAT_T_LEFT + u[-1] = _HEAT_T_RIGHT + u_next = np.empty_like(u) + for _ in range(_HEAT_ITERS): + u_next[0] = u[0] + u_next[-1] = u[-1] + u_next[1:-1] = u[1:-1] + _HEAT_ALPHA * (u[:-2] - 2.0 * u[1:-1] + u[2:]) + u, u_next = u_next, u + return u + + +def test_heat_diffusion(init_cuda): + """1D heat-bar simulation exercising most explicit-graph node types.""" + dev = Device() + + if dev.compute_capability < (9, 0): + pytest.skip("Conditional nodes require compute capability >= 9.0") + + k_heat, k_countdown = _compile_heat_kernels() + + host_ptr = handle_return(driver.cuMemAllocHost(_HEAT_N * SIZEOF_FLOAT)) + + try: + _run_heat_graph(dev, k_heat, k_countdown, host_ptr) + finally: + handle_return(driver.cuMemFreeHost(host_ptr)) + + +def _run_heat_graph(dev, k_heat, k_countdown, host_ptr): + """Build, instantiate, launch, and verify the heat-diffusion graph.""" + + # Definitions + g = GraphDef() + condition = g.create_condition(default_value=1) + event_start = dev.create_event(EventOptions(enable_timing=True)) + event_end = dev.create_event(EventOptions(enable_timing=True)) + results = {} + + def capture_result(): + arr = (ctypes.c_float * _HEAT_N).from_address(host_ptr) + results["data"] = np.array(arr, copy=True) + + block = min(_HEAT_N, 256) + grid = (_HEAT_N + block - 1) // block + heat_cfg = LaunchConfig(grid=grid, block=block) + tick_cfg = LaunchConfig(grid=1, block=1) + + # fmt: off + # Phase 1 — Allocate device memory + a_curr = g.alloc(_HEAT_N * SIZEOF_FLOAT) + a_next = g.alloc(_HEAT_N * SIZEOF_FLOAT) + a_ctr = g.alloc(SIZEOF_INT) + + # Phase 2 — Initialise buffers + m_curr = a_curr.memset(a_curr.dptr, 0, _HEAT_N * SIZEOF_FLOAT) + m_next = a_next.memset(a_next.dptr, 0, _HEAT_N * SIZEOF_FLOAT) + m_ctr = a_ctr.memset(a_ctr.dptr, np.int32(_HEAT_ITERS), 1) + + # Phase 3 — Boundary conditions (child graph) + bc = GraphDef() \ + .memset(a_curr.dptr, np.float32(_HEAT_T_LEFT), 1) \ + .memset(a_curr.dptr + (_HEAT_N - 1) * SIZEOF_FLOAT, + np.float32(_HEAT_T_RIGHT), 1) \ + .graph + p = g.join(m_curr, m_next, m_ctr) \ + .embed(bc) \ + .record_event(event_start) + + # Phase 4 — Iterate + loop = p.while_loop(condition) + loop.body.launch(heat_cfg, k_heat, a_next.dptr, a_curr.dptr, + np.int32(_HEAT_N), _HEAT_ALPHA) \ + .memcpy(a_curr.dptr, a_next.dptr, _HEAT_N * SIZEOF_FLOAT) \ + .launch(tick_cfg, k_countdown, condition.handle, a_ctr.dptr) + + # Phase 5 — After loop: timing end, readback, verify, free memory + loop.wait_event(event_start) \ + .record_event(event_end) \ + .memcpy(host_ptr, a_curr.dptr, _HEAT_N * SIZEOF_FLOAT) \ + .callback(capture_result) \ + .free(a_curr.dptr) \ + .free(a_next.dptr) \ + .free(a_ctr.dptr) + # fmt: on + + # Phase 6 — Instantiate, launch, verify + graph = g.instantiate() + stream = dev.create_stream() + graph.launch(stream) + stream.sync() + + assert "data" in results, "Host callback did not execute" + np.testing.assert_allclose(results["data"], _heat_reference(), rtol=1e-5) + + +# =================================================================== +# Test 2 — Bisection root finder (IfElseNode, IfNode) +# +# Find sqrt(2) by bisecting f(x) = x^2 - 2 on [0, 2]. +# +# alloc(a) ─ memset(0.0) ──┐ +# alloc(b) ─ memset(2.0) ──┼─ join ─ WHILE(while_cond) ──────────────────┐ +# alloc(ctr) ─ memset(20) ─┘ │ +# ┌───────────────────────────────────────────────────────────────────────┘ +# └─ check_refine ─ IF(if_cond) ─ memcpy(→host) ─ callback +# └─ body: newton_refine +# ─ free(a) ─ free(b) ─ free(ctr) +# +# while body: +# bisect_eval ─ IF_ELSE(ie_cond) ─ countdown +# ├─ then: update_hi (b = mid) [f(mid) > 0] +# └─ else: update_lo (a = mid) [f(mid) ≤ 0] +# =================================================================== + +_BISECT_ITERS = 20 + + +def test_bisection_root(init_cuda): + """Bisection search for sqrt(2) with optional Newton refinement. + + Exercises IfElseNode (interval halving) and IfNode (refinement guard). + """ + dev = Device() + + if dev.compute_capability < (9, 0): + pytest.skip("Conditional nodes require compute capability >= 9.0") + + k_eval, k_hi, k_lo, k_cd, k_check, k_newton = _compile_bisect_kernels() + + host_ptr = handle_return(driver.cuMemAllocHost(SIZEOF_FLOAT)) + + try: + _run_bisection_graph(dev, k_eval, k_hi, k_lo, k_cd, k_check, k_newton, host_ptr) + finally: + handle_return(driver.cuMemFreeHost(host_ptr)) + + +def _run_bisection_graph(dev, k_eval, k_hi, k_lo, k_cd, k_check, k_newton, host_ptr): + """Build, instantiate, launch, and verify the bisection graph.""" + + # Definitions + g = GraphDef() + cfg = LaunchConfig(grid=1, block=1) + results = {} + + def capture_result(): + results["root"] = ctypes.c_float.from_address(host_ptr).value + + # fmt: off + # Allocate and initialise: a = 0.0, b = 2.0, counter = ITERS + a = g.alloc(SIZEOF_FLOAT) + b = g.alloc(SIZEOF_FLOAT) + ctr = g.alloc(SIZEOF_INT) + + p = g.join(a.memset(a.dptr, np.float32(0.0), 1), + b.memset(b.dptr, np.float32(2.0), 1), + ctr.memset(ctr.dptr, np.int32(_BISECT_ITERS), 1)) + + # While loop: bisection iterations + while_cond = g.create_condition(default_value=1) + ie_cond = g.create_condition(default_value=0) + loop = p.while_loop(while_cond) + + ie = loop.body.launch(cfg, k_eval, a.dptr, b.dptr, ie_cond.handle) \ + .if_else(ie_cond) + ie.then.launch(cfg, k_hi, a.dptr, b.dptr) + ie.else_.launch(cfg, k_lo, a.dptr, b.dptr) + ie.launch(cfg, k_cd, while_cond.handle, ctr.dptr) + + # Post-loop: Newton refinement (IfNode), readback, free + if_cond = g.create_condition(default_value=0) + if_node = loop.launch(cfg, k_check, a.dptr, b.dptr, if_cond.handle) \ + .if_cond(if_cond) + if_node.then.launch(cfg, k_newton, a.dptr, b.dptr) + + if_node.memcpy(host_ptr, a.dptr, SIZEOF_FLOAT) \ + .callback(capture_result) \ + .free(a.dptr) \ + .free(b.dptr) \ + .free(ctr.dptr) + # fmt: on + + # Instantiate, launch, verify + graph = g.instantiate() + stream = dev.create_stream() + graph.launch(stream) + stream.sync() + + assert "root" in results, "Host callback did not execute" + np.testing.assert_allclose( + results["root"], + np.sqrt(np.float32(2.0)), + rtol=1e-6, + ) + + +# =================================================================== +# Test 3 — Switch dispatch (SwitchNode) +# +# A mode value (0-3) selects one of four transforms on a scalar: +# +# alloc(x) ─ memset(42) ─ SWITCH(mode, 4) +# ├─ 0: negate(x) +# ├─ 1: double(x) +# ├─ 2: square(x) +# └─ 3: (identity) +# ─ memcpy(→host) ─ free(x) +# =================================================================== + +_SWITCH_VALUE = 42 + + +@pytest.mark.parametrize( + "mode, expected", + [ + (0, -_SWITCH_VALUE), + (1, 2 * _SWITCH_VALUE), + (2, _SWITCH_VALUE * _SWITCH_VALUE), + (3, _SWITCH_VALUE), + ], +) +def test_switch_dispatch(init_cuda, mode, expected): + """Runtime kernel selection via SwitchNode.""" + dev = Device() + + if dev.compute_capability < (9, 0): + pytest.skip("Conditional nodes require compute capability >= 9.0") + + k_negate, k_double, k_square = _compile_switch_kernels() + + host_ptr = handle_return(driver.cuMemAllocHost(SIZEOF_INT)) + + try: + _run_switch_graph(dev, mode, k_negate, k_double, k_square, host_ptr) + + result = ctypes.c_int.from_address(host_ptr).value + assert result == expected + finally: + handle_return(driver.cuMemFreeHost(host_ptr)) + + +def _run_switch_graph(dev, mode, k_negate, k_double, k_square, host_ptr): + """Build, instantiate, launch, and verify the switch-dispatch graph.""" + g = GraphDef() + cfg = LaunchConfig(grid=1, block=1) + + # fmt: off + x = g.alloc(SIZEOF_INT) + sw_cond = g.create_condition(default_value=mode) + sw = x.memset(x.dptr, np.int32(_SWITCH_VALUE), 1) \ + .switch(sw_cond, 4) + + sw.branches[0].launch(cfg, k_negate, x.dptr) + sw.branches[1].launch(cfg, k_double, x.dptr) + sw.branches[2].launch(cfg, k_square, x.dptr) + # branch 3: identity (no kernel — value unchanged) + + sw.memcpy(host_ptr, x.dptr, SIZEOF_INT) \ + .free(x.dptr) + # fmt: on + + graph = g.instantiate() + stream = dev.create_stream() + graph.launch(stream) + stream.sync() diff --git a/cuda_core/tests/graph/test_explicit_lifetime.py b/cuda_core/tests/graph/test_explicit_lifetime.py new file mode 100644 index 0000000000..f355fa821d --- /dev/null +++ b/cuda_core/tests/graph/test_explicit_lifetime.py @@ -0,0 +1,404 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +"""Tests for resource lifetime management in explicit CUDA graphs. + +These tests verify that the RAII mechanism in GraphHandle correctly +prevents dangling references when parent Python objects are deleted +while child/body graph references remain alive. +""" + +import gc + +import pytest +from helpers.graph_kernels import compile_common_kernels +from helpers.misc import try_create_condition + +from cuda.core import Device, EventOptions, Kernel, LaunchConfig +from cuda.core._graph._graphdef import ( + ChildGraphNode, + ConditionalNode, + GraphDef, + KernelNode, +) + +# ============================================================================= +# Conditional body graph lifetime +# ============================================================================= + + +def _make_if(g, cond): + node = g.if_cond(cond) + return [node.then] + + +def _make_if_else(g, cond): + node = g.if_else(cond) + return [node.then, node.else_] + + +def _make_while(g, cond): + node = g.while_loop(cond) + return [node.body] + + +def _make_switch(g, cond): + node = g.switch(cond, 4) + return list(node.branches) + + +_COND_BUILDERS = [ + pytest.param(_make_if, 1, id="if"), + pytest.param(_make_if_else, 2, id="if_else"), + pytest.param(_make_while, 1, id="while"), + pytest.param(_make_switch, 4, id="switch"), +] + + +@pytest.mark.parametrize("builder, expected_count", _COND_BUILDERS) +def test_branches_survive_parent_deletion(init_cuda, builder, expected_count): + """All branch graphs remain valid after parent GraphDef is deleted.""" + g = GraphDef() + condition = try_create_condition(g) + branches = builder(g, condition) + assert len(branches) == expected_count + + del g, condition + gc.collect() + + for branch in branches: + assert branch.nodes() == () + + +@pytest.mark.parametrize("builder, expected_count", _COND_BUILDERS) +def test_branches_usable_after_parent_deletion(init_cuda, builder, expected_count): + """Nodes can be added to branch graphs after parent GraphDef is deleted.""" + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + config = LaunchConfig(grid=1, block=1) + + g = GraphDef() + condition = try_create_condition(g) + branches = builder(g, condition) + + del g, condition + gc.collect() + + for branch in branches: + branch.launch(config, kernel) + assert len(branch.nodes()) == 1 + + +def test_reconstructed_body_survives_parent_deletion(init_cuda): + """Body graph obtained via nodes() reconstruction survives parent deletion.""" + g = GraphDef() + condition = try_create_condition(g) + g.while_loop(condition) + + all_nodes = g.nodes() + cond_nodes = [n for n in all_nodes if isinstance(n, ConditionalNode)] + assert len(cond_nodes) == 1 + + branches = cond_nodes[0].branches + if not branches: + pytest.skip("Body reconstruction requires CUDA 13.2+") + body = branches[0] + + del g, condition, all_nodes, cond_nodes, branches + gc.collect() + + assert body.nodes() == () + + +# ============================================================================= +# Child graph (embed) lifetime +# ============================================================================= + + +def test_child_graph_survives_parent_deletion(init_cuda): + """Embedded child graph remains valid after parent GraphDef is deleted.""" + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + config = LaunchConfig(grid=1, block=1) + + child_def = GraphDef() + child_def.launch(config, kernel) + child_def.launch(config, kernel) + + g = GraphDef() + node = g.embed(child_def) + child_ref = node.child_graph + + del g, node, child_def + gc.collect() + + assert len(child_ref.nodes()) == 2 + + +def test_nested_child_graph_lifetime(init_cuda): + """Grandchild graph keeps entire ancestor chain alive.""" + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + config = LaunchConfig(grid=1, block=1) + + inner = GraphDef() + inner.launch(config, kernel) + + middle = GraphDef() + middle.embed(inner) + + outer = GraphDef() + outer_node = outer.embed(middle) + + middle_ref = outer_node.child_graph + middle_nodes = middle_ref.nodes() + child_node = next(n for n in middle_nodes if isinstance(n, ChildGraphNode)) + grandchild = child_node.child_graph + + del outer, outer_node, middle, inner, middle_ref, middle_nodes, child_node + gc.collect() + + assert len(grandchild.nodes()) == 1 + + +# ============================================================================= +# Event lifetime — event nodes should keep the Event alive +# ============================================================================= + + +def test_event_record_node_keeps_event_alive(init_cuda): + """EventRecordNode should keep the Event alive after original is deleted.""" + dev = Device() + g = GraphDef() + alloc = g.alloc(1024) + + event = dev.create_event(EventOptions(enable_timing=False)) + node = alloc.record_event(event) + + del event + gc.collect() + + retrieved = node.event + assert retrieved.is_done is True + + +def test_event_wait_node_keeps_event_alive(init_cuda): + """EventWaitNode should keep the Event alive after original is deleted.""" + dev = Device() + g = GraphDef() + alloc = g.alloc(1024) + + event = dev.create_event(EventOptions(enable_timing=False)) + node = alloc.wait_event(event) + + del event + gc.collect() + + retrieved = node.event + assert retrieved.is_done is True + + +def test_event_record_node_preserves_metadata(init_cuda): + """Reconstructed EventRecordNode recovers full Event metadata via reverse lookup.""" + dev = Device() + g = GraphDef() + + event = dev.create_event(EventOptions(enable_timing=True, busy_waited_sync=True)) + node = g.record_event(event) + + reconstructed = node.event + assert reconstructed.is_timing_disabled is False + assert reconstructed.is_sync_busy_waited is True + assert reconstructed.is_ipc_enabled is False + assert reconstructed.device is not None + + +def test_event_wait_node_preserves_metadata(init_cuda): + """Reconstructed EventWaitNode recovers full Event metadata via reverse lookup.""" + dev = Device() + g = GraphDef() + + event = dev.create_event(EventOptions(enable_timing=False)) + node = g.wait_event(event) + + reconstructed = node.event + assert reconstructed.is_timing_disabled is True + assert reconstructed.is_sync_busy_waited is False + assert reconstructed.device is not None + + +def test_event_metadata_survives_gc(init_cuda): + """Event metadata is preserved through reverse lookup even after original is GC'd.""" + dev = Device() + g = GraphDef() + + event = dev.create_event(EventOptions(enable_timing=True, busy_waited_sync=True)) + node = g.record_event(event) + + del event + gc.collect() + + retrieved = node.event + assert retrieved.is_timing_disabled is False + assert retrieved.is_sync_busy_waited is True + assert retrieved.is_done is True + + +def test_event_survives_graph_instantiation_and_execution(init_cuda): + """Graph with event nodes executes correctly after original Event is deleted.""" + dev = Device() + g = GraphDef() + + event = dev.create_event(EventOptions(enable_timing=False)) + rec = g.record_event(event) + rec.wait_event(event) + + del event + gc.collect() + + graph = g.instantiate() + stream = dev.create_stream() + graph.launch(stream) + stream.sync() + + +def test_event_survives_graph_clone_and_execution(init_cuda): + """Cloned graph with event nodes executes after original Event is deleted. + + This is the critical test for CUDA User Objects: a graph clone does + not inherit Python-level references, so only user objects (which + propagate through cuGraphClone) can keep the event alive. + """ + from cuda.core._utils.cuda_utils import driver, handle_return + + dev = Device() + g = GraphDef() + + event = dev.create_event(EventOptions(enable_timing=False)) + rec = g.record_event(event) + rec.wait_event(event) + + cloned_cu_graph = handle_return(driver.cuGraphClone(driver.CUgraph(g.handle))) + + del event, g, rec + gc.collect() + + graph_exec = handle_return(driver.cuGraphInstantiate(cloned_cu_graph, 0)) + stream = dev.create_stream() + handle_return(driver.cuGraphLaunch(graph_exec, driver.CUstream(int(stream.handle)))) + stream.sync() + + +# ============================================================================= +# Kernel lifetime — kernel nodes should keep the Kernel/Module alive +# ============================================================================= + + +def test_kernel_node_keeps_kernel_alive(init_cuda): + """KernelNode should keep the Kernel alive after original is deleted.""" + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + config = LaunchConfig(grid=1, block=1) + + g = GraphDef() + node = g.launch(config, kernel) + + del kernel, mod + gc.collect() + + retrieved = node.kernel + assert retrieved.attributes.max_threads_per_block() > 0 + + +def test_kernel_survives_graph_instantiation_and_execution(init_cuda): + """Graph with kernel node executes correctly after Kernel/Module is deleted.""" + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + config = LaunchConfig(grid=1, block=1) + + g = GraphDef() + g.launch(config, kernel) + + del kernel, mod + gc.collect() + + graph = g.instantiate() + stream = Device().create_stream() + graph.launch(stream) + stream.sync() + + +def test_kernel_survives_graph_clone_and_execution(init_cuda): + """Cloned graph with kernel node executes after Kernel/Module is deleted. + + Validates that CUDA User Objects keep the kernel's library alive + through graph cloning (where Python-level references are lost). + """ + from cuda.core._utils.cuda_utils import driver, handle_return + + dev = Device() + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + config = LaunchConfig(grid=1, block=1) + + g = GraphDef() + g.launch(config, kernel) + + cloned_cu_graph = handle_return(driver.cuGraphClone(driver.CUgraph(g.handle))) + + del kernel, mod, g + gc.collect() + + graph_exec = handle_return(driver.cuGraphInstantiate(cloned_cu_graph, 0)) + stream = dev.create_stream() + handle_return(driver.cuGraphLaunch(graph_exec, driver.CUstream(int(stream.handle)))) + stream.sync() + + +# ============================================================================= +# Kernel handle recovery — from_handle and graph node reconstruction +# ============================================================================= + + +def test_kernel_from_handle_recovers_library(init_cuda): + """Kernel.from_handle on a cuda.core-created kernel recovers the library + dependency, keeping it alive after the original objects are deleted.""" + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + handle = int(kernel.handle) + + reconstructed = Kernel.from_handle(handle) + + del kernel, mod + gc.collect() + + assert reconstructed.attributes.max_threads_per_block() > 0 + + +def test_kernel_node_reconstruction_preserves_validity(init_cuda): + """A KernelNode reconstructed via DAG traversal has a valid kernel, + kept alive by user objects and existing node references.""" + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + config = LaunchConfig(grid=1, block=1) + + g = GraphDef() + kernel_node = g.launch(config, kernel) + # Chain a second node so we can reconstruct the kernel node via pred + event = Device().create_event() + successor = kernel_node.record_event(event) + + del kernel, mod + gc.collect() + + # Reconstruct the kernel node through DAG traversal + # successor.pred -> Node._create -> KernelNode._create_from_driver + # -> create_kernel_handle_ref -> handle recovery + reconstructed = successor.pred[0] + assert isinstance(reconstructed, KernelNode) + assert reconstructed.kernel.attributes.max_threads_per_block() > 0 + + graph = g.instantiate() + stream = Device().create_stream() + graph.launch(stream) + stream.sync() diff --git a/cuda_core/tests/helpers/misc.py b/cuda_core/tests/helpers/misc.py index aa5757c4ce..6b83c751ab 100644 --- a/cuda_core/tests/helpers/misc.py +++ b/cuda_core/tests/helpers/misc.py @@ -1,6 +1,18 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import pytest + + +def try_create_condition(g, default_value=1): + """Create a Condition on graph *g*, skipping the test if unsupported.""" + from cuda.core._utils.cuda_utils import CUDAError + + try: + return g.create_condition(default_value=default_value) + except CUDAError: + pytest.skip("Conditional nodes not supported (requires CC >= 9.0)") + class StreamWrapper: """ diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index e74b1fc672..2bc7e25d21 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -511,6 +511,42 @@ def test_kernel_from_handle_multiple_instances(get_saxpy_kernel_cubin): assert int(kernel1.handle) == int(kernel2.handle) == int(kernel3.handle) == handle +def test_kernel_from_handle_library_mismatch_warning(init_cuda): + """Kernel.from_handle warns when caller-supplied module differs from the kernel's library.""" + prog1 = Program(SAXPY_KERNEL, code_type="c++") + mod1 = prog1.compile("cubin", name_expressions=("saxpy",)) + kernel = mod1.get_kernel("saxpy") + handle = int(kernel.handle) + + prog2 = Program(SAXPY_KERNEL, code_type="c++") + mod2 = prog2.compile("cubin", name_expressions=("saxpy",)) + mod2.get_kernel("saxpy") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + k = Kernel.from_handle(handle, mod2) + assert len(w) == 1 + assert "does not match" in str(w[0].message) + + assert k.attributes.max_threads_per_block() > 0 + + +def test_kernel_from_handle_foreign_kernel(init_cuda): + """Kernel.from_handle with a driver-level kernel not created by cuda.core.""" + prog = Program(SAXPY_KERNEL, code_type="c++") + mod = prog.compile("cubin", name_expressions=("saxpy",)) + cubin = mod.code + sym_map = mod.symbol_mapping + + cu_lib = handle_return(driver.cuLibraryLoadData(cubin, [], [], 0, [], [], 0)) + mangled = sym_map["saxpy"] + cu_kernel = handle_return(driver.cuLibraryGetKernel(cu_lib, mangled)) + handle = int(cu_kernel) + + k = Kernel.from_handle(handle) + assert k.attributes.max_threads_per_block() > 0 + + def test_kernel_keeps_library_alive(init_cuda): """Test that a Kernel keeps its underlying library alive after ObjectCode goes out of scope.""" import gc diff --git a/cuda_core/tests/test_object_protocols.py b/cuda_core/tests/test_object_protocols.py index fa35a3887e..82a7cff1d4 100644 --- a/cuda_core/tests/test_object_protocols.py +++ b/cuda_core/tests/test_object_protocols.py @@ -12,8 +12,11 @@ import weakref import pytest +from helpers.graph_kernels import compile_common_kernels +from helpers.misc import try_create_condition from cuda.core import Buffer, Device, Kernel, LaunchConfig, Program, Stream, system +from cuda.core._graph._graphdef import GraphDef from cuda.core._program import _can_load_generated_ptx # ============================================================================= @@ -199,6 +202,265 @@ def sample_kernel_alt(sample_object_code_alt): return sample_object_code_alt.get_kernel("test_kernel_alt") +# ============================================================================= +# Fixtures - Graph types (GraphDef and Node) +# ============================================================================= + +ALLOC_SIZE = 1024 + + +@pytest.fixture +def sample_graphdef(init_cuda): + """A sample GraphDef.""" + return GraphDef() + + +@pytest.fixture +def sample_graphdef_alt(init_cuda): + """An alternate GraphDef (for inequality testing).""" + return GraphDef() + + +@pytest.fixture +def sample_root_node(sample_graphdef): + """An entry Node (virtual, NULL handle).""" + return sample_graphdef._entry + + +@pytest.fixture +def sample_root_node_alt(sample_graphdef_alt): + """An alternate entry Node from different graph.""" + return sample_graphdef_alt._entry + + +@pytest.fixture +def sample_empty_node(sample_graphdef): + """An EmptyNode created by merging two branches.""" + a = sample_graphdef.alloc(ALLOC_SIZE) + b = sample_graphdef.alloc(ALLOC_SIZE) + return sample_graphdef.join(a, b) + + +@pytest.fixture +def sample_empty_node_alt(sample_graphdef): + """An alternate EmptyNode from same graph.""" + c = sample_graphdef.alloc(ALLOC_SIZE) + d = sample_graphdef.alloc(ALLOC_SIZE) + return sample_graphdef.join(c, d) + + +@pytest.fixture +def sample_alloc_node(sample_graphdef): + """An AllocNode.""" + return sample_graphdef.alloc(ALLOC_SIZE) + + +@pytest.fixture +def sample_alloc_node_alt(sample_graphdef): + """An alternate AllocNode from same graph.""" + return sample_graphdef.alloc(ALLOC_SIZE) + + +@pytest.fixture +def sample_kernel_node(sample_graphdef, init_cuda): + """A KernelNode.""" + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + config = LaunchConfig(grid=1, block=1) + return sample_graphdef.launch(config, kernel) + + +@pytest.fixture +def sample_kernel_node_alt(sample_graphdef, init_cuda): + """An alternate KernelNode from same graph.""" + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + config = LaunchConfig(grid=1, block=1) + return sample_graphdef.launch(config, kernel) + + +@pytest.fixture +def sample_free_node(sample_graphdef): + """A FreeNode.""" + alloc = sample_graphdef.alloc(ALLOC_SIZE) + return alloc.free(alloc.dptr) + + +@pytest.fixture +def sample_free_node_alt(sample_graphdef): + """An alternate FreeNode from same graph.""" + alloc = sample_graphdef.alloc(ALLOC_SIZE) + return alloc.free(alloc.dptr) + + +@pytest.fixture +def sample_memset_node(sample_graphdef): + """A MemsetNode.""" + alloc = sample_graphdef.alloc(ALLOC_SIZE) + return alloc.memset(alloc.dptr, 0, ALLOC_SIZE) + + +@pytest.fixture +def sample_memset_node_alt(sample_graphdef): + """An alternate MemsetNode from same graph.""" + alloc = sample_graphdef.alloc(ALLOC_SIZE) + return alloc.memset(alloc.dptr, 0, ALLOC_SIZE) + + +@pytest.fixture +def sample_memcpy_node(sample_graphdef): + """A MemcpyNode.""" + src = sample_graphdef.alloc(ALLOC_SIZE) + dst = sample_graphdef.alloc(ALLOC_SIZE) + dep = sample_graphdef.join(src, dst) + return dep.memcpy(dst.dptr, src.dptr, ALLOC_SIZE) + + +@pytest.fixture +def sample_memcpy_node_alt(sample_graphdef): + """An alternate MemcpyNode from same graph.""" + src = sample_graphdef.alloc(ALLOC_SIZE) + dst = sample_graphdef.alloc(ALLOC_SIZE) + dep = sample_graphdef.join(src, dst) + return dep.memcpy(dst.dptr, src.dptr, ALLOC_SIZE) + + +@pytest.fixture +def sample_child_graph_node(sample_graphdef): + """A ChildGraphNode.""" + child = GraphDef() + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + child.launch(LaunchConfig(grid=1, block=1), kernel) + return sample_graphdef.embed(child) + + +@pytest.fixture +def sample_child_graph_node_alt(sample_graphdef): + """An alternate ChildGraphNode from same graph.""" + child = GraphDef() + mod = compile_common_kernels() + kernel = mod.get_kernel("empty_kernel") + child.launch(LaunchConfig(grid=1, block=1), kernel) + return sample_graphdef.embed(child) + + +@pytest.fixture +def sample_event_record_node(sample_graphdef, sample_device): + """An EventRecordNode.""" + event = sample_device.create_event() + return sample_graphdef.record_event(event) + + +@pytest.fixture +def sample_event_record_node_alt(sample_graphdef, sample_device): + """An alternate EventRecordNode from same graph.""" + event = sample_device.create_event() + return sample_graphdef.record_event(event) + + +@pytest.fixture +def sample_event_wait_node(sample_graphdef, sample_device): + """An EventWaitNode.""" + event = sample_device.create_event() + return sample_graphdef.wait_event(event) + + +@pytest.fixture +def sample_event_wait_node_alt(sample_graphdef, sample_device): + """An alternate EventWaitNode from same graph.""" + event = sample_device.create_event() + return sample_graphdef.wait_event(event) + + +@pytest.fixture +def sample_host_callback_node(sample_graphdef): + """A HostCallbackNode.""" + + def my_callback(): + pass + + return sample_graphdef.callback(my_callback) + + +@pytest.fixture +def sample_host_callback_node_alt(sample_graphdef): + """An alternate HostCallbackNode from same graph.""" + + def other_callback(): + pass + + return sample_graphdef.callback(other_callback) + + +@pytest.fixture +def sample_condition(sample_graphdef): + """A Condition object.""" + return try_create_condition(sample_graphdef) + + +@pytest.fixture +def sample_condition_alt(sample_graphdef): + """An alternate Condition from same graph.""" + return try_create_condition(sample_graphdef) + + +@pytest.fixture +def sample_if_node(sample_graphdef): + """An IfNode.""" + condition = try_create_condition(sample_graphdef) + return sample_graphdef.if_cond(condition) + + +@pytest.fixture +def sample_if_node_alt(sample_graphdef): + """An alternate IfNode from same graph.""" + condition = try_create_condition(sample_graphdef) + return sample_graphdef.if_cond(condition) + + +@pytest.fixture +def sample_if_else_node(sample_graphdef): + """An IfElseNode.""" + condition = try_create_condition(sample_graphdef) + return sample_graphdef.if_else(condition) + + +@pytest.fixture +def sample_if_else_node_alt(sample_graphdef): + """An alternate IfElseNode from same graph.""" + condition = try_create_condition(sample_graphdef) + return sample_graphdef.if_else(condition) + + +@pytest.fixture +def sample_while_node(sample_graphdef): + """A WhileNode.""" + condition = try_create_condition(sample_graphdef) + return sample_graphdef.while_loop(condition) + + +@pytest.fixture +def sample_while_node_alt(sample_graphdef): + """An alternate WhileNode from same graph.""" + condition = try_create_condition(sample_graphdef) + return sample_graphdef.while_loop(condition) + + +@pytest.fixture +def sample_switch_node(sample_graphdef): + """A SwitchNode.""" + condition = try_create_condition(sample_graphdef) + return sample_graphdef.switch(condition, 3) + + +@pytest.fixture +def sample_switch_node_alt(sample_graphdef): + """An alternate SwitchNode from same graph.""" + condition = try_create_condition(sample_graphdef) + return sample_graphdef.switch(condition, 3) + + # ============================================================================= # Type groupings # ============================================================================= @@ -213,6 +475,23 @@ def sample_kernel_alt(sample_object_code_alt): "sample_launch_config", "sample_object_code_cubin", "sample_kernel", + "sample_graphdef", + "sample_condition", + "sample_root_node", + "sample_empty_node", + "sample_alloc_node", + "sample_kernel_node", + "sample_free_node", + "sample_memset_node", + "sample_memcpy_node", + "sample_child_graph_node", + "sample_event_record_node", + "sample_event_wait_node", + "sample_host_callback_node", + "sample_if_node", + "sample_if_else_node", + "sample_while_node", + "sample_switch_node", ] # Types with __eq__ support @@ -225,6 +504,23 @@ def sample_kernel_alt(sample_object_code_alt): "sample_launch_config", "sample_object_code_cubin", "sample_kernel", + "sample_graphdef", + "sample_condition", + "sample_root_node", + "sample_empty_node", + "sample_alloc_node", + "sample_kernel_node", + "sample_free_node", + "sample_memset_node", + "sample_memcpy_node", + "sample_child_graph_node", + "sample_event_record_node", + "sample_event_wait_node", + "sample_host_callback_node", + "sample_if_node", + "sample_if_else_node", + "sample_while_node", + "sample_switch_node", ] # Types with __weakref__ support @@ -233,11 +529,28 @@ def sample_kernel_alt(sample_object_code_alt): "sample_stream", "sample_event", "sample_context", + "sample_condition", "sample_buffer", "sample_launch_config", "sample_object_code_cubin", "sample_kernel", "sample_program_nvrtc", + "sample_graphdef", + "sample_root_node", + "sample_empty_node", + "sample_alloc_node", + "sample_kernel_node", + "sample_free_node", + "sample_memset_node", + "sample_memcpy_node", + "sample_child_graph_node", + "sample_event_record_node", + "sample_event_wait_node", + "sample_host_callback_node", + "sample_if_node", + "sample_if_else_node", + "sample_while_node", + "sample_switch_node", ] # Pairs of distinct objects of the same type (for inequality testing) @@ -251,6 +564,23 @@ def sample_kernel_alt(sample_object_code_alt): ("sample_launch_config", "sample_launch_config_alt"), ("sample_object_code_cubin", "sample_object_code_alt"), ("sample_kernel", "sample_kernel_alt"), + ("sample_graphdef", "sample_graphdef_alt"), + ("sample_condition", "sample_condition_alt"), + ("sample_root_node", "sample_root_node_alt"), + ("sample_empty_node", "sample_empty_node_alt"), + ("sample_alloc_node", "sample_alloc_node_alt"), + ("sample_kernel_node", "sample_kernel_node_alt"), + ("sample_free_node", "sample_free_node_alt"), + ("sample_memset_node", "sample_memset_node_alt"), + ("sample_memcpy_node", "sample_memcpy_node_alt"), + ("sample_child_graph_node", "sample_child_graph_node_alt"), + ("sample_event_record_node", "sample_event_record_node_alt"), + ("sample_event_wait_node", "sample_event_wait_node_alt"), + ("sample_host_callback_node", "sample_host_callback_node_alt"), + ("sample_if_node", "sample_if_node_alt"), + ("sample_if_else_node", "sample_if_else_node_alt"), + ("sample_while_node", "sample_while_node_alt"), + ("sample_switch_node", "sample_switch_node_alt"), ] # Types with public from_handle methods and how to create a copy @@ -286,6 +616,24 @@ def sample_kernel_alt(sample_object_code_alt): ("sample_program_nvrtc", r""), ("sample_program_ptx", r""), ("sample_program_nvvm", r""), + # Graph types + ("sample_graphdef", r""), + ("sample_condition", r""), + ("sample_root_node", r""), + ("sample_empty_node", r""), + ("sample_alloc_node", r""), + ("sample_kernel_node", r""), + ("sample_free_node", r""), + ("sample_memset_node", r""), + ("sample_memcpy_node", r""), + ("sample_child_graph_node", r""), + ("sample_event_record_node", r""), + ("sample_event_wait_node", r""), + ("sample_host_callback_node", r""), + ("sample_if_node", r""), + ("sample_if_else_node", r""), + ("sample_while_node", r""), + ("sample_switch_node", r""), ]