@@ -369,21 +369,14 @@ SNSolverHPCCUDA::SNSolverHPCCUDA( Config* settings ) {
369369 ErrorMessages::Error ( " The number of processors must be less than or equal to the number of quadrature points." , CURRENT_FUNCTION );
370370 }
371371
372- if ( _numProcs == 1 ) {
373- _localNSys = _nSys;
374- _startSysIdx = 0 ;
375- _endSysIdx = _nSys;
376- }
377- else {
378- _localNSys = _nSys / ( _numProcs - 1 );
379- _startSysIdx = _rank * _localNSys;
380- _endSysIdx = _rank * _localNSys + _localNSys;
381-
382- if ( _rank == _numProcs - 1 ) {
383- _localNSys = _nSys - _startSysIdx;
384- _endSysIdx = _nSys;
385- }
386- }
372+ const unsigned long numRanks = static_cast <unsigned long >( _numProcs );
373+ const unsigned long rankIndex = static_cast <unsigned long >( _rank );
374+ const unsigned long baseChunk = _nSys / numRanks;
375+ const unsigned long remainder = _nSys % numRanks;
376+
377+ _localNSys = baseChunk + ( rankIndex < remainder ? 1UL : 0UL );
378+ _startSysIdx = rankIndex * baseChunk + std::min ( rankIndex, remainder );
379+ _endSysIdx = _startSysIdx + _localNSys;
387380
388381 // std::cout << "Rank: " << _rank << " startSysIdx: " << _startSysIdx << " endSysIdx: " << _endSysIdx << " localNSys: " << _localNSys <<
389382 // std::endl;
@@ -613,9 +606,29 @@ void SNSolverHPCCUDA::InitCUDA() {
613606 ErrorMessages::Error ( " No CUDA-capable GPU detected, but SNSolverHPCCUDA was requested." , CURRENT_FUNCTION );
614607 }
615608
616- _cudaDeviceId = 0 ; // first version: pin to one GPU
609+ int localRank = 0 ;
610+ int localSize = 1 ;
611+ #ifdef IMPORT_MPI
612+ MPI_Comm localComm = MPI_COMM_NULL;
613+ MPI_Comm_split_type ( MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, _rank, MPI_INFO_NULL, &localComm );
614+ MPI_Comm_rank ( localComm, &localRank );
615+ MPI_Comm_size ( localComm, &localSize );
616+ MPI_Comm_free ( &localComm );
617+ #endif
618+
619+ _cudaDeviceId = localRank % nDevices;
617620 CheckCuda ( cudaSetDevice ( _cudaDeviceId ), " cudaSetDevice" );
618621
622+ if ( _rank == 0 ) {
623+ auto log = spdlog::get ( " event" );
624+ if ( log ) {
625+ log->info ( " | CUDA backend: {} local MPI rank(s), {} visible CUDA device(s)." , localSize, nDevices );
626+ if ( localSize > nDevices ) {
627+ log->warn ( " | CUDA backend: {} local MPI rank(s) exceed {} visible device(s); GPUs will be shared." , localSize, nDevices );
628+ }
629+ }
630+ }
631+
619632 _device = new DeviceBuffers ();
620633
621634 const std::size_t nCells = static_cast <std::size_t >( _nCells );
@@ -805,6 +818,20 @@ void SNSolverHPCCUDA::Solve() {
805818 RK2AverageAndScalarFluxKernel<<<gridCells, threads>>> (
806819 _nCells, _localNSys, _device->quadWeights , _device->solRK0 , _device->sol , _device->scalarFlux );
807820 CheckCuda ( cudaGetLastError (), " RK2AverageAndScalarFluxKernel launch" );
821+ #ifdef IMPORT_MPI
822+ CheckCuda ( cudaMemcpy ( _scalarFlux.data (),
823+ _device->scalarFlux ,
824+ static_cast <std::size_t >( _nCells ) * sizeof ( double ),
825+ cudaMemcpyDeviceToHost ),
826+ " download scalar flux after RK2 average" );
827+ std::vector<double > tempScalarFlux ( _scalarFlux );
828+ MPI_Barrier ( MPI_COMM_WORLD );
829+ MPI_Allreduce ( tempScalarFlux.data (), _scalarFlux.data (), _nCells, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD );
830+ MPI_Barrier ( MPI_COMM_WORLD );
831+ CheckCuda (
832+ cudaMemcpy ( _device->scalarFlux , _scalarFlux.data (), static_cast <std::size_t >( _nCells ) * sizeof ( double ), cudaMemcpyHostToDevice ),
833+ " sync allreduced scalar flux after RK2 average" );
834+ #endif
808835 }
809836 else {
810837 ( _spatialOrder == 2 ) ? FluxOrder2 () : FluxOrder1 ();
0 commit comments