Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class TestElement(xo.HybridClass):
unsigned int gid = blockIdx.x*blockDim.x + threadIdx.x; // global thread ID: 0,1,2,3

// init shared memory with chunk of input array
extern __shared__ double sdata[2];
extern __shared__ double sdata[];
sdata[tid] = input_arr[gid];
__syncthreads();

Expand Down
11 changes: 9 additions & 2 deletions xobjects/context_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,15 +352,22 @@ def __invert__(self):

cudaheader: List[SourceType] = [
"""\
typedef signed long long int64_t; //only_for_context cuda
typedef signed int int32_t; //only_for_context cuda
typedef signed short int16_t; //only_for_context cuda
typedef signed char int8_t; //only_for_context cuda
typedef unsigned long long uint64_t; //only_for_context cuda
typedef unsigned int uint32_t; //only_for_context cuda
typedef unsigned short uint16_t; //only_for_context cuda
typedef unsigned char uint8_t; //only_for_context cuda

#if defined(__CUDACC__) && !defined(__HIPCC__)
typedef signed long long int64_t;
typedef unsigned long long uint64_t;
#endif

#ifndef NULL
#define NULL nullptr
#endif

"""
]

Expand Down
10 changes: 9 additions & 1 deletion xobjects/headers/atomicadd.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ DEF_ATOMIC_ADD(double , f64)
// -------------------------------------------
#if defined(XO_CONTEXT_CUDA)
// CUDA compiler may not have <stdint.h>, so define the types if needed.
#ifdef __CUDACC_RTC__
#if defined(__CUDACC_RTC__) && !defined(__HIPCC__)
// NVRTC (CuPy RawModule default) can’t see <stdint.h>, so detect it via __CUDACC_RTC__
typedef signed char int8_t;
typedef short int16_t;
Expand All @@ -111,6 +111,14 @@ DEF_ATOMIC_ADD(double , f64)
typedef unsigned short uint16_t;
typedef unsigned int uint32_t;
typedef unsigned long long uint64_t;
#elif defined(__HIPCC__) && !defined(__CUDACC_RTC__)
// ROCm appears to have definitions for 64-bit int types
typedef signed char int8_t;
typedef short int16_t;
typedef int int32_t;
typedef unsigned char uint8_t;
typedef unsigned short uint16_t;
typedef unsigned int uint32_t;
#else
// Alternatively, NVCC path is fine with host headers
#include <stdint.h>
Expand Down