Skip to content
23 changes: 23 additions & 0 deletions cuda_core/cuda/core/_cpp/resource_handles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel = nullptr;

// Graph
decltype(&cuGraphDestroy) p_cuGraphDestroy = nullptr;
decltype(&cuGraphExecDestroy) p_cuGraphExecDestroy = nullptr;

// Linker
decltype(&cuLinkDestroy) p_cuLinkDestroy = nullptr;
Expand Down Expand Up @@ -952,6 +953,28 @@ GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent)
return GraphHandle(box, &box->resource);
}

// ============================================================================
// Graph Exec Handles
// ============================================================================

namespace {
struct GraphExecBox {
CUgraphExec resource;
};
} // namespace

GraphExecHandle create_graph_exec_handle(CUgraphExec graph_exec) {
auto box = std::shared_ptr<const GraphExecBox>(
new GraphExecBox{graph_exec},
[](const GraphExecBox* b) {
GILReleaseGuard gil;
p_cuGraphExecDestroy(b->resource);
delete b;
}
);
return GraphExecHandle(box, &box->resource);
}

namespace {
struct GraphNodeBox {
mutable CUgraphNode resource;
Expand Down
22 changes: 22 additions & 0 deletions cuda_core/cuda/core/_cpp/resource_handles.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ extern decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel;

// Graph
extern decltype(&cuGraphDestroy) p_cuGraphDestroy;
extern decltype(&cuGraphExecDestroy) p_cuGraphExecDestroy;

// Linker
extern decltype(&cuLinkDestroy) p_cuLinkDestroy;
Expand Down Expand Up @@ -148,6 +149,7 @@ using MemoryPoolHandle = std::shared_ptr<const CUmemoryPool>;
using LibraryHandle = std::shared_ptr<const CUlibrary>;
using KernelHandle = std::shared_ptr<const CUkernel>;
using GraphHandle = std::shared_ptr<const CUgraph>;
using GraphExecHandle = std::shared_ptr<const CUgraphExec>;
using GraphNodeHandle = std::shared_ptr<const CUgraphNode>;
using GraphicsResourceHandle = std::shared_ptr<const CUgraphicsResource>;
using NvrtcProgramHandle = std::shared_ptr<const nvrtcProgram>;
Expand Down Expand Up @@ -403,6 +405,14 @@ GraphHandle create_graph_handle(CUgraph graph);
// but h_parent will be prevented from destruction while this handle exists.
GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent);

// ============================================================================
// Graph exec handle functions
// ============================================================================

// Wrap an externally-created CUgraphExec with RAII cleanup.
// When the last reference is released, cuGraphExecDestroy is called automatically.
GraphExecHandle create_graph_exec_handle(CUgraphExec graph_exec);

// ============================================================================
// Graph node handle functions
// ============================================================================
Expand Down Expand Up @@ -529,6 +539,10 @@ inline CUgraph as_cu(const GraphHandle& h) noexcept {
return h ? *h : nullptr;
}

inline CUgraphExec as_cu(const GraphExecHandle& h) noexcept {
return h ? *h : nullptr;
}

inline CUgraphNode as_cu(const GraphNodeHandle& h) noexcept {
return h ? *h : nullptr;
}
Expand Down Expand Up @@ -587,6 +601,10 @@ inline std::intptr_t as_intptr(const GraphHandle& h) noexcept {
return reinterpret_cast<std::intptr_t>(as_cu(h));
}

inline std::intptr_t as_intptr(const GraphExecHandle& h) noexcept {
return reinterpret_cast<std::intptr_t>(as_cu(h));
}

inline std::intptr_t as_intptr(const GraphNodeHandle& h) noexcept {
return reinterpret_cast<std::intptr_t>(as_cu(h));
}
Expand Down Expand Up @@ -677,6 +695,10 @@ inline PyObject* as_py(const GraphHandle& h) noexcept {
return detail::make_py("cuda.bindings.driver", "CUgraph", as_intptr(h));
}

inline PyObject* as_py(const GraphExecHandle& h) noexcept {
return detail::make_py("cuda.bindings.driver", "CUgraphExec", as_intptr(h));
}

inline PyObject* as_py(const GraphNodeHandle& h) noexcept {
if (!as_intptr(h)) {
Py_RETURN_NONE;
Expand Down
4 changes: 2 additions & 2 deletions cuda_core/cuda/core/_device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1361,7 +1361,7 @@ class Device:
self._check_context_initialized()
handle_return(runtime.cudaDeviceSynchronize())

def create_graph_builder(self) -> "GraphBuilder":
def create_graph_builder(self) -> GraphBuilder:
"""Create a new :obj:`~graph.GraphBuilder` object.

Returns
Expand All @@ -1373,7 +1373,7 @@ class Device:
from cuda.core.graph._graph_builder import GraphBuilder

self._check_context_initialized()
return GraphBuilder._init(stream=self.create_stream(), is_stream_owner=True)
return GraphBuilder._init(self.create_stream())


cdef inline int Device_ensure_cuda_initialized() except? -1:
Expand Down
7 changes: 7 additions & 0 deletions cuda_core/cuda/core/_resource_handles.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
ctypedef shared_ptr[const cydriver.CUlibrary] LibraryHandle
ctypedef shared_ptr[const cydriver.CUkernel] KernelHandle
ctypedef shared_ptr[const cydriver.CUgraph] GraphHandle
ctypedef shared_ptr[const cydriver.CUgraphExec] GraphExecHandle
ctypedef shared_ptr[const cydriver.CUgraphNode] GraphNodeHandle
ctypedef shared_ptr[const cydriver.CUgraphicsResource] GraphicsResourceHandle
ctypedef shared_ptr[const cynvrtc.nvrtcProgram] NvrtcProgramHandle
Expand All @@ -52,6 +53,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
cydriver.CUlibrary as_cu(LibraryHandle h) noexcept nogil
cydriver.CUkernel as_cu(KernelHandle h) noexcept nogil
cydriver.CUgraph as_cu(GraphHandle h) noexcept nogil
cydriver.CUgraphExec as_cu(GraphExecHandle h) noexcept nogil
cydriver.CUgraphNode as_cu(GraphNodeHandle h) noexcept nogil
cydriver.CUgraphicsResource as_cu(GraphicsResourceHandle h) noexcept nogil
cynvrtc.nvrtcProgram as_cu(NvrtcProgramHandle h) noexcept nogil
Expand All @@ -68,6 +70,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
intptr_t as_intptr(LibraryHandle h) noexcept nogil
intptr_t as_intptr(KernelHandle h) noexcept nogil
intptr_t as_intptr(GraphHandle h) noexcept nogil
intptr_t as_intptr(GraphExecHandle h) noexcept nogil
intptr_t as_intptr(GraphNodeHandle h) noexcept nogil
intptr_t as_intptr(GraphicsResourceHandle h) noexcept nogil
intptr_t as_intptr(NvrtcProgramHandle h) noexcept nogil
Expand All @@ -85,6 +88,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
object as_py(LibraryHandle h)
object as_py(KernelHandle h)
object as_py(GraphHandle h)
object as_py(GraphExecHandle h)
object as_py(GraphNodeHandle h)
object as_py(GraphicsResourceHandle h)
object as_py(NvrtcProgramHandle h)
Expand Down Expand Up @@ -183,6 +187,9 @@ cdef LibraryHandle get_kernel_library(const KernelHandle& h) noexcept nogil
cdef GraphHandle create_graph_handle(cydriver.CUgraph graph) except+ nogil
cdef GraphHandle create_graph_handle_ref(cydriver.CUgraph graph, const GraphHandle& h_parent) except+ nogil

# Graph exec handles
cdef GraphExecHandle create_graph_exec_handle(cydriver.CUgraphExec graph_exec) except+ nogil

# Graph node handles
cdef GraphNodeHandle create_graph_node_handle(cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil
cdef GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept nogil
Expand Down
7 changes: 7 additions & 0 deletions cuda_core/cuda/core/_resource_handles.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ from ._resource_handles cimport (
LibraryHandle,
KernelHandle,
GraphHandle,
GraphExecHandle,
GraphicsResourceHandle,
NvrtcProgramHandle,
NvvmProgramHandle,
Expand Down Expand Up @@ -154,6 +155,10 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
GraphHandle create_graph_handle_ref "cuda_core::create_graph_handle_ref" (
cydriver.CUgraph graph, const GraphHandle& h_parent) except+ nogil

# Graph exec handles
GraphExecHandle create_graph_exec_handle "cuda_core::create_graph_exec_handle" (
cydriver.CUgraphExec graph_exec) except+ nogil

# Graph node handles
GraphNodeHandle create_graph_node_handle "cuda_core::create_graph_node_handle" (
cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil
Expand Down Expand Up @@ -265,6 +270,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":

# Graph
void* p_cuGraphDestroy "reinterpret_cast<void*&>(cuda_core::p_cuGraphDestroy)"
void* p_cuGraphExecDestroy "reinterpret_cast<void*&>(cuda_core::p_cuGraphExecDestroy)"

# Linker
void* p_cuLinkDestroy "reinterpret_cast<void*&>(cuda_core::p_cuLinkDestroy)"
Expand Down Expand Up @@ -334,6 +340,7 @@ p_cuLibraryGetKernel = _get_driver_fn("cuLibraryGetKernel")

# Graph
p_cuGraphDestroy = _get_driver_fn("cuGraphDestroy")
p_cuGraphExecDestroy = _get_driver_fn("cuGraphExecDestroy")

# Linker
p_cuLinkDestroy = _get_driver_fn("cuLinkDestroy")
Expand Down
4 changes: 2 additions & 2 deletions cuda_core/cuda/core/_stream.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ cdef class Stream:

return Stream._init(obj=_stream_holder())

def create_graph_builder(self) -> "GraphBuilder":
def create_graph_builder(self) -> GraphBuilder:
"""Create a new :obj:`~graph.GraphBuilder` object.

The new graph builder will be associated with this stream.
Expand All @@ -373,7 +373,7 @@ cdef class Stream:
"""
from cuda.core.graph._graph_builder import GraphBuilder

return GraphBuilder._init(stream=self, is_stream_owner=False)
return GraphBuilder._init(self)


# c-only python objects, not public
Expand Down
27 changes: 27 additions & 0 deletions cuda_core/cuda/core/graph/_graph_builder.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

from cuda.bindings cimport cydriver

from cuda.core._resource_handles cimport GraphExecHandle, GraphHandle, StreamHandle
from cuda.core._stream cimport Stream


cdef class GraphBuilder:
cdef:
GraphHandle _h_graph
StreamHandle _h_stream
int _kind
int _state
Stream _stream # cached to avoid reconstruction from _h_stream handle
object __weakref__


cdef class Graph:
cdef:
GraphExecHandle _h_graph_exec
object __weakref__

@staticmethod
cdef Graph _init(cydriver.CUgraphExec graph_exec)
Loading
Loading