Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions cuda_core/cuda/core/_cpp/resource_handles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel = nullptr;

// Graph
decltype(&cuGraphDestroy) p_cuGraphDestroy = nullptr;
decltype(&cuGraphExecDestroy) p_cuGraphExecDestroy = nullptr;

// Linker
decltype(&cuLinkDestroy) p_cuLinkDestroy = nullptr;
Expand Down Expand Up @@ -1126,6 +1127,28 @@ GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent)
return GraphHandle(box, &box->resource);
}

// ============================================================================
// Graph Exec Handles
// ============================================================================

namespace {
struct GraphExecBox {
CUgraphExec resource;
};
} // namespace
Comment on lines +1134 to +1138

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Design question, not a blocker: this is the only one of the graph-family boxes that holds no extra state — GraphBox tracks h_parent, GraphNodeBox tracks h_graph, but GraphExecBox has just the raw CUgraphExec. That's semantically correct today: after cuGraphInstantiateWithParams, the exec is independent of its source CUgraph, and cuGraphExecUpdate takes the new source at call time, so nothing needs to be held.

But the PR description calls this "groundwork for step 3 of #1330 (graph updates)". If that work ends up wanting to remember which GraphBuilder / CUgraph an exec was last updated from (for diagnostics, error messages, or to extend the source's lifetime through the exec), this Box is the natural place to grow a GraphHandle h_source_graph;. Worth either a one-line // TODO: noting that, or a confirmation that no such reference is planned and removing the asymmetry from any reviewer's mental model.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed: no source-graph ref on GraphExecBox is planned. The exec is independent post-instantiate and Graph.update() supplies the source at call time. Clarification: step 3 lays the foundation for source graph updates; exec graph updates are step 5.


GraphExecHandle create_graph_exec_handle(CUgraphExec graph_exec) {
auto box = std::shared_ptr<const GraphExecBox>(
new GraphExecBox{graph_exec},
[](const GraphExecBox* b) {
GILReleaseGuard gil;
p_cuGraphExecDestroy(b->resource);
delete b;
}
);
return GraphExecHandle(box, &box->resource);
}

namespace {
struct GraphNodeBox {
mutable CUgraphNode resource;
Expand Down
22 changes: 22 additions & 0 deletions cuda_core/cuda/core/_cpp/resource_handles.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ extern decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel;

// Graph
extern decltype(&cuGraphDestroy) p_cuGraphDestroy;
extern decltype(&cuGraphExecDestroy) p_cuGraphExecDestroy;

// Linker
extern decltype(&cuLinkDestroy) p_cuLinkDestroy;
Expand Down Expand Up @@ -164,6 +165,7 @@ using MemoryPoolHandle = std::shared_ptr<const CUmemoryPool>;
using LibraryHandle = std::shared_ptr<const CUlibrary>;
using KernelHandle = std::shared_ptr<const CUkernel>;
using GraphHandle = std::shared_ptr<const CUgraph>;
using GraphExecHandle = std::shared_ptr<const CUgraphExec>;
using GraphNodeHandle = std::shared_ptr<const CUgraphNode>;
using GraphicsResourceHandle = std::shared_ptr<const CUgraphicsResource>;
using NvrtcProgramHandle = std::shared_ptr<const nvrtcProgram>;
Expand Down Expand Up @@ -441,6 +443,14 @@ GraphHandle create_graph_handle(CUgraph graph);
// but h_parent will be prevented from destruction while this handle exists.
GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent);

// ============================================================================
// Graph exec handle functions
// ============================================================================

// Wrap an externally-created CUgraphExec with RAII cleanup.
// When the last reference is released, cuGraphExecDestroy is called automatically.
GraphExecHandle create_graph_exec_handle(CUgraphExec graph_exec);

// ============================================================================
// Graph node handle functions
// ============================================================================
Expand Down Expand Up @@ -571,6 +581,10 @@ inline CUgraph as_cu(const GraphHandle& h) noexcept {
return h ? *h : nullptr;
}

inline CUgraphExec as_cu(const GraphExecHandle& h) noexcept {
return h ? *h : nullptr;
}

inline CUgraphNode as_cu(const GraphNodeHandle& h) noexcept {
return h ? *h : nullptr;
}
Expand Down Expand Up @@ -633,6 +647,10 @@ inline std::intptr_t as_intptr(const GraphHandle& h) noexcept {
return reinterpret_cast<std::intptr_t>(as_cu(h));
}

inline std::intptr_t as_intptr(const GraphExecHandle& h) noexcept {
return reinterpret_cast<std::intptr_t>(as_cu(h));
}

inline std::intptr_t as_intptr(const GraphNodeHandle& h) noexcept {
return reinterpret_cast<std::intptr_t>(as_cu(h));
}
Expand Down Expand Up @@ -743,6 +761,10 @@ inline PyObject* as_py(const GraphHandle& h) noexcept {
return detail::make_py("cuda.bindings.driver", "CUgraph", as_intptr(h));
}

inline PyObject* as_py(const GraphExecHandle& h) noexcept {
return detail::make_py("cuda.bindings.driver", "CUgraphExec", as_intptr(h));
}

inline PyObject* as_py(const GraphNodeHandle& h) noexcept {
if (!as_intptr(h)) {
Py_RETURN_NONE;
Expand Down
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/_device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1454,7 +1454,7 @@ class Device:
from cuda.core.graph._graph_builder import GraphBuilder

self._check_context_initialized()
return GraphBuilder._init(stream=self.create_stream(), is_stream_owner=True)
return GraphBuilder._init(self.create_stream())


cdef inline int Device_ensure_cuda_initialized() except? -1:
Expand Down
7 changes: 7 additions & 0 deletions cuda_core/cuda/core/_resource_handles.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
ctypedef shared_ptr[const cydriver.CUlibrary] LibraryHandle
ctypedef shared_ptr[const cydriver.CUkernel] KernelHandle
ctypedef shared_ptr[const cydriver.CUgraph] GraphHandle
ctypedef shared_ptr[const cydriver.CUgraphExec] GraphExecHandle
ctypedef shared_ptr[const cydriver.CUgraphNode] GraphNodeHandle
ctypedef shared_ptr[const cydriver.CUgraphicsResource] GraphicsResourceHandle
ctypedef shared_ptr[const cynvrtc.nvrtcProgram] NvrtcProgramHandle
Expand All @@ -54,6 +55,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
cydriver.CUlibrary as_cu(LibraryHandle h) noexcept nogil
cydriver.CUkernel as_cu(KernelHandle h) noexcept nogil
cydriver.CUgraph as_cu(GraphHandle h) noexcept nogil
cydriver.CUgraphExec as_cu(GraphExecHandle h) noexcept nogil
cydriver.CUgraphNode as_cu(GraphNodeHandle h) noexcept nogil
cydriver.CUgraphicsResource as_cu(GraphicsResourceHandle h) noexcept nogil
cynvrtc.nvrtcProgram as_cu(NvrtcProgramHandle h) noexcept nogil
Expand All @@ -71,6 +73,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
intptr_t as_intptr(LibraryHandle h) noexcept nogil
intptr_t as_intptr(KernelHandle h) noexcept nogil
intptr_t as_intptr(GraphHandle h) noexcept nogil
intptr_t as_intptr(GraphExecHandle h) noexcept nogil
intptr_t as_intptr(GraphNodeHandle h) noexcept nogil
intptr_t as_intptr(GraphicsResourceHandle h) noexcept nogil
intptr_t as_intptr(NvrtcProgramHandle h) noexcept nogil
Expand All @@ -89,6 +92,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
object as_py(LibraryHandle h)
object as_py(KernelHandle h)
object as_py(GraphHandle h)
object as_py(GraphExecHandle h)
object as_py(GraphNodeHandle h)
object as_py(GraphicsResourceHandle h)
object as_py(NvrtcProgramHandle h)
Expand Down Expand Up @@ -195,6 +199,9 @@ cdef LibraryHandle get_kernel_library(const KernelHandle& h) noexcept nogil
cdef GraphHandle create_graph_handle(cydriver.CUgraph graph) except+ nogil
cdef GraphHandle create_graph_handle_ref(cydriver.CUgraph graph, const GraphHandle& h_parent) except+ nogil

# Graph exec handles
cdef GraphExecHandle create_graph_exec_handle(cydriver.CUgraphExec graph_exec) except+ nogil

# Graph node handles
cdef GraphNodeHandle create_graph_node_handle(cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil
cdef GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept nogil
Expand Down
1 change: 1 addition & 0 deletions cuda_core/cuda/core/_resource_handles.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ DevicePtrHandle = shared_ptr
LibraryHandle = shared_ptr
KernelHandle = shared_ptr
GraphHandle = shared_ptr
GraphExecHandle = shared_ptr
GraphNodeHandle = shared_ptr
GraphicsResourceHandle = shared_ptr
NvrtcProgramHandle = shared_ptr
Expand Down
8 changes: 7 additions & 1 deletion cuda_core/cuda/core/_resource_handles.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
GraphHandle create_graph_handle_ref "cuda_core::create_graph_handle_ref" (
cydriver.CUgraph graph, const GraphHandle& h_parent) except+ nogil

# Graph exec handles
GraphExecHandle create_graph_exec_handle "cuda_core::create_graph_exec_handle" (
cydriver.CUgraphExec graph_exec) except+ nogil

# Graph node handles
GraphNodeHandle create_graph_node_handle "cuda_core::create_graph_node_handle" (
cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil
Expand Down Expand Up @@ -276,6 +280,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":

# Graph
void* p_cuGraphDestroy "reinterpret_cast<void*&>(cuda_core::p_cuGraphDestroy)"
void* p_cuGraphExecDestroy "reinterpret_cast<void*&>(cuda_core::p_cuGraphExecDestroy)"

# Linker
void* p_cuLinkDestroy "reinterpret_cast<void*&>(cuda_core::p_cuLinkDestroy)"
Expand Down Expand Up @@ -324,7 +329,7 @@ cdef void _init_driver_fn_pointers() noexcept:
global p_cuMemFreeAsync, p_cuMemFree, p_cuMemFreeHost
global p_cuMemPoolImportPointer
global p_cuLibraryLoadFromFile, p_cuLibraryLoadData, p_cuLibraryUnload, p_cuLibraryGetKernel
global p_cuGraphDestroy
global p_cuGraphDestroy, p_cuGraphExecDestroy
global p_cuLinkDestroy
global p_cuGraphicsUnmapResources, p_cuGraphicsUnregisterResource
global p_cuDevSmResourceSplit
Expand Down Expand Up @@ -380,6 +385,7 @@ cdef void _init_driver_fn_pointers() noexcept:

# Graph
p_cuGraphDestroy = _get_driver_fn("cuGraphDestroy")
p_cuGraphExecDestroy = _get_driver_fn("cuGraphExecDestroy")

# Linker
p_cuLinkDestroy = _get_driver_fn("cuLinkDestroy")
Expand Down
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/_stream.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ cdef class Stream:
"""
from cuda.core.graph._graph_builder import GraphBuilder

return GraphBuilder._init(stream=self, is_stream_owner=False)
return GraphBuilder._init(self)


LEGACY_DEFAULT_STREAM: Stream = Stream._legacy_default()
Expand Down
27 changes: 27 additions & 0 deletions cuda_core/cuda/core/graph/_graph_builder.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

from cuda.bindings cimport cydriver

from cuda.core._resource_handles cimport GraphExecHandle, GraphHandle, StreamHandle
from cuda.core._stream cimport Stream


cdef class GraphBuilder:
cdef:
GraphHandle _h_graph
StreamHandle _h_stream
int _kind
int _state
Stream _stream # cached to avoid reconstruction from _h_stream handle
object __weakref__


cdef class Graph:
cdef:
GraphExecHandle _h_graph_exec
object __weakref__

@staticmethod
cdef Graph _init(cydriver.CUgraphExec graph_exec)
55 changes: 14 additions & 41 deletions cuda_core/cuda/core/graph/_graph_builder.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ from cuda.core._stream import Stream
from cuda.core._utils.cuda_utils import driver
from cuda.core.graph._graph_definition import GraphCondition, GraphDefinition

_BuilderKind = int
_CaptureState = int

@dataclass
class GraphDebugPrintOptions:
Expand Down Expand Up @@ -106,23 +108,19 @@ class GraphBuilder:

"""

class _MembersNeededForFinalize:
__slots__ = ('conditional_graph', 'graph', 'is_join_required', 'is_stream_owner', 'stream')

def __init__(self, graph_builder_obj: GraphBuilder, stream_obj: Stream | None, is_stream_owner: bool, conditional_graph, is_join_required: bool) -> None:
...

def close(self) -> None:
...
__slots__ = ('__weakref__', '_building_ended', '_mnff')
def __init__(self):
...

def __init__(self) -> None:
def __dealloc__(self):
...

@classmethod
def _init(cls, stream: Stream | None, is_stream_owner: bool, conditional_graph: object=None, is_join_required: bool=False) -> GraphBuilder:
@staticmethod
def _init(stream: Stream):
...

def close(self):
"""Destroy the graph builder."""

@property
def stream(self) -> Stream:
"""Returns the stream associated with the graph builder."""
Expand Down Expand Up @@ -155,7 +153,7 @@ class GraphBuilder:
def end_building(self) -> GraphBuilder:
"""Ends the building process."""

def complete(self, options: GraphCompleteOptions | None=None) -> 'Graph':
def complete(self, options: GraphCompleteOptions | None=None) -> Graph:
"""Completes the graph builder and returns the built :obj:`~graph.Graph` object.

Parameters
Expand Down Expand Up @@ -245,9 +243,6 @@ class GraphBuilder:
A condition variable for controlling conditional execution.
"""

def _cond_with_params(self, node_params: object) -> tuple[GraphBuilder, ...]:
...

def if_then(self, condition: GraphCondition) -> GraphBuilder:
"""Adds an if condition branch and returns a new graph builder for it.

Expand Down Expand Up @@ -335,15 +330,7 @@ class GraphBuilder:

"""

def close(self) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive-by comment to test the theory that the .pyi files help us see API changes.

I think dropping close here is API breakage -- if a user is calling it directly, they will no longer be able to. If that isn't resolvable to make everything in this PR work, then we need a "breaking changes" changelog entry and a deprecation decorator etc.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I stand corrected, I see this just got moved a few lines above.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good. I was worried for a moment!

"""Destroy the graph builder.

Closes the associated stream if we own it. Borrowed stream
object will instead have their references released.

"""

def embed(self, child: GraphBuilder) -> None:
def embed(self, child: GraphBuilder):
"""Embed a previously-built :obj:`~graph.GraphBuilder` as a child node.

Parameters
Expand Down Expand Up @@ -392,21 +379,7 @@ class Graph:

"""

class _MembersNeededForFinalize:
__slots__ = 'graph'

def __init__(self, graph_obj: Graph, graph: driver.CUgraphExec) -> None:
...

def close(self) -> None:
...
__slots__ = ('__weakref__', '_mnff')

def __init__(self) -> None:
...

@classmethod
def _init(cls, graph: driver.CUgraphExec) -> Graph:
def __init__(self):
...

def close(self) -> None:
Expand Down Expand Up @@ -457,5 +430,5 @@ class Graph:
"""
__all__ = ['Graph', 'GraphBuilder', 'GraphCompleteOptions', 'GraphDebugPrintOptions']

def _instantiate_graph(h_graph, options: GraphCompleteOptions | None=None) -> 'Graph':
def _instantiate_graph(h_graph, options: GraphCompleteOptions | None=None) -> Graph:
...
Loading
Loading