diff --git a/cuda_core/cuda/core/__init__.py b/cuda_core/cuda/core/__init__.py index dfba887144..c32ff79ecc 100644 --- a/cuda_core/cuda/core/__init__.py +++ b/cuda_core/cuda/core/__init__.py @@ -68,3 +68,4 @@ Stream, StreamOptions, ) +from cuda.core._tensor_map import TensorMapDescriptor diff --git a/cuda_core/cuda/core/_cpp/tensor_map.cpp b/cuda_core/cuda/core/_cpp/tensor_map.cpp new file mode 100644 index 0000000000..af09aa1b2f --- /dev/null +++ b/cuda_core/cuda/core/_cpp/tensor_map.cpp @@ -0,0 +1,149 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tensor_map_cccl.h" + +#include + +#include +#include + +#if defined(__has_include) +# if __has_include() +# include +# define CUDA_CORE_HAS_CUDA_TMA 1 +# else +# define CUDA_CORE_HAS_CUDA_TMA 0 +# endif +# if __has_include() +# include +# define CUDA_CORE_HAS_DLPACK_H 1 +# else +# define CUDA_CORE_HAS_DLPACK_H 0 +# endif +#else +# define CUDA_CORE_HAS_CUDA_TMA 0 +# define CUDA_CORE_HAS_DLPACK_H 0 +#endif + +static inline void cuda_core_write_err(char* err, size_t cap, const char* msg) noexcept +{ + if (!err || cap == 0) + return; + if (!msg) + { + err[0] = '\0'; + return; + } + size_t n = ::strlen(msg); + if (n >= cap) + n = cap - 1; + ::memcpy(err, msg, n); + err[n] = '\0'; +} + +int cuda_core_cccl_make_tma_descriptor_tiled( + void* out_tensor_map, + void* data, + int device_type, + int device_id, + int ndim, + const int64_t* shape, + const int64_t* strides, + uint8_t dtype_code, + uint8_t dtype_bits, + uint16_t dtype_lanes, + const int* box_sizes, + const int* elem_strides, + int interleave_layout, + int swizzle, + int l2_fetch_size, + int oob_fill, + char* err, + size_t err_cap) noexcept +{ +#if !(CUDA_CORE_HAS_CUDA_TMA && CUDA_CORE_HAS_DLPACK_H) + (void)out_tensor_map; + (void)data; + (void)device_type; + (void)device_id; + (void)ndim; + (void)shape; + (void)strides; + (void)dtype_code; + (void)dtype_bits; + (void)dtype_lanes; + (void)box_sizes; + (void)elem_strides; + (void)interleave_layout; + (void)swizzle; + (void)l2_fetch_size; + (void)oob_fill; + cuda_core_write_err(err, err_cap, "CCCL and/or not available at build time"); + return 1; +#else + try + { + if (!out_tensor_map) + { + cuda_core_write_err(err, err_cap, "out_tensor_map is NULL"); + return 1; + } + if (!data) + { + cuda_core_write_err(err, err_cap, "tensor data pointer is NULL"); + return 1; + } + if (!shape || !box_sizes || ndim <= 0) + { + cuda_core_write_err(err, err_cap, "invalid rank/shape/box_sizes"); + return 1; + } + + DLTensor t{}; + t.data = data; + t.device = {static_cast(device_type), device_id}; + t.ndim = ndim; + t.dtype.code = dtype_code; + t.dtype.bits = dtype_bits; + t.dtype.lanes = dtype_lanes; + // CCCL promises not to mutate the arrays, but DLPack uses non-const pointers. + t.shape = const_cast(shape); + t.strides = const_cast(strides); + t.byte_offset = 0; + + const auto layout = static_cast(interleave_layout); + const auto swz = static_cast(swizzle); + const auto l2 = static_cast(l2_fetch_size); + const auto oob = static_cast(oob_fill); + + auto box = cuda::std::span(box_sizes, static_cast(ndim)); + + CUtensorMap desc{}; + if (elem_strides) + { + auto es = cuda::std::span(elem_strides, static_cast(ndim)); + desc = cuda::make_tma_descriptor(t, box, es, layout, swz, l2, oob); + } + else + { + desc = cuda::make_tma_descriptor(t, box, layout, swz, l2, oob); + } + + ::memcpy(out_tensor_map, &desc, sizeof(CUtensorMap)); + cuda_core_write_err(err, err_cap, nullptr); + return 0; + } + catch (const std::exception& e) + { + cuda_core_write_err(err, err_cap, e.what()); + return 1; + } + catch (...) + { + cuda_core_write_err(err, err_cap, "unknown error while building TMA descriptor"); + return 1; + } +#endif +} diff --git a/cuda_core/cuda/core/_cpp/tensor_map_cccl.h b/cuda_core/cuda/core/_cpp/tensor_map_cccl.h new file mode 100644 index 0000000000..71be425182 --- /dev/null +++ b/cuda_core/cuda/core/_cpp/tensor_map_cccl.h @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef CUDA_CORE_TENSOR_MAP_CCCL_H_ +#define CUDA_CORE_TENSOR_MAP_CCCL_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// Build a tiled CUtensorMap using CCCL's cuda::make_tma_descriptor (from ). +// +// Returns 0 on success; on failure returns non-zero and writes a best-effort +// human-readable message into (err, err_cap) if provided. +int cuda_core_cccl_make_tma_descriptor_tiled( + void* out_tensor_map, + void* data, + int device_type, + int device_id, + int ndim, + const int64_t* shape, // length ndim + const int64_t* strides, // length ndim, or NULL for contiguous + uint8_t dtype_code, + uint8_t dtype_bits, + uint16_t dtype_lanes, + const int* box_sizes, // length ndim + const int* elem_strides, // length ndim, or NULL for all-ones overload + int interleave_layout, + int swizzle, + int l2_fetch_size, + int oob_fill, + char* err, + size_t err_cap) noexcept; + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // CUDA_CORE_TENSOR_MAP_CCCL_H_ diff --git a/cuda_core/cuda/core/_kernel_arg_handler.pyx b/cuda_core/cuda/core/_kernel_arg_handler.pyx index 882ca5eaab..28a981fed2 100644 --- a/cuda_core/cuda/core/_kernel_arg_handler.pyx +++ b/cuda_core/cuda/core/_kernel_arg_handler.pyx @@ -6,6 +6,7 @@ from cpython.mem cimport PyMem_Malloc, PyMem_Free from libc.stdint cimport (intptr_t, int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t,) +from libc.string cimport memcpy from libcpp cimport bool as cpp_bool from libcpp.complex cimport complex as cpp_complex from libcpp cimport nullptr @@ -16,6 +17,8 @@ import ctypes import numpy from cuda.core._memory import Buffer +from cuda.core._tensor_map import TensorMapDescriptor as _TensorMapDescriptor_py +from cuda.core._tensor_map cimport TensorMapDescriptor from cuda.core._utils.cuda_utils import driver from cuda.bindings cimport cydriver @@ -97,6 +100,9 @@ cdef object numpy_complex64 = numpy.complex64 cdef object numpy_complex128 = numpy.complex128 +cdef object tensor_map_descriptor_type = _TensorMapDescriptor_py + + # limitation due to cython/cython#534 ctypedef void* voidptr @@ -124,6 +130,26 @@ cdef inline int prepare_arg( return 0 +cdef inline int prepare_tensor_map_arg( + vector.vector[void*]& data, + vector.vector[void*]& data_addresses, + TensorMapDescriptor arg, + const size_t idx) except -1: + arg._check_context_compat() + # Allocate a temporary buffer for the 128-byte CUtensorMap struct. + # We copy rather than pointing directly at arg._tensor_map for lifetime + # safety: ParamHolder owns and frees its argument buffers independently. + cdef void* ptr = PyMem_Malloc(sizeof(cydriver.CUtensorMap)) + if ptr is NULL: + raise MemoryError("Failed to allocate memory for CUtensorMap") + memcpy(ptr, arg._get_data_ptr(), sizeof(cydriver.CUtensorMap)) + # data[idx] is tracked so the allocation is freed in ParamHolder.__dealloc__, + # data_addresses[idx] is the pointer passed to cuLaunchKernel. + data_addresses[idx] = ptr + data[idx] = ptr + return 0 + + cdef inline int prepare_ctypes_arg( vector.vector[void*]& data, vector.vector[void*]& data_addresses, @@ -273,6 +299,9 @@ cdef class ParamHolder: # it's a CUdeviceptr: self.data_addresses[i] = (arg.handle.getPtr()) continue + elif arg_type is tensor_map_descriptor_type: + prepare_tensor_map_arg(self.data, self.data_addresses, arg, i) + continue elif arg_type is bool: prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) continue diff --git a/cuda_core/cuda/core/_memoryview.pyx b/cuda_core/cuda/core/_memoryview.pyx index 0e1df726c0..78002b3ff1 100644 --- a/cuda_core/cuda/core/_memoryview.pyx +++ b/cuda_core/cuda/core/_memoryview.pyx @@ -316,6 +316,39 @@ cdef class StridedMemoryView: view_buffer_strided(view, self.get_buffer(), layout, dtype, self.readonly) return view + def as_tensor_map( + self, + box_dim, + *, + element_strides=None, + data_type=None, + interleave=None, + swizzle=None, + l2_promotion=None, + oob_fill=None, + ): + """Create a tiled :obj:`TensorMapDescriptor` from this view. + + This is a convenience wrapper around + :meth:`cuda.core._tensor_map.TensorMapDescriptor.from_tiled`. + """ + from cuda.core._tensor_map import TensorMapDescriptor + + kwargs = {} + if element_strides is not None: + kwargs["element_strides"] = element_strides + if data_type is not None: + kwargs["data_type"] = data_type + if interleave is not None: + kwargs["interleave"] = interleave + if swizzle is not None: + kwargs["swizzle"] = swizzle + if l2_promotion is not None: + kwargs["l2_promotion"] = l2_promotion + if oob_fill is not None: + kwargs["oob_fill"] = oob_fill + return TensorMapDescriptor.from_tiled(self, box_dim, **kwargs) + def copy_from( self, other : StridedMemoryView, stream : Stream, allocator = None, diff --git a/cuda_core/cuda/core/_tensor_map.pxd b/cuda_core/cuda/core/_tensor_map.pxd new file mode 100644 index 0000000000..4c60b7fc70 --- /dev/null +++ b/cuda_core/cuda/core/_tensor_map.pxd @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from cuda.bindings cimport cydriver +from libc.stdint cimport intptr_t + + +cdef class TensorMapDescriptor: + cdef cydriver.CUtensorMap _tensor_map + cdef int _device_id + cdef intptr_t _context + cdef object _source_ref + cdef object _view_ref + cdef object _repr_info + + cdef int _check_context_compat(self) except -1 + cdef void* _get_data_ptr(self) diff --git a/cuda_core/cuda/core/_tensor_map.pyx b/cuda_core/cuda/core/_tensor_map.pyx new file mode 100644 index 0000000000..9d83a709cc --- /dev/null +++ b/cuda_core/cuda/core/_tensor_map.pyx @@ -0,0 +1,919 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from libc.stdint cimport intptr_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t +from libc.stddef cimport size_t +from cuda.bindings cimport cydriver +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN +from cuda.core._dlpack cimport kDLInt, kDLUInt, kDLFloat, kDLBfloat, _kDLCUDA + +import enum + +import numpy + +from cuda.core._memoryview import StridedMemoryView + +cdef extern from "_cpp/tensor_map_cccl.h": + int cuda_core_cccl_make_tma_descriptor_tiled( + void* out_tensor_map, + void* data, + int device_type, + int device_id, + int ndim, + const int64_t* shape, + const int64_t* strides, + uint8_t dtype_code, + uint8_t dtype_bits, + uint16_t dtype_lanes, + const int* box_sizes, + const int* elem_strides, + int interleave_layout, + int swizzle, + int l2_fetch_size, + int oob_fill, + char* err, + size_t err_cap) nogil + + +try: + from ml_dtypes import bfloat16 as ml_bfloat16 +except ImportError: + ml_bfloat16 = None + + +class TensorMapDataType(enum.IntEnum): + """Data types for tensor map descriptors. + + These correspond to the ``CUtensorMapDataType`` driver enum values. + """ + UINT8 = cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT8 + UINT16 = cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT16 + UINT32 = cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT32 + INT32 = cydriver.CU_TENSOR_MAP_DATA_TYPE_INT32 + UINT64 = cydriver.CU_TENSOR_MAP_DATA_TYPE_UINT64 + INT64 = cydriver.CU_TENSOR_MAP_DATA_TYPE_INT64 + FLOAT16 = cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT16 + FLOAT32 = cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT32 + FLOAT64 = cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT64 + BFLOAT16 = cydriver.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16 + FLOAT32_FTZ = cydriver.CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ + TFLOAT32 = cydriver.CU_TENSOR_MAP_DATA_TYPE_TFLOAT32 + TFLOAT32_FTZ = cydriver.CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ + + +class TensorMapInterleave(enum.IntEnum): + """Interleave layout for tensor map descriptors. + + These correspond to the ``CUtensorMapInterleave`` driver enum values. + """ + NONE = cydriver.CU_TENSOR_MAP_INTERLEAVE_NONE + INTERLEAVE_16B = cydriver.CU_TENSOR_MAP_INTERLEAVE_16B + INTERLEAVE_32B = cydriver.CU_TENSOR_MAP_INTERLEAVE_32B + + +class TensorMapSwizzle(enum.IntEnum): + """Swizzle mode for tensor map descriptors. + + These correspond to the ``CUtensorMapSwizzle`` driver enum values. + """ + NONE = cydriver.CU_TENSOR_MAP_SWIZZLE_NONE + SWIZZLE_32B = cydriver.CU_TENSOR_MAP_SWIZZLE_32B + SWIZZLE_64B = cydriver.CU_TENSOR_MAP_SWIZZLE_64B + SWIZZLE_128B = cydriver.CU_TENSOR_MAP_SWIZZLE_128B + + +class TensorMapL2Promotion(enum.IntEnum): + """L2 promotion mode for tensor map descriptors. + + These correspond to the ``CUtensorMapL2promotion`` driver enum values. + """ + NONE = cydriver.CU_TENSOR_MAP_L2_PROMOTION_NONE + L2_64B = cydriver.CU_TENSOR_MAP_L2_PROMOTION_L2_64B + L2_128B = cydriver.CU_TENSOR_MAP_L2_PROMOTION_L2_128B + L2_256B = cydriver.CU_TENSOR_MAP_L2_PROMOTION_L2_256B + + +class TensorMapOOBFill(enum.IntEnum): + """Out-of-bounds fill mode for tensor map descriptors. + + These correspond to the ``CUtensorMapFloatOOBfill`` driver enum values. + """ + NONE = cydriver.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE + NAN_REQUEST_ZERO_FMA = cydriver.CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA + + +IF CUDA_CORE_BUILD_MAJOR >= 13: + class TensorMapIm2ColWideMode(enum.IntEnum): + """Im2col wide mode for tensor map descriptors. + + These correspond to the ``CUtensorMapIm2ColWideMode`` driver enum values. + Supported on compute capability 10.0+. + """ + W = cydriver.CU_TENSOR_MAP_IM2COL_WIDE_MODE_W + W128 = cydriver.CU_TENSOR_MAP_IM2COL_WIDE_MODE_W128 +ELSE: + class TensorMapIm2ColWideMode(enum.IntEnum): + """Im2col wide mode for tensor map descriptors. + + This enum is always defined for API stability, but the + :meth:`TensorMapDescriptor.from_im2col_wide` factory requires a CUDA 13+ + build and will raise otherwise. + """ + W = 0 + W128 = 1 + + +# Mapping from numpy dtype to TMA data type +_NUMPY_DTYPE_TO_TMA = { + numpy.dtype(numpy.uint8): TensorMapDataType.UINT8, + numpy.dtype(numpy.uint16): TensorMapDataType.UINT16, + numpy.dtype(numpy.uint32): TensorMapDataType.UINT32, + numpy.dtype(numpy.int32): TensorMapDataType.INT32, + numpy.dtype(numpy.uint64): TensorMapDataType.UINT64, + numpy.dtype(numpy.int64): TensorMapDataType.INT64, + numpy.dtype(numpy.float16): TensorMapDataType.FLOAT16, + numpy.dtype(numpy.float32): TensorMapDataType.FLOAT32, + numpy.dtype(numpy.float64): TensorMapDataType.FLOAT64, +} + +if ml_bfloat16 is not None: + _NUMPY_DTYPE_TO_TMA[numpy.dtype(ml_bfloat16)] = TensorMapDataType.BFLOAT16 + + +# Mapping from TMA data type to element size in bytes +_TMA_DATA_TYPE_SIZE = { + TensorMapDataType.UINT8: 1, + TensorMapDataType.UINT16: 2, + TensorMapDataType.UINT32: 4, + TensorMapDataType.INT32: 4, + TensorMapDataType.UINT64: 8, + TensorMapDataType.INT64: 8, + TensorMapDataType.FLOAT16: 2, + TensorMapDataType.FLOAT32: 4, + TensorMapDataType.FLOAT64: 8, + TensorMapDataType.BFLOAT16: 2, + TensorMapDataType.FLOAT32_FTZ: 4, + TensorMapDataType.TFLOAT32: 4, + TensorMapDataType.TFLOAT32_FTZ: 4, +} + + +def _resolve_data_type(view, data_type): + """Resolve the TMA data type from an explicit value or the view's dtype.""" + + if data_type is not None: + if not isinstance(data_type, TensorMapDataType): + raise TypeError( + f"data_type must be a TensorMapDataType, got {type(data_type)}") + return data_type + + dt = view.dtype + if dt is None: + raise ValueError( + "Cannot infer TMA data type from the tensor; " + "please specify data_type explicitly") + + tma_dt = _NUMPY_DTYPE_TO_TMA.get(dt) + if tma_dt is None: + raise ValueError( + f"Unsupported dtype {dt} for TMA; " + f"supported dtypes: {list(_NUMPY_DTYPE_TO_TMA.keys())}. " + "You may also specify data_type explicitly.") + + return tma_dt + + +cdef inline bint _tma_dtype_to_dlpack( + object tma_dt, + uint8_t* out_code, + uint8_t* out_bits, + uint16_t* out_lanes, +) noexcept: + if tma_dt == TensorMapDataType.UINT8: + out_code[0] = kDLUInt + out_bits[0] = 8 + out_lanes[0] = 1 + return True + if tma_dt == TensorMapDataType.UINT16: + out_code[0] = kDLUInt + out_bits[0] = 16 + out_lanes[0] = 1 + return True + if tma_dt == TensorMapDataType.UINT32: + out_code[0] = kDLUInt + out_bits[0] = 32 + out_lanes[0] = 1 + return True + if tma_dt == TensorMapDataType.UINT64: + out_code[0] = kDLUInt + out_bits[0] = 64 + out_lanes[0] = 1 + return True + if tma_dt == TensorMapDataType.INT32: + out_code[0] = kDLInt + out_bits[0] = 32 + out_lanes[0] = 1 + return True + if tma_dt == TensorMapDataType.INT64: + out_code[0] = kDLInt + out_bits[0] = 64 + out_lanes[0] = 1 + return True + if tma_dt == TensorMapDataType.FLOAT16: + out_code[0] = kDLFloat + out_bits[0] = 16 + out_lanes[0] = 1 + return True + if tma_dt == TensorMapDataType.FLOAT32: + out_code[0] = kDLFloat + out_bits[0] = 32 + out_lanes[0] = 1 + return True + if tma_dt == TensorMapDataType.FLOAT64: + out_code[0] = kDLFloat + out_bits[0] = 64 + out_lanes[0] = 1 + return True + if tma_dt == TensorMapDataType.BFLOAT16: + out_code[0] = kDLBfloat + out_bits[0] = 16 + out_lanes[0] = 1 + return True + return False + + +def _get_validated_view(tensor): + """Obtain a device-accessible StridedMemoryView with a 16-byte-aligned pointer.""" + if isinstance(tensor, StridedMemoryView): + view = tensor + else: + # stream_ptr=-1: no stream synchronization needed because descriptor + # creation only reads tensor metadata, it does not move data. + view = StridedMemoryView.from_any_interface(tensor, stream_ptr=-1) + + if not view.is_device_accessible: + raise ValueError("The tensor must be device-accessible") + + if view.ptr % 16 != 0: + raise ValueError( + f"Global memory address must be 16-byte aligned, " + f"got address 0x{view.ptr:x}") + + return view + + +cdef inline intptr_t _get_current_context_ptr() except? 0: + cdef cydriver.CUcontext ctx + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx)) + if ctx == NULL: + raise RuntimeError("TensorMapDescriptor requires an active CUDA context") + return ctx + + +cdef inline int _get_current_device_id() except -1: + cdef cydriver.CUdevice dev + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetDevice(&dev)) + return dev + + +def _compute_byte_strides(shape, strides, elem_size): + """Compute byte strides from element strides or C-contiguous fallback. + + Returns a tuple of byte strides in row-major order. + """ + if strides is not None: + return tuple(s * elem_size for s in strides) + + # C-contiguous: compute byte strides from shape, innermost first + rank = len(shape) + byte_strides = [] + stride = elem_size + for i in range(rank - 1, -1, -1): + byte_strides.append(stride) + stride *= shape[i] + byte_strides.reverse() + return tuple(byte_strides) + + +def _validate_element_strides(element_strides, rank): + """Validate or default element_strides to all-ones.""" + if element_strides is not None: + if len(element_strides) != rank: + raise ValueError( + f"element_strides must have {rank} elements, got {len(element_strides)}") + return element_strides + return (1,) * rank + + +cdef class TensorMapDescriptor: + """Describes a TMA (Tensor Memory Accelerator) tensor map for Hopper+ GPUs. + + A ``TensorMapDescriptor`` wraps the opaque 128-byte ``CUtensorMap`` struct + used by the hardware TMA unit for efficient bulk data movement between + global and shared memory. + + Instances are created via the class methods :meth:`from_tiled` and + :meth:`from_im2col`, and can be passed directly to + :func:`~cuda.core.launch` as a kernel argument. + """ + + def __init__(self): + raise RuntimeError( + "TensorMapDescriptor cannot be instantiated directly. " + "Use TensorMapDescriptor.from_tiled() or " + "TensorMapDescriptor.from_im2col().") + + cdef void* _get_data_ptr(self): + return &self._tensor_map + + cdef int _check_context_compat(self) except -1: + cdef cydriver.CUcontext current_ctx + cdef cydriver.CUdevice current_dev + cdef int current_dev_id + if self._context == 0 and self._device_id < 0: + return 0 + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetCurrent(¤t_ctx)) + if current_ctx == NULL: + raise RuntimeError("TensorMapDescriptor requires an active CUDA context") + if self._context != 0 and current_ctx != self._context: + raise RuntimeError( + "TensorMapDescriptor was created in a different CUDA context") + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetDevice(¤t_dev)) + current_dev_id = current_dev + if self._device_id >= 0 and current_dev_id != self._device_id: + raise RuntimeError( + f"TensorMapDescriptor belongs to device {self._device_id}, " + f"but current device is {current_dev_id}") + return 0 + + @classmethod + def from_tiled(cls, tensor, box_dim, *, + element_strides=None, + data_type=None, + interleave=TensorMapInterleave.NONE, + swizzle=TensorMapSwizzle.NONE, + l2_promotion=TensorMapL2Promotion.NONE, + oob_fill=TensorMapOOBFill.NONE): + """Create a tiled TMA descriptor from a tensor object. + + Parameters + ---------- + tensor : object + Any object supporting DLPack or ``__cuda_array_interface__``, + or a :obj:`~cuda.core.StridedMemoryView`. Must refer to + device-accessible memory with a 16-byte-aligned pointer. + box_dim : tuple of int + The size of each tile dimension (in elements). Must have the + same rank as the tensor and each value must be in [1, 256]. + Specified in the same (row-major) order as the tensor shape. + element_strides : tuple of int, optional + Per-dimension element traversal strides. Default is all 1s. + Specified in the same (row-major) order as the tensor shape. + data_type : TensorMapDataType, optional + Explicit data type override. If ``None``, inferred from the + tensor's dtype. + interleave : TensorMapInterleave + Interleave layout. Default ``NONE``. + swizzle : TensorMapSwizzle + Swizzle mode. Default ``NONE``. + l2_promotion : TensorMapL2Promotion + L2 promotion mode. Default ``NONE``. + oob_fill : TensorMapOOBFill + Out-of-bounds fill mode. Default ``NONE``. + + Returns + ------- + TensorMapDescriptor + + Raises + ------ + ValueError + If the tensor rank is outside [1, 5], the pointer is not + 16-byte aligned, or dimension/stride constraints are violated. + """ + cdef TensorMapDescriptor desc = cls.__new__(cls) + + view = _get_validated_view(tensor) + # Keep both the original tensor object and the validated view alive. + # For DLPack exporters, the view may hold the owning capsule whose + # deleter can free the backing allocation when released. + desc._source_ref = tensor + desc._view_ref = view + desc._context = _get_current_context_ptr() + desc._device_id = _get_current_device_id() + + tma_dt = _resolve_data_type(view, data_type) + cdef int c_data_type_int = int(tma_dt) + cdef cydriver.CUtensorMapDataType c_data_type = c_data_type_int + + cdef intptr_t global_address = view.ptr + shape = view.shape + + cdef int rank = len(shape) + if rank < 1 or rank > 5: + raise ValueError( + f"Tensor rank must be between 1 and 5, got {rank}") + + if len(box_dim) != rank: + raise ValueError( + f"box_dim must have {rank} elements (same as tensor rank), " + f"got {len(box_dim)}") + + for i, bd in enumerate(box_dim): + if bd < 1 or bd > 256: + raise ValueError( + f"box_dim[{i}] must be in [1, 256], got {bd}") + + cdef bint elem_strides_provided = element_strides is not None + element_strides = _validate_element_strides(element_strides, rank) + + # Reuse CCCL/libcu++'s DLPack -> CUtensorMap conversion when possible. + # This avoids maintaining a second, independent validation/encoding implementation. + cdef uint8_t dl_code + cdef uint8_t dl_bits + cdef uint16_t dl_lanes + cdef int64_t c_shape[5] + cdef int64_t c_strides[5] + cdef int c_box_sizes[5] + cdef int c_elem_strides[5] + cdef const int64_t* c_strides_ptr + cdef const int* c_elem_strides_ptr + cdef char errbuf[512] + cdef int i_cccl + cdef int device_type + cdef int c_device_id + cdef int c_cccl_interleave_int + cdef int c_cccl_swizzle_int + cdef int c_cccl_l2_promotion_int + cdef int c_cccl_oob_fill_int + cdef int rc + if _tma_dtype_to_dlpack(tma_dt, &dl_code, &dl_bits, &dl_lanes): + c_strides_ptr = NULL + c_elem_strides_ptr = NULL + errbuf[0] = 0 + + for i_cccl in range(rank): + c_shape[i_cccl] = shape[i_cccl] + c_box_sizes[i_cccl] = box_dim[i_cccl] + if elem_strides_provided: + c_elem_strides[i_cccl] = element_strides[i_cccl] + + if view.strides is not None: + for i_cccl in range(rank): + c_strides[i_cccl] = view.strides[i_cccl] + c_strides_ptr = &c_strides[0] + + if elem_strides_provided: + c_elem_strides_ptr = &c_elem_strides[0] + + device_type = _kDLCUDA + c_device_id = view.device_id + c_cccl_interleave_int = int(interleave) + c_cccl_swizzle_int = int(swizzle) + c_cccl_l2_promotion_int = int(l2_promotion) + c_cccl_oob_fill_int = int(oob_fill) + + with nogil: + rc = cuda_core_cccl_make_tma_descriptor_tiled( + &desc._tensor_map, + global_address, + device_type, + c_device_id, + rank, + &c_shape[0], + c_strides_ptr, + dl_code, + dl_bits, + dl_lanes, + &c_box_sizes[0], + c_elem_strides_ptr, + c_cccl_interleave_int, + c_cccl_swizzle_int, + c_cccl_l2_promotion_int, + c_cccl_oob_fill_int, + &errbuf[0], + sizeof(errbuf), + ) + + if rc == 0: + desc._repr_info = { + "method": "tiled", + "rank": rank, + "data_type": tma_dt, + "swizzle": swizzle, + } + return desc + + msg = errbuf[:].split(b"\0", 1)[0].decode("utf-8", errors="replace") + # If CCCL isn't available at build time, fall back to the direct + # driver API path to preserve functionality on older toolchains. + if "not available at build time" not in msg: + raise ValueError(f"Failed to build TMA descriptor via CCCL: {msg}") + + cdef int elem_size = _TMA_DATA_TYPE_SIZE[tma_dt] + byte_strides = _compute_byte_strides(shape, view.strides, elem_size) + + # Reverse dimensions for column-major cuTensorMap convention + # Python/DLPack: row-major (dim 0 = outermost) + # cuTensorMap: column-major (dim 0 = innermost) + cdef uint64_t[5] c_global_dim + cdef uint64_t[4] c_global_strides # rank - 1 elements + cdef uint32_t[5] c_box_dim + cdef uint32_t[5] c_element_strides + cdef int i_c + + for i_c in range(rank): + # Reverse: Python dim i -> cuTensorMap dim (rank - 1 - i) + c_global_dim[i_c] = shape[rank - 1 - i_c] + c_box_dim[i_c] = box_dim[rank - 1 - i_c] + c_element_strides[i_c] = element_strides[rank - 1 - i_c] + + # globalStrides: rank-1 elements (byte strides for dims 1..N-1 in col-major order) + # The innermost stride (dim 0) is implicit = element size + for i_c in range(rank - 1): + c_global_strides[i_c] = byte_strides[rank - 2 - i_c] + + cdef uint32_t c_rank = rank + cdef int c_interleave_int = int(interleave) + cdef int c_swizzle_int = int(swizzle) + cdef int c_l2_promotion_int = int(l2_promotion) + cdef int c_oob_fill_int = int(oob_fill) + cdef cydriver.CUtensorMapInterleave c_interleave = c_interleave_int + cdef cydriver.CUtensorMapSwizzle c_swizzle = c_swizzle_int + cdef cydriver.CUtensorMapL2promotion c_l2_promotion = c_l2_promotion_int + cdef cydriver.CUtensorMapFloatOOBfill c_oob_fill = c_oob_fill_int + + with nogil: + HANDLE_RETURN(cydriver.cuTensorMapEncodeTiled( + &desc._tensor_map, + c_data_type, + c_rank, + global_address, + c_global_dim, + c_global_strides, + c_box_dim, + c_element_strides, + c_interleave, + c_swizzle, + c_l2_promotion, + c_oob_fill, + )) + + desc._repr_info = { + "method": "tiled", + "rank": rank, + "data_type": tma_dt, + "swizzle": swizzle, + } + + return desc + + @classmethod + def from_im2col(cls, tensor, pixel_box_lower_corner, pixel_box_upper_corner, + channels_per_pixel, pixels_per_column, *, + element_strides=None, + data_type=None, + interleave=TensorMapInterleave.NONE, + swizzle=TensorMapSwizzle.NONE, + l2_promotion=TensorMapL2Promotion.NONE, + oob_fill=TensorMapOOBFill.NONE): + """Create an im2col TMA descriptor from a tensor object. + + Im2col layout is used for convolution-style data access patterns. + + Parameters + ---------- + tensor : object + Any object supporting DLPack or ``__cuda_array_interface__``, + or a :obj:`~cuda.core.StridedMemoryView`. Must refer to + device-accessible memory with a 16-byte-aligned pointer. + pixel_box_lower_corner : tuple of int + Lower corner of the pixel bounding box for each spatial + dimension (rank - 2 elements). Specified in row-major order + matching the tensor's spatial dimensions. + pixel_box_upper_corner : tuple of int + Upper corner of the pixel bounding box for each spatial + dimension (rank - 2 elements). Specified in row-major order + matching the tensor's spatial dimensions. + channels_per_pixel : int + Number of channels per pixel. + pixels_per_column : int + Number of pixels per column. + element_strides : tuple of int, optional + Per-dimension element traversal strides. Default is all 1s. + data_type : TensorMapDataType, optional + Explicit data type override. If ``None``, inferred from the + tensor's dtype. + interleave : TensorMapInterleave + Interleave layout. Default ``NONE``. + swizzle : TensorMapSwizzle + Swizzle mode. Default ``NONE``. + l2_promotion : TensorMapL2Promotion + L2 promotion mode. Default ``NONE``. + oob_fill : TensorMapOOBFill + Out-of-bounds fill mode. Default ``NONE``. + + Returns + ------- + TensorMapDescriptor + + Raises + ------ + ValueError + If the tensor rank is outside [3, 5], the pointer is not + 16-byte aligned, or other constraints are violated. + """ + cdef TensorMapDescriptor desc = cls.__new__(cls) + + view = _get_validated_view(tensor) + desc._source_ref = tensor + desc._view_ref = view + desc._context = _get_current_context_ptr() + desc._device_id = _get_current_device_id() + + tma_dt = _resolve_data_type(view, data_type) + cdef int c_data_type_int = int(tma_dt) + cdef cydriver.CUtensorMapDataType c_data_type = c_data_type_int + + cdef intptr_t global_address = view.ptr + shape = view.shape + + cdef int rank = len(shape) + if rank < 3 or rank > 5: + raise ValueError( + f"Im2col tensor rank must be between 3 and 5, got {rank}") + + cdef int n_spatial = rank - 2 + if len(pixel_box_lower_corner) != n_spatial: + raise ValueError( + f"pixel_box_lower_corner must have {n_spatial} elements " + f"(rank - 2), got {len(pixel_box_lower_corner)}") + if len(pixel_box_upper_corner) != n_spatial: + raise ValueError( + f"pixel_box_upper_corner must have {n_spatial} elements " + f"(rank - 2), got {len(pixel_box_upper_corner)}") + + element_strides = _validate_element_strides(element_strides, rank) + + cdef int elem_size = _TMA_DATA_TYPE_SIZE[tma_dt] + byte_strides = _compute_byte_strides(shape, view.strides, elem_size) + + # Reverse all dimension arrays for column-major convention + cdef uint64_t[5] c_global_dim + cdef uint64_t[4] c_global_strides + cdef uint32_t[5] c_element_strides + cdef int[3] c_pixel_box_lower # max 3 spatial dims (rank 5 - 2) + cdef int[3] c_pixel_box_upper + cdef int i_c + + for i_c in range(3): + c_pixel_box_lower[i_c] = 0 + c_pixel_box_upper[i_c] = 0 + + for i_c in range(rank): + c_global_dim[i_c] = shape[rank - 1 - i_c] + c_element_strides[i_c] = element_strides[rank - 1 - i_c] + + for i_c in range(rank - 1): + c_global_strides[i_c] = byte_strides[rank - 2 - i_c] + + # Reverse spatial dimensions for lower/upper corners + for i_c in range(n_spatial): + c_pixel_box_lower[i_c] = pixel_box_lower_corner[n_spatial - 1 - i_c] + c_pixel_box_upper[i_c] = pixel_box_upper_corner[n_spatial - 1 - i_c] + + cdef uint32_t c_rank = rank + cdef uint32_t c_channels = channels_per_pixel + cdef uint32_t c_pixels = pixels_per_column + cdef int c_interleave_int = int(interleave) + cdef int c_swizzle_int = int(swizzle) + cdef int c_l2_promotion_int = int(l2_promotion) + cdef int c_oob_fill_int = int(oob_fill) + cdef cydriver.CUtensorMapInterleave c_interleave = c_interleave_int + cdef cydriver.CUtensorMapSwizzle c_swizzle = c_swizzle_int + cdef cydriver.CUtensorMapL2promotion c_l2_promotion = c_l2_promotion_int + cdef cydriver.CUtensorMapFloatOOBfill c_oob_fill = c_oob_fill_int + + with nogil: + HANDLE_RETURN(cydriver.cuTensorMapEncodeIm2col( + &desc._tensor_map, + c_data_type, + c_rank, + global_address, + c_global_dim, + c_global_strides, + c_pixel_box_lower, + c_pixel_box_upper, + c_channels, + c_pixels, + c_element_strides, + c_interleave, + c_swizzle, + c_l2_promotion, + c_oob_fill, + )) + + desc._repr_info = { + "method": "im2col", + "rank": rank, + "data_type": tma_dt, + "swizzle": swizzle, + } + + return desc + + @classmethod + def from_im2col_wide(cls, tensor, pixel_box_lower_corner_width, pixel_box_upper_corner_width, + channels_per_pixel, pixels_per_column, *, + element_strides=None, + data_type=None, + interleave=TensorMapInterleave.NONE, + mode=TensorMapIm2ColWideMode.W, + swizzle=TensorMapSwizzle.SWIZZLE_128B, + l2_promotion=TensorMapL2Promotion.NONE, + oob_fill=TensorMapOOBFill.NONE): + """Create an im2col-wide TMA descriptor from a tensor object. + + Im2col-wide layout loads elements exclusively along the W (width) + dimension. This variant is supported on compute capability 10.0+ + (Blackwell and later). + + Parameters + ---------- + tensor : object + Any object supporting DLPack or ``__cuda_array_interface__``, + or a :obj:`~cuda.core.StridedMemoryView`. Must refer to + device-accessible memory with a 16-byte-aligned pointer. + pixel_box_lower_corner_width : int + Lower corner of the pixel bounding box along the W dimension. + pixel_box_upper_corner_width : int + Upper corner of the pixel bounding box along the W dimension. + channels_per_pixel : int + Number of channels per pixel. + pixels_per_column : int + Number of pixels per column. + element_strides : tuple of int, optional + Per-dimension element traversal strides. Default is all 1s. + data_type : TensorMapDataType, optional + Explicit data type override. If ``None``, inferred from the + tensor's dtype. + interleave : TensorMapInterleave + Interleave layout. Default ``NONE``. + mode : TensorMapIm2ColWideMode + Im2col wide mode. Default ``W``. + swizzle : TensorMapSwizzle + Swizzle mode. Default ``SWIZZLE_128B``. + l2_promotion : TensorMapL2Promotion + L2 promotion mode. Default ``NONE``. + oob_fill : TensorMapOOBFill + Out-of-bounds fill mode. Default ``NONE``. + + Returns + ------- + TensorMapDescriptor + + Raises + ------ + ValueError + If the tensor rank is outside [3, 5], the pointer is not + 16-byte aligned, or other constraints are violated. + """ + IF CUDA_CORE_BUILD_MAJOR < 13: + raise RuntimeError( + "TensorMapDescriptor.from_im2col_wide requires a CUDA 13+ build") + ELSE: + cdef TensorMapDescriptor desc = cls.__new__(cls) + + view = _get_validated_view(tensor) + desc._source_ref = tensor + desc._view_ref = view + desc._context = _get_current_context_ptr() + desc._device_id = _get_current_device_id() + + tma_dt = _resolve_data_type(view, data_type) + cdef int c_data_type_int = int(tma_dt) + cdef cydriver.CUtensorMapDataType c_data_type = c_data_type_int + + cdef intptr_t global_address = view.ptr + shape = view.shape + + cdef int rank = len(shape) + if rank < 3 or rank > 5: + raise ValueError( + f"Im2col-wide tensor rank must be between 3 and 5, got {rank}") + + element_strides = _validate_element_strides(element_strides, rank) + + cdef int elem_size = _TMA_DATA_TYPE_SIZE[tma_dt] + byte_strides = _compute_byte_strides(shape, view.strides, elem_size) + + # Reverse all dimension arrays for column-major convention + cdef uint64_t[5] c_global_dim + cdef uint64_t[4] c_global_strides + cdef uint32_t[5] c_element_strides + cdef int i_c + + for i_c in range(rank): + c_global_dim[i_c] = shape[rank - 1 - i_c] + c_element_strides[i_c] = element_strides[rank - 1 - i_c] + + for i_c in range(rank - 1): + c_global_strides[i_c] = byte_strides[rank - 2 - i_c] + + cdef uint32_t c_rank = rank + cdef int c_lower_w = pixel_box_lower_corner_width + cdef int c_upper_w = pixel_box_upper_corner_width + cdef uint32_t c_channels = channels_per_pixel + cdef uint32_t c_pixels = pixels_per_column + cdef int c_interleave_int = int(interleave) + cdef int c_mode_int = int(mode) + cdef int c_swizzle_int = int(swizzle) + cdef int c_l2_promotion_int = int(l2_promotion) + cdef int c_oob_fill_int = int(oob_fill) + cdef cydriver.CUtensorMapInterleave c_interleave = c_interleave_int + cdef cydriver.CUtensorMapIm2ColWideMode c_mode = c_mode_int + cdef cydriver.CUtensorMapSwizzle c_swizzle = c_swizzle_int + cdef cydriver.CUtensorMapL2promotion c_l2_promotion = c_l2_promotion_int + cdef cydriver.CUtensorMapFloatOOBfill c_oob_fill = c_oob_fill_int + + with nogil: + HANDLE_RETURN(cydriver.cuTensorMapEncodeIm2colWide( + &desc._tensor_map, + c_data_type, + c_rank, + global_address, + c_global_dim, + c_global_strides, + c_lower_w, + c_upper_w, + c_channels, + c_pixels, + c_element_strides, + c_interleave, + c_mode, + c_swizzle, + c_l2_promotion, + c_oob_fill, + )) + + desc._repr_info = { + "method": "im2col_wide", + "rank": rank, + "data_type": tma_dt, + "swizzle": swizzle, + } + + return desc + + def replace_address(self, tensor): + """Replace the global memory address in this tensor map descriptor. + + This is useful when the tensor data has been reallocated but the + shape, strides, and other parameters remain the same. + + Parameters + ---------- + tensor : object + Any object supporting DLPack or ``__cuda_array_interface__``, + or a :obj:`~cuda.core.StridedMemoryView`. Must refer to + device-accessible memory with a 16-byte-aligned pointer. + """ + self._check_context_compat() + view = _get_validated_view(tensor) + if view.device_id != self._device_id: + raise ValueError( + f"replace_address expects tensor on device {self._device_id}, got {view.device_id}") + + cdef intptr_t global_address = view.ptr + + with nogil: + HANDLE_RETURN(cydriver.cuTensorMapReplaceAddress( + &self._tensor_map, + global_address, + )) + + # Update the source reference only after the driver call succeeds, + # so we don't drop the old tensor (risking a dangling pointer in the + # CUtensorMap struct) if the call fails. + self._source_ref = tensor + self._view_ref = view + + def __repr__(self): + info = self._repr_info + if info is None: + return "TensorMapDescriptor()" + parts = [] + if "method" in info: + parts.append(info["method"]) + if "rank" in info: + parts.append(f"rank={info['rank']}") + if "data_type" in info: + parts.append(f"dtype={info['data_type'].name}") + if "swizzle" in info: + parts.append(f"swizzle={info['swizzle'].name}") + return f"TensorMapDescriptor({', '.join(parts)})" diff --git a/cuda_core/examples/tma_replace_address.py b/cuda_core/examples/tma_replace_address.py new file mode 100644 index 0000000000..a734301fd6 --- /dev/null +++ b/cuda_core/examples/tma_replace_address.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +# ################################################################################ +# +# This example demonstrates how to use replace_address() to repoint a TMA +# (Tensor Memory Accelerator) descriptor at a different tensor without +# rebuilding the descriptor from scratch. +# +# The workflow is: +# +# 1. Create a TMA tiled descriptor and launch a kernel to verify it works +# 2. Allocate a second tensor with different content +# 3. Call replace_address() to repoint the same descriptor at the new tensor +# 4. Re-launch the kernel and verify it reads from the new tensor +# +# This is useful when the tensor layout (shape, dtype, tile size) stays the +# same but the underlying data buffer changes, e.g. double-buffering or +# iterating over a sequence of same-shaped tensors. +# +# Requirements: +# - Hopper or later GPU (compute capability >= 9.0) +# - CuPy +# - CUDA toolkit headers (CUDA_PATH or CUDA_HOME set) +# +# ################################################################################ + +import sys + +import cupy as cp +import numpy as np + +from cuda.core import ( + Device, + LaunchConfig, + Program, + ProgramOptions, + StridedMemoryView, + launch, +) + +# --------------------------------------------------------------------------- +# Check for Hopper+ GPU +# --------------------------------------------------------------------------- +dev = Device() +arch = dev.compute_capability +if arch < (9, 0): + print( + "TMA requires compute capability >= 9.0 (Hopper or later)", + file=sys.stderr, + ) + sys.exit(0) +dev.set_current() + +arch_str = "".join(f"{i}" for i in arch) + +# --------------------------------------------------------------------------- +# CUDA kernel that uses TMA to load a 1-D tile into shared memory, then +# copies the tile to an output buffer so we can verify correctness. +# +# The CUtensorMap struct (128 bytes) is defined inline so the kernel can be +# compiled with NVRTC without pulling in the full driver-API header. +# +# Key points: +# - The tensor map is passed by value with __grid_constant__ so the TMA +# hardware can read it from grid-constant memory. +# - Thread 0 in each block issues the TMA load and manages the mbarrier. +# - All threads wait on the mbarrier, then copy from shared to global. +# --------------------------------------------------------------------------- +TILE_SIZE = 128 # elements per tile (must match the kernel constant) + +code = r""" +// Minimal definition of the 128-byte opaque tensor map struct. +struct __align__(64) TensorMap { unsigned long long opaque[16]; }; + +static constexpr int TILE_SIZE = 128; + +extern "C" +__global__ void tma_copy( + const __grid_constant__ TensorMap tensor_map, + float* output, + int N) +{ + __shared__ __align__(128) float smem[TILE_SIZE]; + __shared__ __align__(8) unsigned long long mbar; + + const int tid = threadIdx.x; + const int tile_start = blockIdx.x * TILE_SIZE; + + // ---- Thread 0: set up mbarrier and issue the TMA load ---- + if (tid == 0) + { + // Initialise a single-phase mbarrier (1 arriving thread). + asm volatile( + "mbarrier.init.shared.b64 [%0], 1;" + :: "r"((unsigned)__cvta_generic_to_shared(&mbar))); + + // Ask TMA to copy TILE_SIZE floats starting at element 'tile_start' + // from the tensor described by 'tensor_map' into shared memory. + asm volatile( + "cp.async.bulk.tensor.1d.shared::cluster.global.tile" + ".mbarrier::complete_tx::bytes" + " [%0], [%1, {%2}], [%3];" + :: "r"((unsigned)__cvta_generic_to_shared(smem)), + "l"(&tensor_map), + "r"(tile_start), + "r"((unsigned)__cvta_generic_to_shared(&mbar))); + + // Tell the mbarrier how many bytes the TMA will deliver. + asm volatile( + "mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" + :: "r"((unsigned)__cvta_generic_to_shared(&mbar)), + "r"((unsigned)(TILE_SIZE * sizeof(float)))); + } + + __syncthreads(); + + // ---- Wait for the TMA load to complete ---- + if (tid == 0) + { + asm volatile( + "{ .reg .pred P; \n" + "WAIT: \n" + " mbarrier.try_wait.parity.shared.b64 P, [%0], 0; \n" + " @!P bra WAIT; \n" + "} \n" + :: "r"((unsigned)__cvta_generic_to_shared(&mbar))); + } + + __syncthreads(); + + // ---- Copy the tile from shared memory to the output buffer ---- + if (tid < TILE_SIZE) + { + const int idx = tile_start + tid; + if (idx < N) + output[idx] = smem[tid]; + } +} +""" + +# --------------------------------------------------------------------------- +# Compile the kernel +# --------------------------------------------------------------------------- +prog = Program( + code, + code_type="c++", + options=ProgramOptions(std="c++17", arch=f"sm_{arch_str}"), +) +mod = prog.compile("cubin") +ker = mod.get_kernel("tma_copy") + +# --------------------------------------------------------------------------- +# 1) Prepare input data and verify the initial TMA copy +# --------------------------------------------------------------------------- +N = 1024 +a = cp.arange(N, dtype=cp.float32) # [0, 1, 2, ..., N-1] +output = cp.zeros(N, dtype=cp.float32) +dev.sync() # cupy uses its own stream + +tensor_map = StridedMemoryView.from_any_interface(a, stream_ptr=-1).as_tensor_map(box_dim=(TILE_SIZE,)) + +n_tiles = N // TILE_SIZE +config = LaunchConfig(grid=n_tiles, block=TILE_SIZE) +launch(dev.default_stream, config, ker, tensor_map, output.data.ptr, np.int32(N)) +dev.sync() + +assert cp.array_equal(output, a), "TMA copy produced incorrect results" +print(f"TMA copy verified: {N} elements across {n_tiles} tiles") + +# --------------------------------------------------------------------------- +# 2) Demonstrate replace_address() +# Create a second tensor with different content, point the *same* +# descriptor at it, and re-launch without rebuilding the descriptor. +# --------------------------------------------------------------------------- +b = cp.full(N, fill_value=42.0, dtype=cp.float32) +dev.sync() + +tensor_map.replace_address(b) + +output2 = cp.zeros(N, dtype=cp.float32) +dev.sync() + +launch(dev.default_stream, config, ker, tensor_map, output2.data.ptr, np.int32(N)) +dev.sync() + +assert cp.array_equal(output2, b), "replace_address produced incorrect results" +print("replace_address verified: descriptor reused with new source tensor") diff --git a/cuda_core/examples/tma_tensor_map.py b/cuda_core/examples/tma_tensor_map.py new file mode 100644 index 0000000000..2a5ce9ad86 --- /dev/null +++ b/cuda_core/examples/tma_tensor_map.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +# ################################################################################ +# +# This example demonstrates how to use TMA (Tensor Memory Accelerator) descriptors +# with cuda.core on Hopper+ GPUs (compute capability >= 9.0). +# +# TMA enables efficient bulk data movement between global and shared memory using +# hardware-managed tensor map descriptors. This example shows: +# +# 1. Creating a TMA tiled descriptor from a CuPy device array +# 2. Passing the descriptor to a kernel via launch() +# 3. Using TMA to load tiles into shared memory (via inline PTX) +# +# Requirements: +# - Hopper or later GPU (compute capability >= 9.0) +# - CuPy +# - CUDA toolkit headers (CUDA_PATH or CUDA_HOME set) +# +# ################################################################################ + +import sys + +import cupy as cp +import numpy as np + +from cuda.core import ( + Device, + LaunchConfig, + Program, + ProgramOptions, + StridedMemoryView, + launch, +) + +# --------------------------------------------------------------------------- +# Check for Hopper+ GPU +# --------------------------------------------------------------------------- +dev = Device() +arch = dev.compute_capability +if arch < (9, 0): + print( + "TMA requires compute capability >= 9.0 (Hopper or later)", + file=sys.stderr, + ) + sys.exit(0) +dev.set_current() + +# --------------------------------------------------------------------------- +# CUDA kernel that uses TMA to load a 1-D tile into shared memory, then +# copies the tile to an output buffer so we can verify correctness. +# +# The CUtensorMap struct (128 bytes) is defined inline so the kernel can be +# compiled with NVRTC without pulling in the full driver-API header. +# +# Key points: +# - The tensor map is passed by value with __grid_constant__ so the TMA +# hardware can read it from grid-constant memory. +# - Thread 0 in each block issues the TMA load and manages the mbarrier. +# - All threads wait on the mbarrier, then copy from shared to global. +# --------------------------------------------------------------------------- +TILE_SIZE = 128 # elements per tile (must match the kernel constant) + +code = r""" +// Minimal definition of the 128-byte opaque tensor map struct. +struct __align__(64) TensorMap { unsigned long long opaque[16]; }; + +static constexpr int TILE_SIZE = 128; + +extern "C" +__global__ void tma_copy( + const __grid_constant__ TensorMap tensor_map, + float* output, + int N) +{ + __shared__ __align__(128) float smem[TILE_SIZE]; + __shared__ __align__(8) unsigned long long mbar; + + const int tid = threadIdx.x; + const int tile_start = blockIdx.x * TILE_SIZE; + + // ---- Thread 0: set up mbarrier and issue the TMA load ---- + if (tid == 0) + { + // Initialise a single-phase mbarrier (1 arriving thread). + asm volatile( + "mbarrier.init.shared.b64 [%0], 1;" + :: "r"((unsigned)__cvta_generic_to_shared(&mbar))); + + // Ask TMA to copy TILE_SIZE floats starting at element 'tile_start' + // from the tensor described by 'tensor_map' into shared memory. + asm volatile( + "cp.async.bulk.tensor.1d.shared::cluster.global.tile" + ".mbarrier::complete_tx::bytes" + " [%0], [%1, {%2}], [%3];" + :: "r"((unsigned)__cvta_generic_to_shared(smem)), + "l"(&tensor_map), + "r"(tile_start), + "r"((unsigned)__cvta_generic_to_shared(&mbar))); + + // Tell the mbarrier how many bytes the TMA will deliver. + asm volatile( + "mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" + :: "r"((unsigned)__cvta_generic_to_shared(&mbar)), + "r"((unsigned)(TILE_SIZE * sizeof(float)))); + } + + __syncthreads(); + + // ---- Wait for the TMA load to complete ---- + if (tid == 0) + { + asm volatile( + "{ .reg .pred P; \n" + "WAIT: \n" + " mbarrier.try_wait.parity.shared.b64 P, [%0], 0; \n" + " @!P bra WAIT; \n" + "} \n" + :: "r"((unsigned)__cvta_generic_to_shared(&mbar))); + } + + __syncthreads(); + + // ---- Copy the tile from shared memory to the output buffer ---- + if (tid < TILE_SIZE) + { + const int idx = tile_start + tid; + if (idx < N) + output[idx] = smem[tid]; + } +} +""" + +# --------------------------------------------------------------------------- +# Compile the kernel +# --------------------------------------------------------------------------- +prog = Program( + code, + code_type="c++", + options=ProgramOptions(std="c++17", arch=f"sm_{dev.arch}"), +) +mod = prog.compile("cubin") +ker = mod.get_kernel("tma_copy") + +# --------------------------------------------------------------------------- +# 1) Prepare input data on the device +# --------------------------------------------------------------------------- +N = 1024 +a = cp.arange(N, dtype=cp.float32) # [0, 1, 2, ..., N-1] +output = cp.zeros(N, dtype=cp.float32) +dev.sync() # cupy uses its own stream + +# --------------------------------------------------------------------------- +# 2) Create a TMA tiled descriptor from a StridedMemoryView. +# The dtype (float32) is inferred automatically from the CuPy array. +# --------------------------------------------------------------------------- +tensor_map = StridedMemoryView.from_any_interface(a, stream_ptr=-1).as_tensor_map(box_dim=(TILE_SIZE,)) + +# --------------------------------------------------------------------------- +# 3) Launch the kernel +# The TensorMapDescriptor is passed directly as a kernel argument — the +# 128-byte struct is copied into kernel parameter space automatically. +# --------------------------------------------------------------------------- +n_tiles = N // TILE_SIZE +config = LaunchConfig(grid=n_tiles, block=TILE_SIZE) +launch(dev.default_stream, config, ker, tensor_map, output.data.ptr, np.int32(N)) +dev.sync() + +assert cp.array_equal(output, a), "TMA copy produced incorrect results" +print(f"TMA copy verified: {N} elements across {n_tiles} tiles") diff --git a/cuda_core/tests/test_tensor_map.py b/cuda_core/tests/test_tensor_map.py new file mode 100644 index 0000000000..dee1f1e2e1 --- /dev/null +++ b/cuda_core/tests/test_tensor_map.py @@ -0,0 +1,490 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest + +from cuda.core import ( + Device, + StridedMemoryView, + TensorMapDescriptor, +) +from cuda.core._tensor_map import ( + TensorMapDataType, + TensorMapIm2ColWideMode, + TensorMapInterleave, + TensorMapL2Promotion, + TensorMapOOBFill, + TensorMapSwizzle, +) + + +@pytest.fixture +def dev(init_cuda): + return Device() + + +@pytest.fixture +def skip_if_no_tma(dev): + if not dev.properties.tensor_map_access_supported: + pytest.skip("Device does not support TMA (requires compute capability 9.0+)") + + +class _DeviceArray: + """Wrap a Buffer with explicit shape via __cuda_array_interface__. + + dev.allocate() returns a 1D byte buffer. For multi-dimensional TMA tests + we need the tensor to report a proper shape/dtype so the TMA encoder sees + the correct rank, dimensions, and strides. + """ + + def __init__(self, buf, shape, dtype=np.float32): + self._buf = buf # prevent GC + self.__cuda_array_interface__ = { + "shape": tuple(shape), + "typestr": np.dtype(dtype).str, + "data": (int(buf.handle), False), + "version": 3, + } + + +class TestTensorMapEnums: + """Test that enum wrappers expose the expected values.""" + + def test_data_type_values(self): + assert TensorMapDataType.UINT8 == 0 + assert TensorMapDataType.FLOAT32 == 7 + assert TensorMapDataType.FLOAT64 == 8 + assert TensorMapDataType.BFLOAT16 == 9 + + def test_interleave_values(self): + assert TensorMapInterleave.NONE == 0 + assert TensorMapInterleave.INTERLEAVE_16B == 1 + assert TensorMapInterleave.INTERLEAVE_32B == 2 + + def test_swizzle_values(self): + assert TensorMapSwizzle.NONE == 0 + assert TensorMapSwizzle.SWIZZLE_32B == 1 + assert TensorMapSwizzle.SWIZZLE_64B == 2 + assert TensorMapSwizzle.SWIZZLE_128B == 3 + + def test_l2_promotion_values(self): + assert TensorMapL2Promotion.NONE == 0 + assert TensorMapL2Promotion.L2_64B == 1 + assert TensorMapL2Promotion.L2_128B == 2 + assert TensorMapL2Promotion.L2_256B == 3 + + def test_oob_fill_values(self): + assert TensorMapOOBFill.NONE == 0 + assert TensorMapOOBFill.NAN_REQUEST_ZERO_FMA == 1 + + def test_im2col_wide_mode_values(self): + assert TensorMapIm2ColWideMode.W == 0 + assert TensorMapIm2ColWideMode.W128 == 1 + + +class TestTensorMapDescriptorCreation: + """Test TensorMapDescriptor factory methods.""" + + def test_cannot_instantiate_directly(self): + with pytest.raises(RuntimeError, match="cannot be instantiated directly"): + TensorMapDescriptor() + + def test_from_tiled_1d(self, dev, skip_if_no_tma): + buf = dev.allocate(1024 * 4) # 1024 float32 elements + desc = TensorMapDescriptor.from_tiled( + buf, + box_dim=(64,), + data_type=TensorMapDataType.FLOAT32, + ) + assert desc is not None + assert repr(desc) == "TensorMapDescriptor(tiled, rank=1, dtype=FLOAT32, swizzle=NONE)" + + def test_from_tiled_2d(self, dev, skip_if_no_tma): + buf = dev.allocate(64 * 64 * 4) # 64x64 float32 + tensor = _DeviceArray(buf, (64, 64)) + desc = TensorMapDescriptor.from_tiled( + tensor, + box_dim=(32, 32), + data_type=TensorMapDataType.FLOAT32, + ) + assert desc is not None + + def test_strided_memory_view_as_tensor_map(self, dev, skip_if_no_tma): + buf = dev.allocate(64 * 64 * 4) + tensor = _DeviceArray(buf, (64, 64)) + view = StridedMemoryView.from_any_interface(tensor, stream_ptr=-1) + desc = view.as_tensor_map( + box_dim=(32, 32), + data_type=TensorMapDataType.FLOAT32, + ) + assert desc is not None + + def test_from_tiled_3d(self, dev, skip_if_no_tma): + buf = dev.allocate(16 * 16 * 16 * 4) # 16x16x16 float32 + tensor = _DeviceArray(buf, (16, 16, 16)) + desc = TensorMapDescriptor.from_tiled( + tensor, + box_dim=(8, 8, 8), + data_type=TensorMapDataType.FLOAT32, + ) + assert desc is not None + + def test_from_tiled_5d(self, dev, skip_if_no_tma): + # 5D: exercises all 5 c_global_dim / 4 c_global_strides slots + shape = (2, 4, 4, 4, 8) + n_bytes = 2 * 4 * 4 * 4 * 8 * 4 # float32 + buf = dev.allocate(n_bytes) + tensor = _DeviceArray(buf, shape) + desc = TensorMapDescriptor.from_tiled( + tensor, + box_dim=(1, 2, 2, 2, 8), + ) + assert desc is not None + + def test_from_tiled_with_element_strides_buffer(self, dev, skip_if_no_tma): + # Use a Buffer input (DLPack path) and explicit element_strides. + buf = dev.allocate(1024 * 4) + desc = TensorMapDescriptor.from_tiled( + buf, + box_dim=(64,), + element_strides=(2,), + data_type=TensorMapDataType.FLOAT32, + ) + assert desc is not None + + def test_from_tiled_with_element_strides_cai(self, dev, skip_if_no_tma): + # Use a CAI-style tensor wrapper and explicit element_strides. + buf = dev.allocate(64 * 64 * 4) + tensor = _DeviceArray(buf, (64, 64)) + desc = TensorMapDescriptor.from_tiled( + tensor, + box_dim=(32, 32), + element_strides=(2, 1), + data_type=TensorMapDataType.FLOAT32, + ) + assert desc is not None + + def test_from_tiled_with_swizzle(self, dev, skip_if_no_tma): + buf = dev.allocate(64 * 64 * 4) + tensor = _DeviceArray(buf, (64, 64)) + desc = TensorMapDescriptor.from_tiled( + tensor, + box_dim=(32, 32), + data_type=TensorMapDataType.FLOAT32, + swizzle=TensorMapSwizzle.SWIZZLE_128B, + ) + assert desc is not None + + def test_from_tiled_with_l2_promotion(self, dev, skip_if_no_tma): + buf = dev.allocate(64 * 64 * 4) + tensor = _DeviceArray(buf, (64, 64)) + desc = TensorMapDescriptor.from_tiled( + tensor, + box_dim=(32, 32), + data_type=TensorMapDataType.FLOAT32, + l2_promotion=TensorMapL2Promotion.L2_128B, + ) + assert desc is not None + + def test_from_tiled_with_oob_fill(self, dev, skip_if_no_tma): + buf = dev.allocate(64 * 64 * 4) + tensor = _DeviceArray(buf, (64, 64)) + desc = TensorMapDescriptor.from_tiled( + tensor, + box_dim=(32, 32), + data_type=TensorMapDataType.FLOAT32, + oob_fill=TensorMapOOBFill.NAN_REQUEST_ZERO_FMA, + ) + assert desc is not None + + +class TestTensorMapDescriptorValidation: + """Test validation in TensorMapDescriptor factory methods.""" + + def test_invalid_rank_zero(self, dev, skip_if_no_tma): + buf = dev.allocate(64) + tensor = _DeviceArray(buf, ()) # 0-dim tensor + with pytest.raises(ValueError, match="rank must be between 1 and 5"): + TensorMapDescriptor.from_tiled( + tensor, + box_dim=(), + data_type=TensorMapDataType.FLOAT32, + ) + + def test_invalid_rank_six(self, dev, skip_if_no_tma): + shape = (2, 2, 2, 2, 2, 2) + n_elements = 1 + for s in shape: + n_elements *= s + buf = dev.allocate(n_elements * 4) + arr = _DeviceArray(buf, shape) + with pytest.raises(ValueError, match="rank must be between 1 and 5"): + TensorMapDescriptor.from_tiled( + arr, + box_dim=(2,) * 6, + ) + + def test_box_dim_rank_mismatch(self, dev, skip_if_no_tma): + buf = dev.allocate(1024 * 4) + with pytest.raises(ValueError, match="box_dim must have 1 elements"): + TensorMapDescriptor.from_tiled( + buf, + box_dim=(32, 32), + data_type=TensorMapDataType.FLOAT32, + ) + + def test_box_dim_out_of_range(self, dev, skip_if_no_tma): + buf = dev.allocate(1024 * 4) + with pytest.raises(ValueError, match=r"box_dim\[0\] must be in \[1, 256\]"): + TensorMapDescriptor.from_tiled( + buf, + box_dim=(512,), + data_type=TensorMapDataType.FLOAT32, + ) + + def test_element_strides_rank_mismatch(self, dev, skip_if_no_tma): + buf = dev.allocate(1024 * 4) + with pytest.raises(ValueError, match="element_strides must have 1 elements"): + TensorMapDescriptor.from_tiled( + buf, + box_dim=(64,), + element_strides=(1, 1), + data_type=TensorMapDataType.FLOAT32, + ) + + def test_invalid_data_type(self, dev, skip_if_no_tma): + buf = dev.allocate(1024 * 4) + with pytest.raises(TypeError, match="data_type must be a TensorMapDataType"): + TensorMapDescriptor.from_tiled( + buf, + box_dim=(64,), + data_type=42, + ) + + +class TestTensorMapDtypeMapping: + """Test automatic dtype inference from numpy dtypes.""" + + @pytest.mark.parametrize( + "np_dtype,expected_tma_dt", + [ + (np.uint8, TensorMapDataType.UINT8), + (np.uint16, TensorMapDataType.UINT16), + (np.uint32, TensorMapDataType.UINT32), + (np.int32, TensorMapDataType.INT32), + (np.uint64, TensorMapDataType.UINT64), + (np.int64, TensorMapDataType.INT64), + (np.float16, TensorMapDataType.FLOAT16), + (np.float32, TensorMapDataType.FLOAT32), + (np.float64, TensorMapDataType.FLOAT64), + ], + ) + def test_dtype_mapping(self, np_dtype, expected_tma_dt, dev, skip_if_no_tma): + from cuda.core._tensor_map import _NUMPY_DTYPE_TO_TMA + + assert _NUMPY_DTYPE_TO_TMA[np.dtype(np_dtype)] == expected_tma_dt + + def test_bfloat16_mapping(self): + try: + from ml_dtypes import bfloat16 + + from cuda.core._tensor_map import _NUMPY_DTYPE_TO_TMA + + assert _NUMPY_DTYPE_TO_TMA[np.dtype(bfloat16)] == TensorMapDataType.BFLOAT16 + except ImportError: + pytest.skip("ml_dtypes not installed") + + +class TestTensorMapReplaceAddress: + """Test replace_address functionality.""" + + def test_replace_address(self, dev, skip_if_no_tma): + buf1 = dev.allocate(1024 * 4) + desc = TensorMapDescriptor.from_tiled( + buf1, + box_dim=(64,), + data_type=TensorMapDataType.FLOAT32, + ) + + buf2 = dev.allocate(1024 * 4) + desc.replace_address(buf2) + # No exception means success + + def test_replace_address_requires_device_accessible(self, dev, skip_if_no_tma): + buf1 = dev.allocate(1024 * 4) + desc = TensorMapDescriptor.from_tiled( + buf1, + box_dim=(64,), + data_type=TensorMapDataType.FLOAT32, + ) + # Create a host-only array (not device-accessible) + host_arr = np.zeros(1024, dtype=np.float32) + with pytest.raises(ValueError, match="device-accessible"): + desc.replace_address(host_arr) + + +class TestTensorMapIm2col: + """Test im2col TMA descriptor creation.""" + + def test_from_im2col_3d(self, dev, skip_if_no_tma): + # 3D tensor: batch=1, height=32, channels=64 + buf = dev.allocate(1 * 32 * 64 * 4) + tensor = _DeviceArray(buf, (1, 32, 64)) + desc = TensorMapDescriptor.from_im2col( + tensor, + pixel_box_lower_corner=(0,), + pixel_box_upper_corner=(4,), + channels_per_pixel=64, + pixels_per_column=4, + data_type=TensorMapDataType.FLOAT32, + ) + assert desc is not None + + def test_from_im2col_rank_validation(self, dev, skip_if_no_tma): + buf = dev.allocate(1024 * 4) + with pytest.raises(ValueError, match="Im2col tensor rank must be between 3 and 5"): + TensorMapDescriptor.from_im2col( + buf, + pixel_box_lower_corner=(), + pixel_box_upper_corner=(), + channels_per_pixel=64, + pixels_per_column=4, + data_type=TensorMapDataType.FLOAT32, + ) + + def test_from_im2col_corner_rank_mismatch(self, dev, skip_if_no_tma): + buf = dev.allocate(1 * 32 * 64 * 4) + tensor = _DeviceArray(buf, (1, 32, 64)) # 3D: n_spatial = 1 + with pytest.raises(ValueError, match="pixel_box_lower_corner must have 1 elements"): + TensorMapDescriptor.from_im2col( + tensor, + pixel_box_lower_corner=(0, 0), + pixel_box_upper_corner=(4,), + channels_per_pixel=64, + pixels_per_column=4, + data_type=TensorMapDataType.FLOAT32, + ) + + def test_from_im2col_4d(self, dev, skip_if_no_tma): + # NHWC layout: N=1, H=8, W=8, C=64 — 2 spatial dims + # Exercises spatial corner reversal with n_spatial=2: + # Python [H_lower, W_lower] -> driver [W_lower, H_lower] + shape = (1, 8, 8, 64) + buf = dev.allocate(1 * 8 * 8 * 64 * 4) + tensor = _DeviceArray(buf, shape) + desc = TensorMapDescriptor.from_im2col( + tensor, + pixel_box_lower_corner=(0, 0), + pixel_box_upper_corner=(4, 4), + channels_per_pixel=64, + pixels_per_column=16, + ) + assert desc is not None + + def test_from_im2col_5d(self, dev, skip_if_no_tma): + # NDHWC layout: N=1, D=4, H=8, W=8, C=64 — 3 spatial dims + # Exercises the full spatial corner reversal: + # Python [D, H, W] -> driver [W, H, D] + shape = (1, 4, 8, 8, 64) + buf = dev.allocate(1 * 4 * 8 * 8 * 64 * 4) + tensor = _DeviceArray(buf, shape) + desc = TensorMapDescriptor.from_im2col( + tensor, + pixel_box_lower_corner=(0, 0, 0), + pixel_box_upper_corner=(2, 4, 4), + channels_per_pixel=64, + pixels_per_column=32, + ) + assert desc is not None + + +class TestTensorMapIm2colWide: + """Test im2col-wide TMA descriptor creation (compute capability 10.0+).""" + + @pytest.fixture + def skip_if_no_im2col_wide(self, dev): + cc = dev.compute_capability + if cc.major < 10: + pytest.skip("Device does not support im2col-wide (requires compute capability 10.0+)") + + # Some environments in CI exercise this test module with a cuda.core + # build that does not include im2col-wide symbols (CUDA < 13 build), + # or with driver/GPU combinations that reject im2col-wide descriptor + # encoding for otherwise valid inputs. Probe once per test invocation + # and skip only for those known unsupported cases. + buf = dev.allocate(1 * 32 * 64 * 4) + tensor = _DeviceArray(buf, (1, 32, 64)) + try: + TensorMapDescriptor.from_im2col_wide( + tensor, + pixel_box_lower_corner_width=0, + pixel_box_upper_corner_width=4, + channels_per_pixel=64, + pixels_per_column=4, + data_type=TensorMapDataType.FLOAT32, + ) + except RuntimeError as e: + if "requires a CUDA 13+ build" in str(e): + pytest.skip("Im2col-wide requires cuda.core built with CUDA 13+") + raise + except Exception as e: + if "CUDA_ERROR_INVALID_VALUE" in str(e): + pytest.skip("Im2col-wide unsupported on this driver/GPU combination") + raise + + def test_from_im2col_wide_3d(self, dev, skip_if_no_im2col_wide): + # 3D tensor: batch=1, width=32, channels=64 + buf = dev.allocate(1 * 32 * 64 * 4) + tensor = _DeviceArray(buf, (1, 32, 64)) + desc = TensorMapDescriptor.from_im2col_wide( + tensor, + pixel_box_lower_corner_width=0, + pixel_box_upper_corner_width=4, + channels_per_pixel=64, + pixels_per_column=4, + data_type=TensorMapDataType.FLOAT32, + ) + assert desc is not None + + def test_from_im2col_wide_4d(self, dev, skip_if_no_im2col_wide): + # NHWC layout: N=1, H=8, W=8, C=64 + # Wide mode only uses scalar W corners, even with higher rank + shape = (1, 8, 8, 64) + buf = dev.allocate(1 * 8 * 8 * 64 * 4) + tensor = _DeviceArray(buf, shape) + desc = TensorMapDescriptor.from_im2col_wide( + tensor, + pixel_box_lower_corner_width=0, + pixel_box_upper_corner_width=4, + channels_per_pixel=64, + pixels_per_column=16, + ) + assert desc is not None + + def test_from_im2col_wide_5d(self, dev, skip_if_no_im2col_wide): + # NDHWC layout: N=1, D=4, H=8, W=8, C=64 + # Max rank boundary — verifies all 5 dim/stride slots are filled + shape = (1, 4, 8, 8, 64) + buf = dev.allocate(1 * 4 * 8 * 8 * 64 * 4) + tensor = _DeviceArray(buf, shape) + desc = TensorMapDescriptor.from_im2col_wide( + tensor, + pixel_box_lower_corner_width=0, + pixel_box_upper_corner_width=4, + channels_per_pixel=64, + pixels_per_column=32, + ) + assert desc is not None + + def test_from_im2col_wide_rank_validation(self, dev, skip_if_no_im2col_wide): + buf = dev.allocate(1024 * 4) + with pytest.raises(ValueError, match="Im2col-wide tensor rank must be between 3 and 5"): + TensorMapDescriptor.from_im2col_wide( + buf, + pixel_box_lower_corner_width=0, + pixel_box_upper_corner_width=4, + channels_per_pixel=64, + pixels_per_column=4, + data_type=TensorMapDataType.FLOAT32, + )