@@ -501,8 +501,6 @@ __global__ void eval_vvar_grad_kern( size_t ntasks,
501501 double * den_y_eval_device = nullptr ;
502502 double * den_z_eval_device = nullptr ;
503503
504- constexpr auto warp_size = hip::warp_size;
505-
506504 if constexpr (den_select == DEN_S) {
507505 den_eval_device = task.den_s ;
508506 den_x_eval_device = task.dden_sx ;
@@ -534,62 +532,61 @@ __global__ void eval_vvar_grad_kern( size_t ntasks,
534532 const auto * dbasis_z_eval_device = task.dbfz ;
535533
536534 const auto * den_basis_prod_device = task.zmat ;
537-
538- __shared__ double den_shared[4 ][warp_size][VVAR_KERNEL_SM_BLOCK+1 ];
539535
540- for ( int bid_x = blockIdx.x * blockDim.x ;
541- bid_x < nbf;
542- bid_x += blockDim.x * gridDim.x ) {
543-
544- for ( int bid_y = blockIdx.y * VVAR_KERNEL_SM_BLOCK;
545- bid_y < npts;
546- bid_y += VVAR_KERNEL_SM_BLOCK * gridDim.y ) {
547-
548- for (int sm_y = threadIdx.y ; sm_y < VVAR_KERNEL_SM_BLOCK; sm_y += blockDim.y ) {
549- den_shared[0 ][threadIdx.x ][sm_y] = 0 .;
550- den_shared[1 ][threadIdx.x ][sm_y] = 0 .;
551- den_shared[2 ][threadIdx.x ][sm_y] = 0 .;
552- den_shared[3 ][threadIdx.x ][sm_y] = 0 .;
536+ // We always launch enough blocks to cover npts, so blocks aren't doing multiple results
537+ double den_reg = 0 .;
538+ double dx_reg = 0 .;
539+ double dy_reg = 0 .;
540+ double dz_reg = 0 .;
541+
542+ // Have each thread accumulate its own reduction result into a register.
543+ // There's no real _need_ for LDS because the reductions are small and
544+ // therefore can be done without sharing.
545+ for ( int ibf = 0 ; ibf < nbf; ibf++ ) {
546+
547+ for ( int ipt = blockIdx.x * blockDim.x + threadIdx.x ; ipt < npts; ipt += blockDim.x * gridDim.x ) {
548+
549+ const double * bf_col = basis_eval_device + ibf*npts;
550+ const double * bf_x_col = dbasis_x_eval_device + ibf*npts;
551+ const double * bf_y_col = dbasis_y_eval_device + ibf*npts;
552+ const double * bf_z_col = dbasis_z_eval_device + ibf*npts;
553+ const double * db_col = den_basis_prod_device + ibf*npts;
554+
555+ den_reg += bf_col[ ipt ] * db_col[ ipt ];
556+ dx_reg += 2 * bf_x_col[ ipt ] * db_col[ ipt ];
557+ dy_reg += 2 * bf_y_col[ ipt ] * db_col[ ipt ];
558+ dz_reg += 2 * bf_z_col[ ipt ] * db_col[ ipt ];
559+ }
560+ }
553561
554- if (bid_y + threadIdx.x < npts and bid_x + sm_y < nbf) {
555- const double * db_col = den_basis_prod_device + (bid_x + sm_y)*npts;
556- const double * bf_col = basis_eval_device + (bid_x + sm_y)*npts;
557- const double * bf_x_col = dbasis_x_eval_device + (bid_x + sm_y)*npts;
558- const double * bf_y_col = dbasis_y_eval_device + (bid_x + sm_y)*npts;
559- const double * bf_z_col = dbasis_z_eval_device + (bid_x + sm_y)*npts;
560562
561- den_shared[0 ][threadIdx.x ][sm_y] = bf_col [ bid_y + threadIdx.x ] * db_col[ bid_y + threadIdx.x ];
562- den_shared[1 ][threadIdx.x ][sm_y] = bf_x_col[ bid_y + threadIdx.x ] * db_col[ bid_y + threadIdx.x ];
563- den_shared[2 ][threadIdx.x ][sm_y] = bf_y_col[ bid_y + threadIdx.x ] * db_col[ bid_y + threadIdx.x ];
564- den_shared[3 ][threadIdx.x ][sm_y] = bf_z_col[ bid_y + threadIdx.x ] * db_col[ bid_y + threadIdx.x ];
565- }
566- }
567- __syncthreads ();
563+ for ( int ipt = blockIdx.x * blockDim.x + threadIdx.x ; ipt < npts; ipt += blockDim.x * gridDim.x ) {
564+ den_eval_device [ipt] = den_reg;
565+ den_x_eval_device [ipt] = dx_reg ;
566+ den_y_eval_device [ipt] = dy_reg ;
567+ den_z_eval_device [ipt] = dz_reg ;
568+ }
568569
570+ }
569571
570- for (int sm_y = threadIdx.y ; sm_y < VVAR_KERNEL_SM_BLOCK; sm_y += blockDim.y ) {
571- const int tid_y = bid_y + sm_y;
572- double den_reg = den_shared[0 ][sm_y][threadIdx.x ];
573- double dx_reg = den_shared[1 ][sm_y][threadIdx.x ];
574- double dy_reg = den_shared[2 ][sm_y][threadIdx.x ];
575- double dz_reg = den_shared[3 ][sm_y][threadIdx.x ];
576572
577- // Warp blocks are stored col major
578- den_reg = hip::warp_reduce_sum<warp_size>( den_reg );
579- dx_reg = 2 . * hip::warp_reduce_sum<warp_size>( dx_reg );
580- dy_reg = 2 . * hip::warp_reduce_sum<warp_size>( dy_reg );
581- dz_reg = 2 . * hip::warp_reduce_sum<warp_size>( dz_reg );
573+ __global__ void eval_vvars_gga_kernel (
574+ size_t npts,
575+ const double * den_x_eval_device,
576+ const double * den_y_eval_device,
577+ const double * den_z_eval_device,
578+ double * gamma_eval_device
579+ ) {
580+
581+ const int tid = threadIdx.x + blockIdx.x * blockDim.x ;
582+ if ( tid < npts ) {
582583
584+ const double dx = den_x_eval_device[ tid ];
585+ const double dy = den_y_eval_device[ tid ];
586+ const double dz = den_z_eval_device[ tid ];
587+
588+ gamma_eval_device[tid] = dx*dx + dy*dy + dz*dz;
583589
584- if ( threadIdx.x == 0 and tid_y < npts ) {
585- atomicAdd ( den_eval_device + tid_y, den_reg );
586- atomicAdd ( den_x_eval_device + tid_y, dx_reg );
587- atomicAdd ( den_y_eval_device + tid_y, dy_reg );
588- atomicAdd ( den_z_eval_device + tid_y, dz_reg );
589- }
590- }
591- __syncthreads ();
592- }
593590 }
594591
595592}
@@ -656,10 +653,9 @@ void eval_vvar( size_t ntasks, int32_t nbf_max, int32_t npts_max, bool do_grad,
656653 dim3 threads;
657654 dim3 blocks;
658655 if ( do_grad ) {
659- threads = dim3 ( hip::warp_size, hip::max_warps_per_thread_block / 2 , 1 );
660- blocks = dim3 ( std::min (uint64_t (4 ), util::div_ceil ( nbf_max, 4 )),
661- std::min (uint64_t (16 ), util::div_ceil ( nbf_max, 16 )),
662- ntasks );
656+ threads = dim3 (hip::max_warps_per_thread_block, 1 , 1 );
657+ blocks = dim3 ( util::div_ceil ( npts_max, threads.x ),
658+ 1 , 1 );
663659 } else {
664660 threads = dim3 ( hip::warp_size, hip::max_warps_per_thread_block, 1 );
665661 blocks = dim3 ( util::div_ceil ( nbf_max, threads.x ),
0 commit comments