Skip to content

Commit

Permalink
Update mpp_scatter_mpi.fh
Browse files Browse the repository at this point in the history
  • Loading branch information
ganganoaa committed Jul 13, 2023
1 parent f9e0a37 commit cdf6e60
Showing 1 changed file with 69 additions and 14 deletions.
83 changes: 69 additions & 14 deletions mpp/include/mpp_scatter_mpi.fh
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,29 @@ subroutine MPP_SCATTER_PELIST_3D_(is, ie, js, je, nk, pelist, array_seg, data, i
integer :: total_msgsize
integer :: stdout_unit
integer :: gindx1D(4*size(pelist)) !< Packed version of gind
integer :: ii, jj, k, m

if (.not.ANY(mpp_pe().eq.pelist(:))) return

! Get peset number
n = get_peset(pelist)
if( peset(n)%count.EQ.1 ) return

if (any(peset(n)%list .ne. pelist)) call mpp_error(FATAL, "mpp_scatter_mpi: two pelists don't match")
do i = 1, mpp_npes()
if(peset(n)%list(i) == pelist(1)) then
root_pe = i - 1
exit
endif
enddo

if (any(peset(n)%list(:) .ne. pelist(:))) call mpp_error(FATAL, "mpp_scatter_mpi: two pelists don't match")

if( debug )then
call SYSTEM_CLOCK(tick)
write( stdout_unit,'(a,i18,a,i6,a,i6)' )&
'T=',tick, ' PE=', pe, ' MPP_SCATTER begin: from_pe, length=', mpp_pe()
end if

root_pe = pelist(1)

if (is_root_pe) then
if (.not.ANY(pelist(:).eq.root_pe)) call mpp_error(FATAL, "mpp_scatter_mpi: root_pe not a member of pelist")
if (root_pe .ne. pelist(1)) call mpp_error(FATAL, "mpp_scatter_mpi: root_pe is not the first pe of pelist")
Expand All @@ -107,28 +113,36 @@ subroutine MPP_SCATTER_PELIST_3D_(is, ie, js, je, nk, pelist, array_seg, data, i
if (mpp_npes() .gt. 1) call MPI_GATHER(my_ind, 4, MPI_INTEGER4, gindx1D, 4, MPI_INTEGER4, root_pe, peset(n)%id, ierr)
if (ierr /= MPI_SUCCESS) call mpp_error(FATAL, "mpp_scatter_mpi::MPI_GATHER something is wrong")

! Unpack gindx1D(:) to gind(:,:)
if (is_root_pe) then
do i = 1, size(pelist)
gind(1, i) = gindx1D((i-1)*4 + 1)
gind(2, i) = gindx1D((i-1)*4 + 2)
gind(3, i) = gindx1D((i-1)*4 + 3)
gind(4, i) = gindx1D((i-1)*4 + 4)
enddo
if (any(mpp_pe() .eq. pelist(:))) then
print "('mpp_scatter_mpi:my_ind ', 'PE ', i4, 4i4)", mpp_pe(), my_ind(1:4)
end if
if (mpp_pe() .eq. pelist(1)) then
!print *, 'mpp_scatter_mpi:pelist', size(pelist), pelist(:)
!print *, 'mpp_scatter_mpi:list', size(peset(n)%list), peset(n)%list(:)
!print "('mpp_scatter_mpi:gindx1D ', 'root PE=', i4, 4i4)", root_pe, gindx1D((i-1)*4 + 1:(i-1)*4 + 4)
print *, 'mpp_scatter_mpi:gindx1D root PE', root_pe, gindx1D(:)
end if

! Compute my message size
msgsize = (my_ind(2)-my_ind(1)+1) * (my_ind(4)-my_ind(3)+1) * nk
allocate(recv_buf(msgsize))

if (is_root_pe) then
! Unpack gindx1D(:) to gind(:,:)
do i = 1, size(pelist)
gind(1, i) = gindx1D((i-1)*4 + 1)
gind(2, i) = gindx1D((i-1)*4 + 2)
gind(3, i) = gindx1D((i-1)*4 + 3)
gind(4, i) = gindx1D((i-1)*4 + 4)
end do
! Update group indices
gind(1,:)=gind(1,:)+ioff
gind(2,:)=gind(2,:)+ioff
gind(3,:)=gind(3,:)+joff
gind(4,:)=gind(4,:)+joff
! check indices to make sure they are within the range of "data"
if ((minval(gind).lt.1) .OR. (maxval(gind(1:2,:)).gt.size(data,1)) .OR. (maxval(gind(3:4,:)).gt.size(data,2))) then
print "('mpp_scatter_mpi:min-max ', 3i6)", minval(gind), maxval(gind(1:2,:)), maxval(gind(3:4,:))
call mpp_error(FATAL,"mpp_scatter_mpi:: specified indices (with shift) are outside &
of the range of the receiving array")
end if
Expand All @@ -143,6 +157,7 @@ subroutine MPP_SCATTER_PELIST_3D_(is, ie, js, je, nk, pelist, array_seg, data, i
total_msgsize = total_msgsize + send_count(i)
! Compute data displacements
displ(i) = total_msgsize - send_count(i)
!print "('mpp_scatter_mpi:', 2i6)", displ(i), send_count(i)
enddo

! Allocate send buffer
Expand All @@ -157,17 +172,57 @@ subroutine MPP_SCATTER_PELIST_3D_(is, ie, js, je, nk, pelist, array_seg, data, i
j2 = gind(4,i)
total_msgsize = total_msgsize + send_count(i)
! Pack data segments
send_buf(displ(i)+1:total_msgsize) = reshape(data(i1:i2,j1:j2,1:nk), (/size(data(i1:i2,j1:j2,1:nk))/))
m = displ(i) + 1
do k = 1, nk
do jj = j1, j2
do ii = i1, i2
send_buf(m) = data(ii, jj, k)
m = m + 1
end do
end do
end do
!send_buf(displ(i)+1:total_msgsize) = reshape(data(i1:i2,j1:j2,1:nk), (/size(data(i1:i2,j1:j2,1:nk))/), &
! data(i1:i2,j1:j2,1:nk))
!print *, 'mpp_scatter_mpi:send_buf', i, send_buf(displ(i)+1:displ(i)+4)
!print *, 'mpp_scatter_mpi:data', i, data(i1:i2,j1:j2,1:nk)
enddo
end if

! Scatter data chunks to respective PEs
if (mpp_npes() .gt. 1) call MPI_SCATTERV(send_buf, send_count, displ, MPI_TYPE_, recv_buf, &
msgsize, MPI_TYPE_, root_pe, peset(n)%id, ierr)
msgsize, MPI_TYPE_, root_pe, peset(n)%id, ierr)
if (ierr /= MPI_SUCCESS) call mpp_error(FATAL, "mpp_scatter_mpi::MPI_SCATTERV something is wrong")

! Unpack received data
array_seg(is:ie,js:je,1:nk) = reshape(recv_buf, (/shape(array_seg(is:ie,js:je,1:nk))/))
!if (is_root_pe) then
!array_seg(is:ie,js:je,1:nk) = reshape(send_buf(1:send_count(1)), (/shape(array_seg(is:ie,js:je,1:nk))/), &
!send_buf(1:send_count(1)))
!else
m = 1
do k = 1, nk
do jj = js, je
do ii = is, ie
array_seg(ii,jj,k) = recv_buf(m)
m = m + 1
end do
end do
end do
!array_seg(is:ie,js:je,1:nk) = reshape(recv_buf, (/shape(array_seg(is:ie,js:je,1:nk))/), recv_buf)
!end if

i = 1
if (i .le. size(pelist)) then
if (mpp_pe() .eq. pelist(i)) then
!print *, 'mpp_scatter_mpi:array_seg', array_seg(ie-3:ie,js:js,1:1)
end if
if (is_root_pe) then
!if (any(array_seg(is:ie,js:je,1:nk) .ne. data(gind(1,1):gind(2,1),gind(3,1):gind(4,1),1:nk))) then
!print *, 'mpp_scatter_mpi: data did not match!'
!end if
!print *, 'mpp_scatter_mpi:data', data(gind(2,i)-3:gind(2,i),gind(3,i):gind(3,i),1:1)
!print *, 'mpp_scatter_mpi:array_seg', array_seg(ie:ie,je:je,1:1)
end if
end if

if( debug .and. (current_clock.NE.0) ) then
call increment_current_clock( EVENT_SCATTER, msgsize*MPP_TYPE_BYTELEN_ )
Expand Down

0 comments on commit cdf6e60

Please sign in to comment.