@@ -76,18 +76,22 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
7676 const void *x,
7777 void *stream_) const {
7878 cudaStream_t stream = (cudaStream_t)stream_;
79- if (_opaque->internal ->maxThreadsPerBlock () == CUDA_BLOCK_SIZE_1024) {
79+ if (_opaque->internal ->maxThreadsPerBlock () == CUDA_BLOCK_SIZE_4096) {
80+ CHECK_STATUS (launchKernel<CUDA_BLOCK_SIZE_4096>(
81+ y, x, _info.dtype , _info.batch_size , _info.seq_len , _info.total_seq_len ,
82+ _info.y_stride_b , _info.y_stride_i , _info.x_stride_b , _info.x_stride_i , stream));
83+ } else if (_opaque->internal ->maxThreadsPerBlock () == CUDA_BLOCK_SIZE_2048) {
84+ CHECK_STATUS (launchKernel<CUDA_BLOCK_SIZE_2048>(
85+ y, x, _info.dtype , _info.batch_size , _info.seq_len , _info.total_seq_len ,
86+ _info.y_stride_b , _info.y_stride_i , _info.x_stride_b , _info.x_stride_i , stream));
87+ } else if (_opaque->internal ->maxThreadsPerBlock () == CUDA_BLOCK_SIZE_1024) {
8088 CHECK_STATUS (launchKernel<CUDA_BLOCK_SIZE_1024>(
8189 y, x, _info.dtype , _info.batch_size , _info.seq_len , _info.total_seq_len ,
8290 _info.y_stride_b , _info.y_stride_i , _info.x_stride_b , _info.x_stride_i , stream));
8391 } else if (_opaque->internal ->maxThreadsPerBlock () == CUDA_BLOCK_SIZE_512) {
8492 CHECK_STATUS (launchKernel<CUDA_BLOCK_SIZE_512>(
8593 y, x, _info.dtype , _info.batch_size , _info.seq_len , _info.total_seq_len ,
8694 _info.y_stride_b , _info.y_stride_i , _info.x_stride_b , _info.x_stride_i , stream));
87- } else if (_opaque->internal ->maxThreadsPerBlock () == CUDA_BLOCK_SIZE_4096) {
88- CHECK_STATUS (launchKernel<CUDA_BLOCK_SIZE_4096>(
89- y, x, _info.dtype , _info.batch_size , _info.seq_len , _info.total_seq_len ,
90- _info.y_stride_b , _info.y_stride_i , _info.x_stride_b , _info.x_stride_i , stream));
9195 } else {
9296 return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
9397 }
0 commit comments