From cdf6e609a52456fac337b1bcc8d186e9abcdc209 Mon Sep 17 00:00:00 2001 From: Ganga P Purja Pun Date: Thu, 13 Jul 2023 14:27:21 -0400 Subject: [PATCH] Update mpp_scatter_mpi.fh --- mpp/include/mpp_scatter_mpi.fh | 83 ++++++++++++++++++++++++++++------ 1 file changed, 69 insertions(+), 14 deletions(-) diff --git a/mpp/include/mpp_scatter_mpi.fh b/mpp/include/mpp_scatter_mpi.fh index a1c7557663..32d351839e 100644 --- a/mpp/include/mpp_scatter_mpi.fh +++ b/mpp/include/mpp_scatter_mpi.fh @@ -68,6 +68,7 @@ subroutine MPP_SCATTER_PELIST_3D_(is, ie, js, je, nk, pelist, array_seg, data, i integer :: total_msgsize integer :: stdout_unit integer :: gindx1D(4*size(pelist)) !< Packed version of gind + integer :: ii, jj, k, m if (.not.ANY(mpp_pe().eq.pelist(:))) return @@ -75,7 +76,14 @@ subroutine MPP_SCATTER_PELIST_3D_(is, ie, js, je, nk, pelist, array_seg, data, i n = get_peset(pelist) if( peset(n)%count.EQ.1 ) return - if (any(peset(n)%list .ne. pelist)) call mpp_error(FATAL, "mpp_scatter_mpi: two pelists don't match") + do i = 1, mpp_npes() + if(peset(n)%list(i) == pelist(1)) then + root_pe = i - 1 + exit + endif + enddo + + if (any(peset(n)%list(:) .ne. pelist(:))) call mpp_error(FATAL, "mpp_scatter_mpi: two pelists don't match") if( debug )then call SYSTEM_CLOCK(tick) @@ -83,8 +91,6 @@ subroutine MPP_SCATTER_PELIST_3D_(is, ie, js, je, nk, pelist, array_seg, data, i 'T=',tick, ' PE=', pe, ' MPP_SCATTER begin: from_pe, length=', mpp_pe() end if - root_pe = pelist(1) - if (is_root_pe) then if (.not.ANY(pelist(:).eq.root_pe)) call mpp_error(FATAL, "mpp_scatter_mpi: root_pe not a member of pelist") if (root_pe .ne. pelist(1)) call mpp_error(FATAL, "mpp_scatter_mpi: root_pe is not the first pe of pelist") @@ -107,14 +113,14 @@ subroutine MPP_SCATTER_PELIST_3D_(is, ie, js, je, nk, pelist, array_seg, data, i if (mpp_npes() .gt. 1) call MPI_GATHER(my_ind, 4, MPI_INTEGER4, gindx1D, 4, MPI_INTEGER4, root_pe, peset(n)%id, ierr) if (ierr /= MPI_SUCCESS) call mpp_error(FATAL, "mpp_scatter_mpi::MPI_GATHER something is wrong") - ! Unpack gindx1D(:) to gind(:,:) - if (is_root_pe) then - do i = 1, size(pelist) - gind(1, i) = gindx1D((i-1)*4 + 1) - gind(2, i) = gindx1D((i-1)*4 + 2) - gind(3, i) = gindx1D((i-1)*4 + 3) - gind(4, i) = gindx1D((i-1)*4 + 4) - enddo + if (any(mpp_pe() .eq. pelist(:))) then + print "('mpp_scatter_mpi:my_ind ', 'PE ', i4, 4i4)", mpp_pe(), my_ind(1:4) + end if + if (mpp_pe() .eq. pelist(1)) then + !print *, 'mpp_scatter_mpi:pelist', size(pelist), pelist(:) + !print *, 'mpp_scatter_mpi:list', size(peset(n)%list), peset(n)%list(:) + !print "('mpp_scatter_mpi:gindx1D ', 'root PE=', i4, 4i4)", root_pe, gindx1D((i-1)*4 + 1:(i-1)*4 + 4) + print *, 'mpp_scatter_mpi:gindx1D root PE', root_pe, gindx1D(:) end if ! Compute my message size @@ -122,6 +128,13 @@ subroutine MPP_SCATTER_PELIST_3D_(is, ie, js, je, nk, pelist, array_seg, data, i allocate(recv_buf(msgsize)) if (is_root_pe) then + ! Unpack gindx1D(:) to gind(:,:) + do i = 1, size(pelist) + gind(1, i) = gindx1D((i-1)*4 + 1) + gind(2, i) = gindx1D((i-1)*4 + 2) + gind(3, i) = gindx1D((i-1)*4 + 3) + gind(4, i) = gindx1D((i-1)*4 + 4) + end do ! Update group indices gind(1,:)=gind(1,:)+ioff gind(2,:)=gind(2,:)+ioff @@ -129,6 +142,7 @@ subroutine MPP_SCATTER_PELIST_3D_(is, ie, js, je, nk, pelist, array_seg, data, i gind(4,:)=gind(4,:)+joff ! check indices to make sure they are within the range of "data" if ((minval(gind).lt.1) .OR. (maxval(gind(1:2,:)).gt.size(data,1)) .OR. (maxval(gind(3:4,:)).gt.size(data,2))) then + print "('mpp_scatter_mpi:min-max ', 3i6)", minval(gind), maxval(gind(1:2,:)), maxval(gind(3:4,:)) call mpp_error(FATAL,"mpp_scatter_mpi:: specified indices (with shift) are outside & of the range of the receiving array") end if @@ -143,6 +157,7 @@ subroutine MPP_SCATTER_PELIST_3D_(is, ie, js, je, nk, pelist, array_seg, data, i total_msgsize = total_msgsize + send_count(i) ! Compute data displacements displ(i) = total_msgsize - send_count(i) + !print "('mpp_scatter_mpi:', 2i6)", displ(i), send_count(i) enddo ! Allocate send buffer @@ -157,17 +172,57 @@ subroutine MPP_SCATTER_PELIST_3D_(is, ie, js, je, nk, pelist, array_seg, data, i j2 = gind(4,i) total_msgsize = total_msgsize + send_count(i) ! Pack data segments - send_buf(displ(i)+1:total_msgsize) = reshape(data(i1:i2,j1:j2,1:nk), (/size(data(i1:i2,j1:j2,1:nk))/)) + m = displ(i) + 1 + do k = 1, nk + do jj = j1, j2 + do ii = i1, i2 + send_buf(m) = data(ii, jj, k) + m = m + 1 + end do + end do + end do + !send_buf(displ(i)+1:total_msgsize) = reshape(data(i1:i2,j1:j2,1:nk), (/size(data(i1:i2,j1:j2,1:nk))/), & + ! data(i1:i2,j1:j2,1:nk)) + !print *, 'mpp_scatter_mpi:send_buf', i, send_buf(displ(i)+1:displ(i)+4) + !print *, 'mpp_scatter_mpi:data', i, data(i1:i2,j1:j2,1:nk) enddo end if ! Scatter data chunks to respective PEs if (mpp_npes() .gt. 1) call MPI_SCATTERV(send_buf, send_count, displ, MPI_TYPE_, recv_buf, & - msgsize, MPI_TYPE_, root_pe, peset(n)%id, ierr) + msgsize, MPI_TYPE_, root_pe, peset(n)%id, ierr) if (ierr /= MPI_SUCCESS) call mpp_error(FATAL, "mpp_scatter_mpi::MPI_SCATTERV something is wrong") ! Unpack received data - array_seg(is:ie,js:je,1:nk) = reshape(recv_buf, (/shape(array_seg(is:ie,js:je,1:nk))/)) + !if (is_root_pe) then + !array_seg(is:ie,js:je,1:nk) = reshape(send_buf(1:send_count(1)), (/shape(array_seg(is:ie,js:je,1:nk))/), & + !send_buf(1:send_count(1))) + !else + m = 1 + do k = 1, nk + do jj = js, je + do ii = is, ie + array_seg(ii,jj,k) = recv_buf(m) + m = m + 1 + end do + end do + end do + !array_seg(is:ie,js:je,1:nk) = reshape(recv_buf, (/shape(array_seg(is:ie,js:je,1:nk))/), recv_buf) + !end if + + i = 1 + if (i .le. size(pelist)) then + if (mpp_pe() .eq. pelist(i)) then + !print *, 'mpp_scatter_mpi:array_seg', array_seg(ie-3:ie,js:js,1:1) + end if + if (is_root_pe) then + !if (any(array_seg(is:ie,js:je,1:nk) .ne. data(gind(1,1):gind(2,1),gind(3,1):gind(4,1),1:nk))) then + !print *, 'mpp_scatter_mpi: data did not match!' + !end if + !print *, 'mpp_scatter_mpi:data', data(gind(2,i)-3:gind(2,i),gind(3,i):gind(3,i),1:1) + !print *, 'mpp_scatter_mpi:array_seg', array_seg(ie:ie,je:je,1:1) + end if + end if if( debug .and. (current_clock.NE.0) ) then call increment_current_clock( EVENT_SCATTER, msgsize*MPP_TYPE_BYTELEN_ )