diff --git a/cuda_core/cuda/core/_layout.pyx b/cuda_core/cuda/core/_layout.pyx index 796a6243fd4..3e2580d11d1 100644 --- a/cuda_core/cuda/core/_layout.pyx +++ b/cuda_core/cuda/core/_layout.pyx @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 diff --git a/cuda_core/cuda/core/_memory/_device_memory_resource.pxd b/cuda_core/cuda/core/_memory/_device_memory_resource.pxd index a7f3bfd9585..0b7cd941e18 100644 --- a/cuda_core/cuda/core/_memory/_device_memory_resource.pxd +++ b/cuda_core/cuda/core/_memory/_device_memory_resource.pxd @@ -9,7 +9,6 @@ from cuda.core._memory._ipc cimport IPCDataForMR cdef class DeviceMemoryResource(_MemPool): cdef: int _dev_id - object _peer_accessible_by cpdef DMR_mempool_get_access(DeviceMemoryResource, int) diff --git a/cuda_core/cuda/core/_memory/_device_memory_resource.pyx b/cuda_core/cuda/core/_memory/_device_memory_resource.pyx index c19d7358b00..b7b8b247a92 100644 --- a/cuda_core/cuda/core/_memory/_device_memory_resource.pyx +++ b/cuda_core/cuda/core/_memory/_device_memory_resource.pyx @@ -18,14 +18,12 @@ from cuda.core._utils.cuda_utils cimport ( check_or_create_options, HANDLE_RETURN, ) -from cpython.mem cimport PyMem_Malloc, PyMem_Free - from dataclasses import dataclass import multiprocessing import platform # no-cython-lint import uuid -from cuda.core._memory._peer_access_utils import plan_peer_access_update +from cuda.core._memory._peer_access_utils import PeerAccessibleBySetProxy, replace_peer_accessible_by from cuda.core._utils.cuda_utils import check_multiprocessing_start_method __all__ = ['DeviceMemoryResource', 'DeviceMemoryResourceOptions'] @@ -131,7 +129,6 @@ cdef class DeviceMemoryResource(_MemPool): def __cinit__(self, *args, **kwargs): self._dev_id = cydriver.CU_DEVICE_INVALID - self._peer_accessible_by = None def __init__(self, device_id: Device | int, options=None): _DMR_init(self, device_id, options) @@ -191,7 +188,6 @@ cdef class DeviceMemoryResource(_MemPool): _ipc.MP_from_allocation_handle(cls, alloc_handle)) from .._device import Device mr._dev_id = Device(device_id).device_id - mr._peer_accessible_by = () return mr @property @@ -217,30 +213,23 @@ cdef class DeviceMemoryResource(_MemPool): pool. Access can be modified at any time and affects all allocations from this memory pool. - Returns a tuple of sorted device IDs that currently have peer access to - allocations from this memory pool. - - When setting, accepts a sequence of :obj:`~_device.Device` objects or device IDs. - Setting to an empty sequence revokes all peer access. - - For non-owned pools (the default or current device pool), the state - is always queried from the driver to reflect changes made by other - wrappers or direct driver calls. + Returns a set-like proxy of :obj:`~_device.Device` objects that manages + peer access. Inputs are accepted as either :obj:`~_device.Device` + objects or device-ordinal :class:`int` values. Examples -------- >>> dmr = DeviceMemoryResource(0) - >>> dmr.peer_accessible_by = [1] # Grant access to device 1 - >>> assert dmr.peer_accessible_by == (1,) - >>> dmr.peer_accessible_by = [] # Revoke access + >>> dmr.peer_accessible_by = {1} # grant access to device 1 + >>> assert 1 in dmr.peer_accessible_by + >>> dmr.peer_accessible_by.add(2) # update access to include device 2 + >>> dmr.peer_accessible_by = [] # revoke peer access """ - if not self._mempool_owned: - _DMR_query_peer_access(self) - return self._peer_accessible_by + return PeerAccessibleBySetProxy(self) @peer_accessible_by.setter def peer_accessible_by(self, devices): - _DMR_set_peer_accessible_by(self, devices) + replace_peer_accessible_by(self, devices) @property def is_device_accessible(self) -> bool: @@ -253,81 +242,6 @@ cdef class DeviceMemoryResource(_MemPool): return False -cdef inline _DMR_query_peer_access(DeviceMemoryResource self): - """Query the driver for the actual peer access state of this pool.""" - cdef int total - cdef cydriver.CUmemAccess_flags flags - cdef cydriver.CUmemLocation location - cdef list peers = [] - - with nogil: - HANDLE_RETURN(cydriver.cuDeviceGetCount(&total)) - - location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE - for dev_id in range(total): - if dev_id == self._dev_id: - continue - location.id = dev_id - with nogil: - HANDLE_RETURN(cydriver.cuMemPoolGetAccess(&flags, as_cu(self._h_pool), &location)) - if flags == cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE: - peers.append(dev_id) - - self._peer_accessible_by = tuple(sorted(peers)) - - -cdef inline _DMR_set_peer_accessible_by(DeviceMemoryResource self, devices): - from .._device import Device - - this_dev = Device(self._dev_id) - cdef object resolve_device_id = lambda dev: Device(dev).device_id - cdef object plan - cdef tuple target_ids - cdef tuple to_add - cdef tuple to_rm - if not self._mempool_owned: - _DMR_query_peer_access(self) - plan = plan_peer_access_update( - owner_device_id=self._dev_id, - current_peer_ids=self._peer_accessible_by, - requested_devices=devices, - resolve_device_id=resolve_device_id, - can_access_peer=this_dev.can_access_peer, - ) - target_ids = plan.target_ids - to_add = plan.to_add - to_rm = plan.to_remove - cdef size_t count = len(to_add) + len(to_rm) - cdef cydriver.CUmemAccessDesc* access_desc = NULL - cdef size_t i = 0 - - if count > 0: - access_desc = PyMem_Malloc(count * sizeof(cydriver.CUmemAccessDesc)) - if access_desc == NULL: - raise MemoryError("Failed to allocate memory for access descriptors") - - try: - for dev_id in to_add: - access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE - access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE - access_desc[i].location.id = dev_id - i += 1 - - for dev_id in to_rm: - access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_NONE - access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE - access_desc[i].location.id = dev_id - i += 1 - - with nogil: - HANDLE_RETURN(cydriver.cuMemPoolSetAccess(as_cu(self._h_pool), access_desc, count)) - finally: - if access_desc != NULL: - PyMem_Free(access_desc) - - self._peer_accessible_by = tuple(target_ids) - - cdef inline _DMR_init(DeviceMemoryResource self, device_id, options): from .._device import Device cdef int dev_id = Device(device_id).device_id @@ -351,7 +265,6 @@ cdef inline _DMR_init(DeviceMemoryResource self, device_id, options): self._mempool_owned = False MP_raise_release_threshold(self) else: - self._peer_accessible_by = () MP_init_create_pool( self, cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE, diff --git a/cuda_core/cuda/core/_memory/_ipc.pyx b/cuda_core/cuda/core/_memory/_ipc.pyx index 1c7b25c14fb..59414fc1b2e 100644 --- a/cuda_core/cuda/core/_memory/_ipc.pyx +++ b/cuda_core/cuda/core/_memory/_ipc.pyx @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 diff --git a/cuda_core/cuda/core/_memory/_peer_access_utils.py b/cuda_core/cuda/core/_memory/_peer_access_utils.py deleted file mode 100644 index e08de69f2c7..00000000000 --- a/cuda_core/cuda/core/_memory/_peer_access_utils.py +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from collections.abc import Callable, Iterable -from dataclasses import dataclass - - -@dataclass(frozen=True) -class PeerAccessPlan: - """Normalized peer-access target state and the driver updates it requires.""" - - target_ids: tuple[int, ...] - to_add: tuple[int, ...] - to_remove: tuple[int, ...] - - -def normalize_peer_access_targets( - owner_device_id: int, - requested_devices: Iterable[object], - *, - resolve_device_id: Callable[[object], int], -) -> tuple[int, ...]: - """Return sorted, unique peer device IDs, excluding the owner device.""" - - target_ids = {resolve_device_id(device) for device in requested_devices} - target_ids.discard(owner_device_id) - return tuple(sorted(target_ids)) - - -def plan_peer_access_update( - owner_device_id: int, - current_peer_ids: Iterable[int], - requested_devices: Iterable[object], - *, - resolve_device_id: Callable[[object], int], - can_access_peer: Callable[[int], bool], -) -> PeerAccessPlan: - """Compute the peer-access target state and add/remove deltas.""" - - target_ids = normalize_peer_access_targets( - owner_device_id, - requested_devices, - resolve_device_id=resolve_device_id, - ) - bad = tuple(dev_id for dev_id in target_ids if not can_access_peer(dev_id)) - if bad: - bad_ids = ", ".join(str(dev_id) for dev_id in bad) - raise ValueError(f"Device {owner_device_id} cannot access peer(s): {bad_ids}") - - current_ids = set(current_peer_ids) - target_id_set = set(target_ids) - return PeerAccessPlan( - target_ids=target_ids, - to_add=tuple(sorted(target_id_set - current_ids)), - to_remove=tuple(sorted(current_ids - target_id_set)), - ) diff --git a/cuda_core/cuda/core/_memory/_peer_access_utils.pyx b/cuda_core/cuda/core/_memory/_peer_access_utils.pyx new file mode 100644 index 00000000000..8086aaff170 --- /dev/null +++ b/cuda_core/cuda/core/_memory/_peer_access_utils.pyx @@ -0,0 +1,406 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Callable, Iterable, MutableSet +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from cuda.bindings cimport cydriver +from cuda.core._memory._device_memory_resource cimport DeviceMemoryResource +from cuda.core._resource_handles cimport as_cu +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN +from cpython.mem cimport PyMem_Malloc, PyMem_Free +from libcpp.vector cimport vector + +if TYPE_CHECKING: + from cuda.core._device import Device + + +@dataclass(frozen=True) +class PeerAccessPlan: + """Normalized peer-access target state and the driver updates it requires.""" + + target_ids: tuple[int, ...] + to_add: tuple[int, ...] + to_remove: tuple[int, ...] + + +def normalize_peer_access_targets( + owner_device_id: int, + requested_devices: Iterable[object], + *, + resolve_device_id: Callable[[object], int], +) -> tuple[int, ...]: + """Return sorted, unique peer device IDs, excluding the owner device.""" + + target_ids = {resolve_device_id(device) for device in requested_devices} + target_ids.discard(owner_device_id) + return tuple(sorted(target_ids)) + + +def plan_peer_access_update( + owner_device_id: int, + current_peer_ids: Iterable[int], + requested_devices: Iterable[object], + *, + resolve_device_id: Callable[[object], int], + can_access_peer: Callable[[int], bool], +) -> PeerAccessPlan: + """Compute the peer-access target state and add/remove deltas.""" + + target_ids = normalize_peer_access_targets( + owner_device_id, + requested_devices, + resolve_device_id=resolve_device_id, + ) + bad = tuple(dev_id for dev_id in target_ids if not can_access_peer(dev_id)) + if bad: + bad_ids = ", ".join(str(dev_id) for dev_id in bad) + raise ValueError(f"Device {owner_device_id} cannot access peer(s): {bad_ids}") + + current_ids = set(current_peer_ids) + target_id_set = set(target_ids) + return PeerAccessPlan( + target_ids=target_ids, + to_add=tuple(sorted(target_id_set - current_ids)), + to_remove=tuple(sorted(current_ids - target_id_set)), + ) + + +def _resolve_peer_device_id(value): + """Coerce ``Device | int`` into a device-ordinal int.""" + from cuda.core._device import Device + + return Device(value).device_id + + +# ---- driver-touching helpers (cdef inline, called from .pyx code) ----------- + +cdef inline tuple _query_peer_access_ids(DeviceMemoryResource mr): + """Return the current peer device IDs as a sorted tuple of ints. + + The full driver loop runs inside a single ``nogil`` block. Because + ``range(total)`` ascends, the result is already sorted. + """ + cdef int total + cdef int dev_id + cdef int owner_id = mr._dev_id + cdef cydriver.CUmemAccess_flags flags + cdef cydriver.CUmemLocation location + cdef cydriver.CUmemoryPool h_pool = as_cu(mr._h_pool) + cdef vector[int] peers + cdef size_t i, n + + location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + + with nogil: + HANDLE_RETURN(cydriver.cuDeviceGetCount(&total)) + for dev_id in range(total): + if dev_id == owner_id: + continue + location.id = dev_id + HANDLE_RETURN(cydriver.cuMemPoolGetAccess(&flags, h_pool, &location)) + if flags == cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE: + peers.push_back(dev_id) + + n = peers.size() + return tuple(peers[i] for i in range(n)) + + +cdef inline bint _peer_access_includes(DeviceMemoryResource mr, int dev_id): + """Return True if peer access from ``dev_id`` is currently granted.""" + cdef cydriver.CUmemAccess_flags flags + cdef cydriver.CUmemLocation location + + location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + location.id = dev_id + with nogil: + HANDLE_RETURN(cydriver.cuMemPoolGetAccess(&flags, as_cu(mr._h_pool), &location)) + return flags == cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE + + +def _set_pool_access(mr, tuple to_add, tuple to_remove): + """Issue one ``cuMemPoolSetAccess`` for the given add/remove deltas. + + The thin Python-callable layer that wraps the actual driver call: building + the ``CUmemAccessDesc`` array and invoking ``cuMemPoolSetAccess`` happens + in here. Tests monkeypatch this on the module to spy on real driver work + without intercepting earlier no-op paths. + + Preconditions: ``len(to_add) + len(to_remove) > 0`` (the caller is + responsible for skipping empty diffs). + """ + cdef DeviceMemoryResource mr_typed = mr + cdef size_t count = len(to_add) + len(to_remove) + cdef cydriver.CUmemAccessDesc* access_desc = NULL + cdef size_t i = 0 + + access_desc = PyMem_Malloc(count * sizeof(cydriver.CUmemAccessDesc)) + if access_desc == NULL: + raise MemoryError("Failed to allocate memory for access descriptors") + + try: + for dev_id in to_add: + access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE + access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + access_desc[i].location.id = dev_id + i += 1 + for dev_id in to_remove: + access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_NONE + access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + access_desc[i].location.id = dev_id + i += 1 + + with nogil: + HANDLE_RETURN(cydriver.cuMemPoolSetAccess(as_cu(mr_typed._h_pool), access_desc, count)) + finally: + if access_desc != NULL: + PyMem_Free(access_desc) + + +def _apply_peer_access_diff(mr, to_add, to_remove): + """Apply a peer-access diff in at most one driver call. + + Every write path on :class:`PeerAccessibleBySetProxy` and the + ``peer_accessible_by`` setter routes through this function. Empty diffs + short-circuit here so the driver-level helper :func:`_set_pool_access` is + only invoked when there is actual work for ``cuMemPoolSetAccess`` to do. + """ + add_tuple = tuple(to_add) + remove_tuple = tuple(to_remove) + if not add_tuple and not remove_tuple: + return + _set_pool_access(mr, add_tuple, remove_tuple) + + +cpdef replace_peer_accessible_by(DeviceMemoryResource mr, devices): + """Replace the full peer-access set in a single batched driver call. + + Backs the ``mr.peer_accessible_by = [...]`` setter. Uses the same planner + as the proxy's bulk ops; the only difference is that adds and removes are + derived from the symmetric difference between current driver state and the + requested target set. + """ + from cuda.core._device import Device + + this_dev = Device(mr._dev_id) + plan = plan_peer_access_update( + owner_device_id=mr._dev_id, + current_peer_ids=_query_peer_access_ids(mr), + requested_devices=devices, + resolve_device_id=_resolve_peer_device_id, + can_access_peer=this_dev.can_access_peer, + ) + _apply_peer_access_diff(mr, plan.to_add, plan.to_remove) + + +# ---- Python MutableSet proxy ------------------------------------------------ + +class PeerAccessibleBySetProxy(MutableSet): + """Live driver-backed view of the peer devices granted access to a memory pool. + + Reads (``__contains__``, ``__iter__``, ``len(...)``) call ``cuMemPoolGetAccess``; + writes (``add``, ``discard``, and bulk ops) call ``cuMemPoolSetAccess``. There + is no in-memory mirror, so the view always reflects the current driver state + and stays consistent across multiple wrappers around the same pool. + + Iteration yields :class:`~cuda.core.Device` objects. ``add``, ``discard``, and + ``__contains__`` accept either a :class:`~cuda.core.Device` or a device-ordinal + ``int``; the owner device is silently ignored when supplied. + + All bulk operations (``update``, ``|=``, ``&=``, ``-=``, ``^=``, ``clear``) + issue exactly one ``cuMemPoolSetAccess`` call. This matters: peer-access + transitions can take seconds per pool because every existing memory mapping + is updated, so coalescing into a single driver call lets the toolkit handle + the mappings in parallel. + """ + + __slots__ = ("_mr",) + + def __init__(self, mr): + self._mr = mr + + @classmethod + def _from_iterable(cls, it): + # Binary set operators (&, |, -, ^) collect their result through + # _from_iterable. Returning a plain set lets the user reason about + # the result independently of any pool's driver state. + return set(it) + + # --- abstract MutableSet methods --- + + def __contains__(self, value) -> bool: + try: + dev_id = _resolve_peer_device_id(value) + except (TypeError, ValueError): + return False + cdef DeviceMemoryResource mr = self._mr + if dev_id == mr._dev_id: + return False + return _peer_access_includes(mr, dev_id) + + def __iter__(self): + from cuda.core._device import Device + + return iter(Device(dev_id) for dev_id in _query_peer_access_ids(self._mr)) + + def __len__(self) -> int: + return len(_query_peer_access_ids(self._mr)) + + def add(self, value) -> None: + """Grant peer access from ``value`` to allocations in this pool.""" + dev_id = _resolve_peer_device_id(value) + cdef DeviceMemoryResource mr = self._mr + if dev_id == mr._dev_id: + return + if _peer_access_includes(mr, dev_id): + return + from cuda.core._device import Device + if not Device(mr._dev_id).can_access_peer(dev_id): + raise ValueError(f"Device {mr._dev_id} cannot access peer: {dev_id}") + _apply_peer_access_diff(mr, (dev_id,), ()) + + def discard(self, value) -> None: + """Revoke peer access from ``value`` to allocations in this pool.""" + try: + dev_id = _resolve_peer_device_id(value) + except (TypeError, ValueError): + return + cdef DeviceMemoryResource mr = self._mr + if dev_id == mr._dev_id: + return + if not _peer_access_includes(mr, dev_id): + return + _apply_peer_access_diff(mr, (), (dev_id,)) + + # --- bulk overrides: one driver call per op --- + + def clear(self) -> None: + """Revoke all peer access in a single driver call.""" + self._apply((), _query_peer_access_ids(self._mr)) + + def update(self, *others) -> None: + """Grant peer access to every device in ``others`` in one driver call.""" + to_add = [] + for other in others: + to_add.extend(other) + if to_add: + self._apply(to_add, ()) + + def difference_update(self, *others) -> None: + """Revoke peer access for every device in ``others`` in one driver call.""" + revoke_ids = set() + for other in others: + for value in other: + try: + revoke_ids.add(_resolve_peer_device_id(value)) + except (TypeError, ValueError): + continue + current = set(_query_peer_access_ids(self._mr)) + to_remove = revoke_ids & current + if to_remove: + self._apply((), to_remove) + + def intersection_update(self, *others) -> None: + """Restrict peer access to the intersection in a single driver call.""" + keep_ids = None + for other in others: + ids = set() + for value in other: + try: + ids.add(_resolve_peer_device_id(value)) + except (TypeError, ValueError): + continue + keep_ids = ids if keep_ids is None else keep_ids & ids + if keep_ids is None: + return # ``set.intersection_update()`` with no args is a no-op + current = set(_query_peer_access_ids(self._mr)) + to_remove = current - keep_ids + if to_remove: + self._apply((), to_remove) + + def symmetric_difference_update(self, other) -> None: + """Toggle peer access for every device in ``other`` in one driver call.""" + toggle_ids = set() + for value in other: + try: + toggle_ids.add(_resolve_peer_device_id(value)) + except (TypeError, ValueError): + continue + current = set(_query_peer_access_ids(self._mr)) + to_add = toggle_ids - current + to_remove = toggle_ids & current + if to_add or to_remove: + self._apply(to_add, to_remove) + + def __ior__(self, other): + self.update(other) + return self + + def __iand__(self, other): + self.intersection_update(other) + return self + + def __isub__(self, other): + if other is self: + self.clear() + else: + self.difference_update(other) + return self + + def __ixor__(self, other): + self.symmetric_difference_update(other) + return self + + def __repr__(self) -> str: + return f"PeerAccessibleBySetProxy({set(self)!r})" + + # --- internal: route every write through one batched driver call --- + + def _apply(self, additions, removals) -> None: + """Compute the diff and issue a single ``cuMemPoolSetAccess``. + + ``additions`` and ``removals`` are user-supplied (``Device | int``); + only the owner device is filtered out. Adds are validated through + :meth:`Device.can_access_peer` via :func:`plan_peer_access_update`; + removals bypass that check (revoking is always permitted). + """ + from cuda.core._device import Device + + cdef DeviceMemoryResource mr = self._mr + owner_id = mr._dev_id + owner = Device(owner_id) + current = _query_peer_access_ids(mr) + + # Plan additions through the existing helper (validates can_access_peer). + plan = plan_peer_access_update( + owner_device_id=owner_id, + current_peer_ids=current, + # union of (current set + requested adds) so the planner emits + # exactly the to_add deltas for these additions, no removals. + requested_devices=[*current, *additions], + resolve_device_id=_resolve_peer_device_id, + can_access_peer=owner.can_access_peer, + ) + to_add = plan.to_add + + # Removals: resolve, drop owner and unknowns, intersect with current. + current_set = set(current) + revoke_ids = set() + for value in removals: + try: + dev_id = _resolve_peer_device_id(value) + except (TypeError, ValueError): + continue + if dev_id == owner_id: + continue + if dev_id in current_set: + revoke_ids.add(dev_id) + to_remove = tuple(sorted(revoke_ids)) + + if not to_add and not to_remove: + return + _apply_peer_access_diff(mr, to_add, to_remove) diff --git a/cuda_core/cuda/core/_stream.pxd b/cuda_core/cuda/core/_stream.pxd index c9ffb4c80a7..de16b84bde2 100644 --- a/cuda_core/cuda/core/_stream.pxd +++ b/cuda_core/cuda/core/_stream.pxd @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 diff --git a/cuda_core/docs/source/api_private.rst b/cuda_core/docs/source/api_private.rst index 846ed928bf2..4df4f70ce87 100644 --- a/cuda_core/docs/source/api_private.rst +++ b/cuda_core/docs/source/api_private.rst @@ -16,6 +16,7 @@ CUDA runtime .. autosummary:: :toctree: generated/ + _memory._peer_access_utils.PeerAccessibleBySetProxy _module.KernelAttributes _module.KernelOccupancy _module.MaxPotentialBlockSizeOccupancyResult diff --git a/cuda_core/docs/source/release/1.0.0-notes.rst b/cuda_core/docs/source/release/1.0.0-notes.rst index 58553fe59ea..747b033c117 100644 --- a/cuda_core/docs/source/release/1.0.0-notes.rst +++ b/cuda_core/docs/source/release/1.0.0-notes.rst @@ -120,6 +120,11 @@ Breaking changes ``CUgraphConditionalHandle`` value. Previously, ``.handle`` had to be extracted explicitly. +- :attr:`DeviceMemoryResource.peer_accessible_by` now returns a + :class:`collections.abc.MutableSet` of :obj:`~_device.Device` objects instead + of a sorted ``tuple[int, ...]``. The property setter is unchanged. + (`#2018 `__) + - ``stream`` is now a required keyword-only argument on APIs that schedule work on a stream (`#2001 `__). diff --git a/cuda_core/tests/helpers/buffers.py b/cuda_core/tests/helpers/buffers.py index 44f84693089..bbe54c3a000 100644 --- a/cuda_core/tests/helpers/buffers.py +++ b/cuda_core/tests/helpers/buffers.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import ctypes diff --git a/cuda_core/tests/helpers/collection_interface_testers.py b/cuda_core/tests/helpers/collection_interface_testers.py index 5197e475c18..63fbed8b381 100644 --- a/cuda_core/tests/helpers/collection_interface_testers.py +++ b/cuda_core/tests/helpers/collection_interface_testers.py @@ -1,13 +1,47 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Reusable helpers to verify collections.abc protocol conformance.""" +"""Reusable helpers to verify collections.abc protocol conformance. + +Two helpers are provided for ``MutableSet``-like subjects, picked by the +capacity of the backing store: + +- :func:`assert_mutable_set_interface` is the standard pass; it requires at + least five distinct insertable items so every method (including the + multi-element bulk operators) can be exercised. +- :func:`assert_single_member_mutable_set_interface` is a focused pass for + proxies whose backing store admits at most one insertable element at a time + (for example, a peer-access view on a system with one valid peer device). + It runs every ``MutableSet`` method at least once using a single member and + one non-member sentinel. + +The two helpers are intentionally separate rather than one helper with a +mode flag: a single-member proxy is a substantially different contract +("capacity one, by hardware") and naming it explicitly in the API keeps each +helper's signature small and its assertions linear. +""" from collections.abc import MutableSet, Set import pytest +def _assert_empty(subject): + """Assertions that hold for any empty MutableSet-like subject.""" + assert isinstance(subject, Set) + assert isinstance(subject, MutableSet) + assert len(subject) == 0 + assert subject == set() + assert list(subject) == [] + + +def _assert_repr_nonempty(subject): + """``__repr__`` produces a non-empty string.""" + r = repr(subject) + assert isinstance(r, str) + assert len(r) > 0 + + def assert_mutable_set_interface(subject, items): """Exercise every MutableSet method on *subject* against a reference set. @@ -23,15 +57,7 @@ def assert_mutable_set_interface(subject, items): a, b, c, d, e = items[:5] ref = set() - # -- ABC conformance -- - assert isinstance(subject, Set) - assert isinstance(subject, MutableSet) - - # -- empty state -- - assert len(subject) == 0 - assert subject == ref - assert subject == set() - assert list(subject) == [] + _assert_empty(subject) # -- add -- subject.add(a) @@ -136,7 +162,157 @@ def assert_mutable_set_interface(subject, items): # -- __iter__ -- assert set(subject) == ref - # -- __repr__ -- - r = repr(subject) - assert isinstance(r, str) - assert len(r) > 0 + _assert_repr_nonempty(subject) + + +def assert_single_member_mutable_set_interface(subject, member, non_member): + """Exercise every MutableSet method on a subject with capacity one. + + Use this for proxies whose backing store admits at most one insertable + element at a time (typically because the underlying resource is bounded + by hardware, e.g. a peer-access view on a system with a single valid + peer device). The subject only ever holds ``set()`` or ``{member}``; + *non_member* supplies the right-hand side of comparisons, ``isdisjoint``, + subset/superset, and binary/in-place operators so every ``MutableSet`` + method is exercised at least once. + + Parameters + ---------- + subject : MutableSet + An **empty** mutable-set-like object to test. + member : hashable + A distinct, hashable object valid for insertion into *subject*. + non_member : hashable + A distinct, hashable object that compares correctly under set + semantics but is guaranteed never to be inserted into *subject* + (typically because the backing store rejects it). + """ + assert member != non_member, "member and non_member must be distinct" + a = member + x = non_member + ref = set() + + _assert_empty(subject) + + # -- add -- + subject.add(a) + ref.add(a) + assert subject == ref + assert a in subject + assert x not in subject + assert len(subject) == 1 + + # add duplicate is a no-op + subject.add(a) + assert subject == ref + assert len(subject) == 1 + + # -- comparison with plain set -- + assert subject == {a} + assert subject != {a, x} + assert subject != set() + + # -- isdisjoint -- + assert subject.isdisjoint({x}) + assert not subject.isdisjoint({a, x}) + + # -- subset / superset -- + assert subject <= {a} + assert subject <= {a, x} + assert not (subject <= set()) + assert subject < {a, x} + assert not (subject < {a}) + assert {a, x} >= subject + assert {a, x} > subject + + # -- binary operators (results are plain sets, never insert into subject) -- + assert subject & {a, x} == {a} + assert subject & {x} == set() + assert subject | {x} == {a, x} + assert subject - {a} == set() + assert subject - {x} == {a} + assert subject ^ {x} == {a, x} + assert subject ^ {a} == set() + + # -- discard non-member is a no-op -- + subject.discard(x) + assert subject == ref + + # -- discard member -- + subject.discard(a) + ref.discard(a) + assert subject == ref + + # -- remove member -- + subject.add(a) + ref.add(a) + subject.remove(a) + ref.remove(a) + assert subject == ref + + # -- remove non-member raises -- + with pytest.raises(KeyError): + subject.remove(x) + + # -- pop empty raises -- + with pytest.raises(KeyError): + subject.pop() + + # -- pop populated -- + subject.add(a) + ref.add(a) + popped = subject.pop() + ref.discard(popped) + assert popped == a + assert popped not in subject + assert subject == ref + + # -- in-place union (|=) covers single insert via bulk path -- + subject |= {a} + ref |= {a} + assert subject == ref + + # -- in-place intersection (&=) keeps the lone member -- + subject &= {a, x} + ref &= {a, x} + assert subject == ref + + # -- in-place intersection (&=) drops the lone member -- + subject &= {x} + ref &= {x} + assert subject == ref + + # -- in-place difference (-=) on non-member is a no-op -- + subject |= {a} + ref |= {a} + subject -= {x} + ref -= {x} + assert subject == ref + + # -- in-place difference (-=) on member empties the subject -- + subject -= {a} + ref -= {a} + assert subject == ref + + # -- in-place symmetric difference (^=): toggle in then out -- + subject ^= {a} + ref ^= {a} + assert subject == ref + subject ^= {a} + ref ^= {a} + assert subject == ref + + # -- clear -- + subject.add(a) + ref.add(a) + subject.clear() + ref.clear() + assert subject == ref + assert len(subject) == 0 + + # -- __iter__ on populated subject -- + subject.add(a) + ref.add(a) + assert set(subject) == ref + + _assert_repr_nonempty(subject) diff --git a/cuda_core/tests/memory_ipc/test_peer_access.py b/cuda_core/tests/memory_ipc/test_peer_access.py index de0baef9b8b..993aa7344ec 100644 --- a/cuda_core/tests/memory_ipc/test_peer_access.py +++ b/cuda_core/tests/memory_ipc/test_peer_access.py @@ -29,7 +29,7 @@ def test_main(self, mempool_device_x2): options = DeviceMemoryResourceOptions(max_size=POOL_SIZE, ipc_enabled=True) mr = DeviceMemoryResource(dev1, options=options) mr.peer_accessible_by = [dev0] - assert mr.peer_accessible_by == (0,) + assert mr.peer_accessible_by == {dev0} # Spawn child process process = mp.Process(target=self.child_main, args=(mr,)) @@ -38,18 +38,18 @@ def test_main(self, mempool_device_x2): assert process.exitcode == 0 # Verify parent's MR still has peer access set (independent state) - assert mr.peer_accessible_by == (0,) + assert mr.peer_accessible_by == {dev0} mr.close() def child_main(self, mr): Device(1).set_current() assert mr.is_mapped is True assert mr.device_id == 1 - assert mr.peer_accessible_by == () + assert mr.peer_accessible_by == set() mr.peer_accessible_by = [0] - assert mr.peer_accessible_by == (0,) + assert mr.peer_accessible_by == {Device(0)} mr.peer_accessible_by = [] - assert mr.peer_accessible_by == () + assert mr.peer_accessible_by == set() mr.close() @@ -70,9 +70,9 @@ def test_main(self, mempool_device_x2, grant_access_in_parent): mr = DeviceMemoryResource(dev1, options=options) if grant_access_in_parent: mr.peer_accessible_by = [dev0] - assert mr.peer_accessible_by == (0,) + assert mr.peer_accessible_by == {dev0} else: - assert mr.peer_accessible_by == () + assert mr.peer_accessible_by == set() buffer = mr.allocate(NBYTES, stream=dev1.default_stream) pgen = PatternGen(dev1, NBYTES) pgen.fill_buffer(buffer, seed=False) @@ -108,14 +108,14 @@ def child_main(self, mr, buffer): # Test 3: Set peer access and verify buffer becomes accessible dev1.set_current() mr.peer_accessible_by = [0] - assert mr.peer_accessible_by == (0,) + assert mr.peer_accessible_by == {dev0} dev0.set_current() PatternGen(dev0, NBYTES).verify_buffer(buffer, seed=False) # Test 4: Revoke peer access and verify buffer becomes inaccessible dev1.set_current() mr.peer_accessible_by = [] - assert mr.peer_accessible_by == () + assert mr.peer_accessible_by == set() dev0.set_current() with pytest.raises(CUDAError, match="CUDA_ERROR_INVALID_VALUE"): PatternGen(dev0, NBYTES).verify_buffer(buffer, seed=False) diff --git a/cuda_core/tests/test_memory_peer_access.py b/cuda_core/tests/test_memory_peer_access.py index 04324ceec81..71beb459143 100644 --- a/cuda_core/tests/test_memory_peer_access.py +++ b/cuda_core/tests/test_memory_peer_access.py @@ -1,10 +1,13 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import pytest from helpers.buffers import PatternGen, compare_buffer_to_constant, make_scratch_buffer +from helpers.collection_interface_testers import assert_single_member_mutable_set_interface -from cuda.core import DeviceMemoryResource, DeviceMemoryResourceOptions +from cuda.core import Device, DeviceMemoryResource, DeviceMemoryResourceOptions, system +from cuda.core._memory import _peer_access_utils +from cuda.core._memory._peer_access_utils import PeerAccessibleBySetProxy from cuda.core._utils.cuda_utils import CUDAError NBYTES = 1024 @@ -100,18 +103,18 @@ def verify_state(state, pattern_seed): transitions = [(s0, s1) for s0 in states for s1 in states if s0 != s1] for init_state, final_state in transitions: dmrs[0].peer_accessible_by = init_state - assert dmrs[0].peer_accessible_by == init_state + assert dmrs[0].peer_accessible_by == {Device(i) for i in init_state} verify_state(init_state, pattern_seed) pattern_seed += 1 dmrs[0].peer_accessible_by = final_state - assert dmrs[0].peer_accessible_by == final_state + assert dmrs[0].peer_accessible_by == {Device(i) for i in final_state} verify_state(final_state, pattern_seed) pattern_seed += 1 def test_peer_access_shared_pool_queries_driver(mempool_device_x2): - """Non-owned pools always query the driver for peer access state.""" + """All pools always query the driver, so wrappers see consistent state.""" dev0, dev1 = mempool_device_x2 # Grant peer access via one wrapper; a second wrapper must see it. @@ -122,18 +125,289 @@ def test_peer_access_shared_pool_queries_driver(mempool_device_x2): # Revoke via dmr2; dmr1 must reflect the change immediately. dmr2.peer_accessible_by = [] - assert dmr1.peer_accessible_by == () + assert dmr1.peer_accessible_by == set() # Re-grant via dmr1. A fresh wrapper that has never read the # property must still query the driver before computing diffs # in the setter, so setting [] must discover and revoke the access. dmr1.peer_accessible_by = [dev1] dmr3 = DeviceMemoryResource(dev0) - assert dmr1.peer_accessible_by == (dev1.device_id,) - assert dmr2.peer_accessible_by == (dev1.device_id,) - assert dmr3.peer_accessible_by == (dev1.device_id,) + assert dmr1.peer_accessible_by == {dev1} + assert dmr2.peer_accessible_by == {dev1} + assert dmr3.peer_accessible_by == {dev1} dmr3.peer_accessible_by = [] - assert DeviceMemoryResource(dev0).peer_accessible_by == () - assert dmr1.peer_accessible_by == () - assert dmr2.peer_accessible_by == () - assert dmr3.peer_accessible_by == () + assert DeviceMemoryResource(dev0).peer_accessible_by == set() + assert dmr1.peer_accessible_by == set() + assert dmr2.peer_accessible_by == set() + assert dmr3.peer_accessible_by == set() + + +# --------------------------------------------------------------------------- +# Set-proxy interface coverage +# +# These tests exercise the ``PeerAccessibleBySetProxy`` surface added in +# v1.0.0. They run against ``mempool_device_x2`` because every CI machine has +# at most 2 GPUs, which means at most one valid peer device. The +# ``assert_single_member_mutable_set_interface`` helper threads that single +# insertable element through the full ``MutableSet`` protocol. +# --------------------------------------------------------------------------- + + +@pytest.fixture +def isolated_dmr_x2(mempool_device_x2): + """Owned-pool DMR on dev0 + the lone peer device (dev1). + + The owned pool guarantees a clean, empty initial peer-access state so the + proxy tests are not polluted by other tests sharing a default pool. + """ + dev0, dev1 = mempool_device_x2 + dmr = DeviceMemoryResource(dev0, DeviceMemoryResourceOptions()) + dmr.peer_accessible_by = [] + return dmr, dev0, dev1 + + +def test_peer_accessible_by_mutable_set_interface(isolated_dmr_x2): + """Run the MutableSet protocol against a single-peer driver-backed view. + + On a 2-GPU box the proxy can only ever hold ``{dev1}`` because there is a + single valid peer. The capacity-one helper exercises every ``MutableSet`` + method using ``dev1`` as the lone insertable element and ``dev0`` (the + owner, which the proxy refuses to insert) as the non-member sentinel. + """ + dmr, dev0, dev1 = isolated_dmr_x2 + assert_single_member_mutable_set_interface( + dmr.peer_accessible_by, + member=dev1, + non_member=dev0, + ) + + +def test_peer_accessible_by_accepts_int_and_device(isolated_dmr_x2): + """``add``/``discard``/``__contains__`` accept ``Device`` and ``int`` interchangeably.""" + dmr, dev0, dev1 = isolated_dmr_x2 + + dmr.peer_accessible_by.add(dev1.device_id) + assert dmr.peer_accessible_by == {dev1} + assert dev1 in dmr.peer_accessible_by + assert dev1.device_id in dmr.peer_accessible_by + + dmr.peer_accessible_by.discard(dev1) + assert dmr.peer_accessible_by == set() + + dmr.peer_accessible_by.add(dev1) + assert dmr.peer_accessible_by == {dev1} + dmr.peer_accessible_by.discard(dev1.device_id) + assert dmr.peer_accessible_by == set() + + +def test_peer_accessible_by_silently_ignores_owner(isolated_dmr_x2): + """The owner device is silently filtered on every write; ``__contains__`` returns False.""" + dmr, dev0, dev1 = isolated_dmr_x2 + + # add/discard on owner is a no-op (no error, no state change) + dmr.peer_accessible_by.add(dev0) + dmr.peer_accessible_by.add(dev0.device_id) + assert dmr.peer_accessible_by == set() + dmr.peer_accessible_by.discard(dev0) + dmr.peer_accessible_by.discard(dev0.device_id) + assert dmr.peer_accessible_by == set() + + # __contains__ on owner is False (matches set semantics, never raises) + assert dev0 not in dmr.peer_accessible_by + assert dev0.device_id not in dmr.peer_accessible_by + + # Owner mixed into bulk ops is filtered, the peer is still added/removed + dmr.peer_accessible_by |= {dev0, dev1} + assert dmr.peer_accessible_by == {dev1} + dmr.peer_accessible_by -= {dev0, dev1} + assert dmr.peer_accessible_by == set() + + +def test_peer_accessible_by_rejects_invalid_inputs(isolated_dmr_x2): + """``add`` raises on out-of-range/unsupported inputs; lenient methods do not.""" + dmr, dev0, dev1 = isolated_dmr_x2 + bad_id = system.get_num_devices() # one past the last valid device ordinal + + # add: validates strictly, propagates errors from Device(bad_id) + with pytest.raises((ValueError, CUDAError)): + dmr.peer_accessible_by.add(bad_id) + # Non-coercible inputs surface whatever Device(value) raises (TypeError or + # ValueError depending on Cython's int coercion path). + with pytest.raises((TypeError, ValueError)): + dmr.peer_accessible_by.add("not-a-device") + + # discard: silently ignores non-coercible values (matches set.discard) + dmr.peer_accessible_by.discard("not-a-device") + assert dmr.peer_accessible_by == set() + + # __contains__: returns False on non-coercible values, never raises + assert "not-a-device" not in dmr.peer_accessible_by + + # __contains__: out-of-range int returns False, never raises + assert bad_id not in dmr.peer_accessible_by + + # remove on a non-member raises KeyError (inherited from MutableSet) + with pytest.raises(KeyError): + dmr.peer_accessible_by.remove(dev1) + + +def test_peer_accessible_by_no_cache_across_proxies(mempool_device_x2): + """Updates via one wrapper are immediately visible through any other proxy.""" + dev0, dev1 = mempool_device_x2 + dmr_a = DeviceMemoryResource(dev0) + dmr_b = DeviceMemoryResource(dev0) + dmr_a.peer_accessible_by = [] + + proxy = dmr_a.peer_accessible_by # acquired before the change below + dmr_b.peer_accessible_by.add(dev1) + # The proxy must reflect the new driver state, not a snapshot. + assert dev1 in proxy + assert proxy == {dev1} + + dmr_b.peer_accessible_by.clear() + assert proxy == set() + + +def test_peer_accessible_by_iteration_order_is_sorted(mempool_device_x2): + """``__iter__`` yields peers in ascending device-ordinal order.""" + dev0, dev1 = mempool_device_x2 + dmr = DeviceMemoryResource(dev0, DeviceMemoryResourceOptions()) + dmr.peer_accessible_by = [dev1] + devices = list(dmr.peer_accessible_by) + ids = [d.device_id for d in devices] + assert ids == sorted(ids) + assert all(isinstance(d, Device) for d in devices) + + +def test_peer_accessible_by_repr(isolated_dmr_x2): + """``repr`` includes the class name and reflects the live contents.""" + dmr, dev0, dev1 = isolated_dmr_x2 + empty_repr = repr(dmr.peer_accessible_by) + assert "PeerAccessibleBySetProxy" in empty_repr + assert "set()" in empty_repr + + dmr.peer_accessible_by.add(dev1) + populated_repr = repr(dmr.peer_accessible_by) + assert "PeerAccessibleBySetProxy" in populated_repr + # Don't pin the exact device repr; just confirm content changed. + assert populated_repr != empty_repr + + +def test_peer_accessible_by_returns_proxy_type(isolated_dmr_x2): + """The getter returns the documented proxy type (anchors the public contract).""" + dmr, dev0, dev1 = isolated_dmr_x2 + assert isinstance(dmr.peer_accessible_by, PeerAccessibleBySetProxy) + + +# --------------------------------------------------------------------------- +# Batching contract: every bulk op must issue at most one cuMemPoolSetAccess +# +# Spying via ``monkeypatch.setattr`` on the module-level +# ``_apply_peer_access_diff`` works because the proxy and the property setter +# call it by bare name, which Cython resolves through the module's globals at +# runtime (the wrapper is a plain ``def``, not a ``cdef inline``). +# --------------------------------------------------------------------------- + + +class _DriverCallSpy: + """Records every actual ``cuMemPoolSetAccess`` invocation. + + Spies on :func:`cuda.core._memory._peer_access_utils._set_pool_access` — + the thin Python-visible wrapper that builds the descriptor array and + issues the single driver call. Earlier no-op layers (e.g. the + augmented-assignment-on-property quirk that reassigns an already-mutated + proxy back through the setter) short-circuit before reaching here, so the + recorded count is exactly the number of real ``cuMemPoolSetAccess`` calls. + """ + + def __init__(self, real): + self._real = real + self.calls = [] + + def __call__(self, mr, to_add, to_remove): + self.calls.append((tuple(to_add), tuple(to_remove))) + self._real(mr, to_add, to_remove) + + +@pytest.fixture +def driver_spy(monkeypatch): + spy = _DriverCallSpy(_peer_access_utils._set_pool_access) + monkeypatch.setattr(_peer_access_utils, "_set_pool_access", spy) + return spy + + +def test_peer_accessible_by_setter_batches_one_call(driver_spy, isolated_dmr_x2): + """``mr.peer_accessible_by = [...]`` issues exactly one driver call (or zero on no-op).""" + dmr, dev0, dev1 = isolated_dmr_x2 + dmr.peer_accessible_by = [dev1] + assert len(driver_spy.calls) == 1 + assert dev1.device_id in driver_spy.calls[-1][0] + + # Reassigning the same set is a no-op (zero driver calls). + driver_spy.calls.clear() + dmr.peer_accessible_by = [dev1] + assert driver_spy.calls == [] + + # Revoking everything is a single call (one removal). + dmr.peer_accessible_by = [] + assert len(driver_spy.calls) == 1 + assert dev1.device_id in driver_spy.calls[-1][1] + + +def test_peer_accessible_by_bulk_ops_batch_one_call(driver_spy, isolated_dmr_x2): + """Every bulk op issues exactly one ``cuMemPoolSetAccess`` (or zero on no-op). + + Covers ``|=``, ``&=``, ``^=``, ``update``, ``difference_update``, and + ``clear``. Both operand styles are exercised: a locally bound proxy + (``proxy |= {...}``) and augmented assignment directly on the property + (``dmr.peer_accessible_by |= {...}``). The latter trips Python's + augmented-assignment-on-property pattern (fetch proxy, mutate, write + back through setter) but the trailing setter call discovers an empty + diff and short-circuits before reaching the driver, so the count is + still one. + """ + dmr, dev0, dev1 = isolated_dmr_x2 + proxy = dmr.peer_accessible_by + + proxy |= {dev1} + assert len(driver_spy.calls) == 1 + driver_spy.calls.clear() + + # &= keeping the lone member: no driver call (no diff). + proxy &= {dev1} + assert driver_spy.calls == [] + + # &= dropping the lone member: one removal. + proxy &= {dev0} + assert len(driver_spy.calls) == 1 + driver_spy.calls.clear() + + # ^= toggling the lone peer in then out: two ops, one call each. + proxy ^= {dev1} + assert len(driver_spy.calls) == 1 + proxy ^= {dev1} + assert len(driver_spy.calls) == 2 + driver_spy.calls.clear() + + # update() with the peer already absent: one add. + proxy.update([dev1]) + assert len(driver_spy.calls) == 1 + driver_spy.calls.clear() + + # clear() with one member: one removal. + proxy.clear() + assert len(driver_spy.calls) == 1 + driver_spy.calls.clear() + + # Already-empty bulk ops are no-ops (nothing to add or remove). + proxy.clear() + proxy.difference_update([dev1]) + proxy -= {dev1} + assert driver_spy.calls == [] + + # Augmented assignment directly on the property is also one driver call: + # the proxy mutates the pool via __ior__, the setter writes back an + # already-up-to-date proxy, and the empty-diff short-circuit prevents a + # second driver call. + dmr.peer_accessible_by |= {dev1} + assert len(driver_spy.calls) == 1 diff --git a/cuda_core/tests/test_stream.py b/cuda_core/tests/test_stream.py index 4e5813ee226..49e372c9d53 100644 --- a/cuda_core/tests/test_stream.py +++ b/cuda_core/tests/test_stream.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import pytest