-
Notifications
You must be signed in to change notification settings - Fork 253
Add TMA TensorMapDescriptor support #1687
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e3e1899
77a3c8e
19c4a0f
35a04b9
bb19e4f
23a8900
0a1b720
44fbdcf
bdf39a2
96a3e84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -68,3 +68,12 @@ | |
| Stream, | ||
| StreamOptions, | ||
| ) | ||
| from cuda.core._tensor_map import ( | ||
| TensorMapDataType, | ||
| TensorMapDescriptor, | ||
| TensorMapIm2ColWideMode, | ||
| TensorMapInterleave, | ||
| TensorMapL2Promotion, | ||
| TensorMapOOBFill, | ||
| TensorMapSwizzle, | ||
|
Comment on lines
+72
to
+78
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical:
|
||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| // | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| #include "tensor_map_cccl.h" | ||
|
|
||
| #include <string.h> | ||
|
|
||
| #include <algorithm> | ||
| #include <exception> | ||
|
|
||
| #if defined(__has_include) | ||
| # if __has_include(<cuda/tma>) | ||
| # include <cuda/tma> | ||
| # define CUDA_CORE_HAS_CUDA_TMA 1 | ||
| # else | ||
| # define CUDA_CORE_HAS_CUDA_TMA 0 | ||
| # endif | ||
| # if __has_include(<dlpack/dlpack.h>) | ||
| # include <dlpack/dlpack.h> | ||
| # define CUDA_CORE_HAS_DLPACK_H 1 | ||
| # else | ||
| # define CUDA_CORE_HAS_DLPACK_H 0 | ||
| # endif | ||
| #else | ||
| # define CUDA_CORE_HAS_CUDA_TMA 0 | ||
| # define CUDA_CORE_HAS_DLPACK_H 0 | ||
| #endif | ||
|
|
||
| static inline void cuda_core_write_err(char* err, size_t cap, const char* msg) noexcept | ||
| { | ||
| if (!err || cap == 0) | ||
| return; | ||
| if (!msg) | ||
| { | ||
| err[0] = '\0'; | ||
| return; | ||
| } | ||
| size_t n = ::strlen(msg); | ||
| if (n >= cap) | ||
| n = cap - 1; | ||
| ::memcpy(err, msg, n); | ||
| err[n] = '\0'; | ||
| } | ||
|
|
||
| int cuda_core_cccl_make_tma_descriptor_tiled( | ||
| void* out_tensor_map, | ||
| void* data, | ||
| int device_type, | ||
| int device_id, | ||
| int ndim, | ||
| const int64_t* shape, | ||
| const int64_t* strides, | ||
| uint8_t dtype_code, | ||
| uint8_t dtype_bits, | ||
| uint16_t dtype_lanes, | ||
| const int* box_sizes, | ||
| const int* elem_strides, | ||
| int interleave_layout, | ||
| int swizzle, | ||
| int l2_fetch_size, | ||
| int oob_fill, | ||
| char* err, | ||
| size_t err_cap) noexcept | ||
| { | ||
| #if !(CUDA_CORE_HAS_CUDA_TMA && CUDA_CORE_HAS_DLPACK_H) | ||
| (void)out_tensor_map; | ||
| (void)data; | ||
| (void)device_type; | ||
| (void)device_id; | ||
| (void)ndim; | ||
| (void)shape; | ||
| (void)strides; | ||
| (void)dtype_code; | ||
| (void)dtype_bits; | ||
| (void)dtype_lanes; | ||
| (void)box_sizes; | ||
| (void)elem_strides; | ||
| (void)interleave_layout; | ||
| (void)swizzle; | ||
| (void)l2_fetch_size; | ||
| (void)oob_fill; | ||
| cuda_core_write_err(err, err_cap, "CCCL <cuda/tma> and/or <dlpack/dlpack.h> not available at build time"); | ||
| return 1; | ||
| #else | ||
| try | ||
| { | ||
| if (!out_tensor_map) | ||
| { | ||
| cuda_core_write_err(err, err_cap, "out_tensor_map is NULL"); | ||
| return 1; | ||
| } | ||
| if (!data) | ||
| { | ||
| cuda_core_write_err(err, err_cap, "tensor data pointer is NULL"); | ||
| return 1; | ||
| } | ||
| if (!shape || !box_sizes || ndim <= 0) | ||
| { | ||
| cuda_core_write_err(err, err_cap, "invalid rank/shape/box_sizes"); | ||
| return 1; | ||
| } | ||
|
|
||
| DLTensor t{}; | ||
| t.data = data; | ||
| t.device = {static_cast<DLDeviceType>(device_type), device_id}; | ||
| t.ndim = ndim; | ||
| t.dtype.code = dtype_code; | ||
| t.dtype.bits = dtype_bits; | ||
| t.dtype.lanes = dtype_lanes; | ||
| // CCCL promises not to mutate the arrays, but DLPack uses non-const pointers. | ||
| t.shape = const_cast<int64_t*>(shape); | ||
| t.strides = const_cast<int64_t*>(strides); | ||
| t.byte_offset = 0; | ||
|
|
||
| const auto layout = static_cast<cuda::tma_interleave_layout>(interleave_layout); | ||
| const auto swz = static_cast<cuda::tma_swizzle>(swizzle); | ||
| const auto l2 = static_cast<cuda::tma_l2_fetch_size>(l2_fetch_size); | ||
| const auto oob = static_cast<cuda::tma_oob_fill>(oob_fill); | ||
|
|
||
| auto box = cuda::std::span<const int>(box_sizes, static_cast<size_t>(ndim)); | ||
|
|
||
| CUtensorMap desc{}; | ||
| if (elem_strides) | ||
| { | ||
| auto es = cuda::std::span<const int>(elem_strides, static_cast<size_t>(ndim)); | ||
| desc = cuda::make_tma_descriptor(t, box, es, layout, swz, l2, oob); | ||
| } | ||
| else | ||
| { | ||
| desc = cuda::make_tma_descriptor(t, box, layout, swz, l2, oob); | ||
| } | ||
|
|
||
| ::memcpy(out_tensor_map, &desc, sizeof(CUtensorMap)); | ||
| cuda_core_write_err(err, err_cap, nullptr); | ||
| return 0; | ||
| } | ||
| catch (const std::exception& e) | ||
| { | ||
| cuda_core_write_err(err, err_cap, e.what()); | ||
| return 1; | ||
| } | ||
| catch (...) | ||
| { | ||
| cuda_core_write_err(err, err_cap, "unknown error while building TMA descriptor"); | ||
| return 1; | ||
| } | ||
| #endif | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| // | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| #ifndef CUDA_CORE_TENSOR_MAP_CCCL_H_ | ||
| #define CUDA_CORE_TENSOR_MAP_CCCL_H_ | ||
|
|
||
| #include <stddef.h> | ||
| #include <stdint.h> | ||
|
|
||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #endif | ||
|
|
||
| // Build a tiled CUtensorMap using CCCL's cuda::make_tma_descriptor (from <cuda/tma>). | ||
| // | ||
| // Returns 0 on success; on failure returns non-zero and writes a best-effort | ||
| // human-readable message into (err, err_cap) if provided. | ||
| int cuda_core_cccl_make_tma_descriptor_tiled( | ||
| void* out_tensor_map, | ||
| void* data, | ||
| int device_type, | ||
| int device_id, | ||
| int ndim, | ||
| const int64_t* shape, // length ndim | ||
| const int64_t* strides, // length ndim, or NULL for contiguous | ||
| uint8_t dtype_code, | ||
| uint8_t dtype_bits, | ||
| uint16_t dtype_lanes, | ||
| const int* box_sizes, // length ndim | ||
| const int* elem_strides, // length ndim, or NULL for all-ones overload | ||
| int interleave_layout, | ||
| int swizzle, | ||
| int l2_fetch_size, | ||
| int oob_fill, | ||
| char* err, | ||
| size_t err_cap) noexcept; | ||
|
|
||
| #ifdef __cplusplus | ||
| } // extern "C" | ||
| #endif | ||
|
|
||
| #endif // CUDA_CORE_TENSOR_MAP_CCCL_H_ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ from cpython.mem cimport PyMem_Malloc, PyMem_Free | |
| from libc.stdint cimport (intptr_t, | ||
| int8_t, int16_t, int32_t, int64_t, | ||
| uint8_t, uint16_t, uint32_t, uint64_t,) | ||
| from libc.string cimport memcpy | ||
| from libcpp cimport bool as cpp_bool | ||
| from libcpp.complex cimport complex as cpp_complex | ||
| from libcpp cimport nullptr | ||
|
|
@@ -16,6 +17,8 @@ import ctypes | |
| import numpy | ||
|
|
||
| from cuda.core._memory import Buffer | ||
| from cuda.core._tensor_map import TensorMapDescriptor as _TensorMapDescriptor_py | ||
| from cuda.core._tensor_map cimport TensorMapDescriptor | ||
| from cuda.core._utils.cuda_utils import driver | ||
| from cuda.bindings cimport cydriver | ||
|
|
||
|
|
@@ -97,6 +100,9 @@ cdef object numpy_complex64 = numpy.complex64 | |
| cdef object numpy_complex128 = numpy.complex128 | ||
|
|
||
|
|
||
| cdef object tensor_map_descriptor_type = _TensorMapDescriptor_py | ||
|
|
||
|
|
||
| # limitation due to cython/cython#534 | ||
| ctypedef void* voidptr | ||
|
|
||
|
|
@@ -124,6 +130,25 @@ cdef inline int prepare_arg( | |
| return 0 | ||
|
|
||
|
|
||
| cdef inline int prepare_tensor_map_arg( | ||
| vector.vector[void*]& data, | ||
| vector.vector[void*]& data_addresses, | ||
| TensorMapDescriptor arg, | ||
| const size_t idx) except -1: | ||
| # Allocate a temporary buffer for the 128-byte CUtensorMap struct. | ||
| # We copy rather than pointing directly at arg._tensor_map for lifetime | ||
| # safety: ParamHolder owns and frees its argument buffers independently. | ||
| cdef void* ptr = PyMem_Malloc(sizeof(cydriver.CUtensorMap)) | ||
| if ptr is NULL: | ||
| raise MemoryError("Failed to allocate memory for CUtensorMap") | ||
| memcpy(ptr, arg._get_data_ptr(), sizeof(cydriver.CUtensorMap)) | ||
| # data[idx] is tracked so the allocation is freed in ParamHolder.__dealloc__, | ||
| # data_addresses[idx] is the pointer passed to cuLaunchKernel. | ||
| data_addresses[idx] = ptr | ||
| data[idx] = ptr | ||
| return 0 | ||
|
|
||
|
|
||
| cdef inline int prepare_ctypes_arg( | ||
| vector.vector[void*]& data, | ||
| vector.vector[void*]& data_addresses, | ||
|
|
@@ -273,6 +298,9 @@ cdef class ParamHolder: | |
| # it's a CUdeviceptr: | ||
| self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr()) | ||
| continue | ||
| elif arg_type is tensor_map_descriptor_type: | ||
| prepare_tensor_map_arg(self.data, self.data_addresses, <TensorMapDescriptor>arg, i) | ||
| continue | ||
| elif arg_type is bool: | ||
| prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) | ||
| continue | ||
|
|
@@ -322,6 +350,9 @@ cdef class ParamHolder: | |
| elif isinstance(arg, driver.CUgraphConditionalHandle): | ||
| prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, arg, i) | ||
| continue | ||
| elif isinstance(arg, tensor_map_descriptor_type): | ||
| prepare_tensor_map_arg(self.data, self.data_addresses, <TensorMapDescriptor>arg, i) | ||
| continue | ||
|
Comment on lines
+353
to
+355
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't remember off top of my head why we had to repeat the checks in this block, I believe it has to do with backward compatibility. Since TMA support is new, we don't need to add it here. |
||
| # TODO: support ctypes/numpy struct | ||
| raise TypeError("the argument is of unsupported type: " + str(type(arg))) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from cuda.bindings cimport cydriver | ||
|
|
||
|
|
||
| cdef class TensorMapDescriptor: | ||
| cdef cydriver.CUtensorMap _tensor_map | ||
| cdef object _source_ref | ||
| cdef object _view_ref | ||
| cdef object _repr_info | ||
|
|
||
| cdef void* _get_data_ptr(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we ignore them...?