1
2! --- Tangled code
3module m_redist_spmatrix
4#ifdef SIESTA__PEXSI
5 implicit none
6 type, public :: comm_t
7    integer :: src, dst, i1, i2, nitems
8 end type comm_t
9
10 integer, parameter, private :: dp = selected_real_kind(10,100)
11
12 type, public :: dp_pointer
13    ! boxed array pointer type
14    real(dp), pointer :: data(:) => null()
15 end type dp_pointer
16
17 type, public ::  aux_matrix
18    integer :: norbs = -1
19    integer :: no_l  = -1
20    integer :: nnzl  = -1
21    integer, pointer :: numcols(:) => null()
22    integer, pointer :: cols(:)    => null()
23    ! array of 1D pointers
24    type(dp_pointer), dimension(:), pointer :: vals(:) => null()
25 end type aux_matrix
26 public :: redistribute_spmatrix
27CONTAINS
28 subroutine redistribute_spmatrix(norbs,m1,dist1,m2,dist2,bridge_comm)
29
30   use mpi
31   use class_Distribution
32   use alloc,       only: re_alloc, de_alloc
33
34   implicit none
35
36   integer, intent(in)       :: norbs   ! Overall number of rows
37   type(aux_matrix) :: m1               ! Source matrix
38   type(aux_matrix) :: m2               ! Destination matrix -- it is allocated
39   type(Distribution) :: dist1, dist2           ! Distributions
40   integer, intent(in)   :: bridge_comm    ! Umbrella Communicator
41
42   type(comm_t), dimension(:), allocatable, target :: comms
43   type(comm_t), dimension(:), allocatable, target :: commsnnz
44   type(comm_t), pointer :: c, cnnz
45
46   integer ::  myrank1, myrank2, myid, gg
47   logical ::  proc_in_set1, proc_in_set2
48   integer ::  ierr
49
50   integer ::  i, io, g1, g2, j, nvals
51   integer ::  comparison, n1, n2, c1, c2
52   integer, parameter :: dp = selected_real_kind(10,100)
53   real(dp), dimension(:), pointer  :: data1 => null(), data2 => null()
54
55   integer, allocatable :: ranks1(:), ranks2(:)
56
57   ! The communicators are a sanity check on the ranks
58
59   call mpi_comm_rank(bridge_comm,myid,ierr)
60   c1 = ref_comm(dist1)
61   c2 = ref_comm(dist2)
62   call MPI_Comm_Compare(c1,c2,comparison,ierr)
63
64   select case (comparison)
65   case (MPI_IDENT)
66      ! Communicators are identical
67   case (MPI_CONGRUENT)
68      ! Communicators have the same group and rank order, but they differ in context
69   case (MPI_SIMILAR)
70      ! Rank order is different
71      call die("Different rank order in communicators")
72   case (MPI_UNEQUAL)
73      ! Big mess
74      call die("Incompatible distribution reference communicators")
75   end select
76
77   ! Now check congruence with the provided bridge_comm
78
79   call MPI_Comm_Compare(c1,bridge_comm,comparison,ierr)
80   select case (comparison)
81   case (MPI_IDENT)
82      ! Communicators are identical
83   case (MPI_CONGRUENT)
84      ! Communicators have the same group and rank order, but they differ in context
85      ! We will use bridge_comm
86   case (MPI_SIMILAR)
87      ! Rank order is different
88      call die("Different rank order in dist communicators and bridge comm")
89   case (MPI_UNEQUAL)
90      ! Big mess
91      call die("Incompatible bridge and dist communicators")
92   end select
93
94   ! Now create groups g1 and g2.
95   ! (DO NOT trust the internal handles)
96   call MPI_Comm_Group(bridge_comm,gg,ierr)
97   call get_ranks_in_ref_comm(dist1, ranks1)
98   call get_ranks_in_ref_comm(dist2, ranks2)
99   n1 = size(ranks1)
100   n2 = size(ranks2)
101   call MPI_Group_Incl(gg,n1,ranks1,g1,ierr)
102   call MPI_Group_Incl(gg,n2,ranks2,g2,ierr)
103
104   ! The rest is the same as before
105
106   call mpi_group_rank(g1,myrank1,ierr)
107   call mpi_group_rank(g2,myrank2,ierr)
108
109   proc_in_set1 = (myrank1 /= MPI_UNDEFINED)
110   proc_in_set2 = (myrank2 /= MPI_UNDEFINED)
111
112   if (proc_in_set1 .or. proc_in_set2) then
113     print "(a,3i6,2l2)", "world_rank, rank1, rank2, ing1?, ing2?", myid,  &
114        myrank1, myrank2, proc_in_set1, proc_in_set2
115   endif
116
117   ! Figure out the communication needs
118   call analyze_comms()
119
120   ! In preparation for the transfer, we allocate
121   ! storage for the second group of processors
122   ! Note that m2%numcols (and, in general, any of the 2nd set
123   ! of arrays), will not be allocated by those processors
124   ! not in the second set.
125
126
127   if (proc_in_set2) then
128      m2%norbs = norbs
129      m2%no_l = num_local_elements(dist2,norbs,myrank2)
130      call re_alloc(m2%numcols,1,m2%no_l,"m2%numcols","redistribute_spmatrix")
131   endif
132
133   if (myid == 0) print *, "About to transfer numcols..."
134   call do_transfers_int(comms,m1%numcols,m2%numcols, &
135        g1,g2,bridge_comm)
136   if (myid == 0) print *, "Transferred numcols."
137
138   ! We need to tell the processes in set 2 how many
139   ! "vals" to expect.
140   if (proc_in_set1) then
141      if (associated(m1%vals)) then
142         nvals = size(m1%vals)
143      else
144         nvals = 0
145      endif
146   endif
147   ! Now do a broadcast within bridge_comm, using as root one
148   ! process in the first set. Let's say the one with rank 0
149   ! in g1, the first in the set, which will have rank=ranks1(1)
150   ! in bridge_comm
151   call MPI_Bcast(nvals,1,MPI_Integer,ranks1(1),bridge_comm,ierr)
152
153   ! Now we can figure out how many non-zeros there are
154   if (proc_in_set2) then
155      m2%nnzl = sum(m2%numcols(1:m2%no_l))
156      call re_alloc(m2%cols,1,m2%nnzl,"m2%cols","redistribute_spmatrix")
157
158      if (nvals > 0) then
159         allocate(m2%vals(nvals))
160         do j=1,nvals
161            call re_alloc(m2%vals(j)%data,1,m2%nnzl, &
162                 "m2%vals(j)%data","redistribute_spmatrix")
163         enddo
164      endif
165
166   endif
167
168   ! Generate a new comms-structure with new start/count indexes
169
170   allocate(commsnnz(size(comms)))
171   do i = 1, size(comms)
172      c => comms(i)
173      cnnz => commsnnz(i)
174
175      cnnz%src = c%src
176      cnnz%dst = c%dst
177      if (myrank1 == c%src) then
178         ! Starting position at source: previous cols plus 1
179         cnnz%i1 = sum(m1%numcols(1:(c%i1-1))) + 1
180         ! Number of items transmitted: total number of cols
181         cnnz%nitems = sum(m1%numcols(c%i1 : c%i1 + c%nitems -1))
182      endif
183      if (myrank2 == c%dst) then
184         ! Starting position at destination: previous cols plus 1
185         cnnz%i2 = sum(m2%numcols(1 : (c%i2-1))) + 1
186         ! Number of items transmitted: total number of cols
187         cnnz%nitems = sum(m2%numcols(c%i2 : c%i2 + c%nitems -1))
188      endif
189   end do
190
191 !!$         do i = 1, size(comms)
192 !!$            c => commsnnz(i)
193 !!$            if (myrank1 == c%src) then
194 !!$               print "(a,i5,a,2i5,2i7,i5)", &
195 !!$                 "commnnz(src): ", i, " src, dst, i1, (), n:", &
196 !!$                 c%src, c%dst, c%i1, -1, c%nitems
197 !!$            endif
198 !!$            if (myrank2 == c%dst) then
199 !!$               print "(a,i5,a,2i5,2i7,i5)", &
200 !!$                 "commnnz(dst): ", i, " src, dst, (), i2, n:", &
201 !!$                 c%src, c%dst, -1, c%i2, c%nitems
202 !!$            endif
203 !!$         enddo
204
205   if (myid == 0) print *, "About to transfer cols..."
206   ! Transfer the cols arrays
207   call do_transfers_int(commsnnz,m1%cols,m2%cols, &
208        g1, g2, bridge_comm)
209
210   if (myid == 0) print *, "About to transfer values..."
211   ! Transfer the values arrays
212   do j=1, nvals
213      if (proc_in_set1) data1 => m1%vals(j)%data
214      if (proc_in_set2) data2 => m2%vals(j)%data
215      call do_transfers_dp(commsnnz,data1,data2, &
216           g1,g2,bridge_comm)
217   enddo
218   nullify(data1,data2)
219   if (myid == 0) print *, "Done transfers."
220
221   deallocate(commsnnz)
222   deallocate(comms)
223   deallocate(ranks1, ranks2)
224
225   call MPI_group_free(gg,ierr)
226   call MPI_group_free(g1,ierr)
227   call MPI_group_free(g2,ierr)
228
229 CONTAINS
230
231
232   !-----------------------------------------------------
233      subroutine analyze_comms()
234
235         integer, allocatable, dimension(:) :: p1, p2, isrc, idst
236         integer :: ncomms
237
238         ! To turn on debug printing, set this to .true.
239         logical, save :: comms_not_printed = .false.
240
241         ! Find the communication needs for each orbital
242         ! This information is replicated in every processor
243         ! (Note that the indexing functions are able to find
244         !  out the information for any processor. For the
245         ! block-cyclic and "pexsi" distributions, this is quite
246         ! easy. For others, the underlying indexing arrays might
247         ! be large...)
248
249         ! It might not be necessary to have this in memory. It
250         ! can be done on the fly
251         allocate(p1(norbs),p2(norbs),isrc(norbs),idst(norbs))
252
253   !      if (myid == 0) then
254   !         write(6,"(5a10)") "Orb", "p1", "i1", "p2", "i2"
255   !      endif
256         do io = 1, norbs
257            p1(io) = node_handling_element(dist1,io)
258            p2(io) = node_handling_element(dist2,io)
259            isrc(io) = index_global_to_local(dist1,io,p1(io))
260            idst(io) = index_global_to_local(dist2,io,p2(io))
261   !         if (myid == 0) then
262   !            if ((norbs < 1000) .or. (mod(io,12) == 0)) then
263   !               write(6,"(5i10)") io, p1(io), isrc(io), p2(io), idst(io)
264   !            endif
265   !        endif
266         enddo
267
268         ! Aggregate communications
269         ! First pass: find out how many there are, on the basis
270         ! of groups of orbitals that share the same source and
271         ! destination. Due to the form of the distributions, the
272         ! local indexes are also correlative in that case, so we
273         ! only need to check for p1 and p2. (Check whether this
274         ! applies to every possible distribution...)
275
276         ncomms = 1
277         do io = 2, norbs
278            if ((p1(io) /= p1(io-1)) .or. (p2(io) /= p2(io-1))) then
279               ncomms = ncomms + 1
280            else
281               !
282            endif
283         enddo
284
285         allocate(comms(ncomms))
286
287         ! Second pass: Fill in the data structures
288         ncomms = 1
289         c => comms(ncomms)
290         io = 1
291         c%src = p1(io)
292         c%dst = p2(io)
293         c%i1  = isrc(io)
294         c%i2  = idst(io)
295         c%nitems = 1
296         do io = 2, norbs
297            if ((p1(io) /= p1(io-1)) .or. (p2(io) /= p2(io-1))) then
298               ! end of group -- new communication
299               ncomms = ncomms + 1
300               c => comms(ncomms)
301               c%src = p1(io)
302               c%dst = p2(io)
303               c%i1  = isrc(io)
304               c%i2  = idst(io)
305               c%nitems = 1
306            else
307               ! we stay in the same communication
308               c%nitems = c%nitems + 1
309            endif
310         enddo
311
312         if (myid == 0 .and. comms_not_printed) then
313            do i = 1, ncomms
314               c => comms(i)
315               write(6,"(a,i5,a,2i5,2i7,i5)") &
316                    "comm: ", i, " src, dst, i1, i2, n:", &
317                    c%src, c%dst, c%i1, c%i2, c%nitems
318            enddo
319            comms_not_printed = .false.
320         endif
321
322         deallocate(p1,p2,isrc,idst)
323
324       end subroutine analyze_comms
325
326 end subroutine redistribute_spmatrix
327 !--------------------------------------------------
328    subroutine do_transfers_int(comms,data1,data2,g1,g2,bridge_comm)
329
330      use mpi
331      type(comm_t), intent(in), target     :: comms(:)
332      integer, dimension(:), pointer  :: data1
333      integer, dimension(:), pointer  :: data2
334      integer, intent(in)                :: g1
335      integer, intent(in)                :: g2
336      integer, intent(in)                :: bridge_comm
337
338      integer                 :: basegroup, nsize1, nsize2, ierr
339      integer, allocatable    :: comm_rank1(:), comm_rank2(:)
340
341
342      integer :: ncomms
343      integer :: i
344      integer :: nrecvs_local, nsends_local
345      integer, allocatable :: statuses(:,:), local_reqR(:), local_reqS(:)
346      integer :: src_in_comm, dst_in_comm
347      integer :: myrank1, myrank2, myrank
348      type(comm_t), pointer :: c
349
350
351       ! Find the rank correspondences, in case
352       ! there is implicit renumbering at the time of group creation
353
354       call  MPI_Comm_group( bridge_comm, basegroup, ierr )
355       call  MPI_Comm_Rank( bridge_comm, myrank, ierr )
356
357       call  MPI_Group_Size( g1, nsize1, ierr )
358       call  MPI_Group_Size( g2, nsize2, ierr )
359
360       allocate(comm_rank1(0:nsize1-1))
361       call MPI_Group_translate_ranks( g1, nsize1, (/ (i,i=0,nsize1-1) /), &
362                                       basegroup, comm_rank1, ierr )
363 !      print "(i4,a,10i3)", myrank, ":Ranks of g1 in base group:", comm_rank1
364
365       allocate(comm_rank2(0:nsize2-1))
366       call MPI_Group_translate_ranks( g2, nsize2, (/ (i,i=0,nsize2-1) /), &
367                                       basegroup, comm_rank2, ierr )
368 !      print "(i4,a,10i3)", myrank,":Ranks of g2 in base group:", comm_rank2
369
370       call mpi_group_rank(g1,myrank1,ierr)
371       call mpi_group_rank(g2,myrank2,ierr)
372
373 !      print "(i4,a,2i6)", myrank,": Ranks in g1 and g2: ", myrank1, myrank2
374 !      print "(i4,a,2i3)", myrank,": g1 and g2: ", g1, g2
375
376
377       ! Do the actual transfers.
378       ! This version with non-blocking communications
379
380      ncomms = size(comms)
381
382       ! Some bookkeeping for the requests
383       nrecvs_local = 0
384       nsends_local = 0
385       do i=1,ncomms
386          c => comms(i)
387          if (myrank2 == c%dst) then
388             nrecvs_local = nrecvs_local + 1
389          endif
390          if (myrank1 == c%src) then
391             nsends_local = nsends_local + 1
392          endif
393       enddo
394       allocate(local_reqR(nrecvs_local))
395       allocate(local_reqS(nsends_local))
396       allocate(statuses(mpi_status_size,nrecvs_local))
397
398       ! First, post the receives
399       nrecvs_local = 0
400       do i=1,ncomms
401          c => comms(i)
402          if (myrank2 == c%dst) then
403             nrecvs_local = nrecvs_local + 1
404             src_in_comm = comm_rank1(c%src)
405             call MPI_irecv(data2(c%i2),c%nitems,MPI_integer,src_in_comm, &
406                            i,bridge_comm,local_reqR(nrecvs_local),ierr)
407          endif
408       enddo
409
410       ! Post the sends
411       nsends_local = 0
412       do i=1,ncomms
413          c => comms(i)
414          if (myrank1 == c%src) then
415             nsends_local = nsends_local + 1
416             dst_in_comm = comm_rank2(c%dst)
417             call MPI_isend(data1(c%i1),c%nitems,MPI_integer,dst_in_comm, &
418                         i,bridge_comm,local_reqS(nsends_local),ierr)
419          endif
420       enddo
421
422       ! A former loop of waits can be substituted by a "waitall",
423       ! with every processor keeping track of the actual number of
424       ! requests in which it is involved.
425
426       ! Should we wait also on the sends?
427
428       call MPI_waitall(nrecvs_local, local_reqR, statuses, ierr)
429
430
431       ! This barrier is needed, I think
432       call MPI_Barrier(bridge_comm,ierr)
433
434       deallocate(local_reqR, local_reqS, statuses)
435
436     end subroutine do_transfers_int
437
438 !--------------------------------------------------
439    subroutine do_transfers_dp(comms,data1,data2,g1,g2,bridge_comm)
440
441      use mpi
442      integer, parameter :: dp = selected_real_kind(10,100)
443
444      type(comm_t), intent(in), target     :: comms(:)
445      real(dp), dimension(:), pointer :: data1
446      real(dp), dimension(:), pointer :: data2
447      integer, intent(in)                :: g1
448      integer, intent(in)                :: g2
449      integer, intent(in)                :: bridge_comm
450
451      integer                 :: basegroup, nsize1, nsize2, ierr
452      integer, allocatable    :: comm_rank1(:), comm_rank2(:)
453
454
455      integer :: ncomms
456      integer :: i
457      integer :: nrecvs_local, nsends_local
458      integer, allocatable :: statuses(:,:), local_reqR(:), local_reqS(:)
459      integer :: src_in_comm, dst_in_comm
460      integer :: myrank1, myrank2, myid
461      type(comm_t), pointer :: c
462
463      call  MPI_Comm_Rank( bridge_comm, myid, ierr )
464 !     print *, "Entering transfer_dp"
465 !     print *, "rank, Associated data1: ", myid, associated(data1)
466 !     print *, "rank, Associated data2: ", myid, associated(data2)
467
468       ! Find the rank correspondences, in case
469       ! there is implicit renumbering at the time of group creation
470
471       call  MPI_Comm_group( bridge_comm, basegroup, ierr )
472       call  MPI_Group_Size( g1, nsize1, ierr )
473       call  MPI_Group_Size( g2, nsize2, ierr )
474       allocate(comm_rank1(0:nsize1-1))
475       call MPI_Group_translate_ranks( g1, nsize1, (/ (i,i=0,nsize1-1) /), &
476                                       basegroup, comm_rank1, ierr )
477 !      print "(a,10i3)", "Ranks of g1 in base group:", comm_rank1
478       allocate(comm_rank2(0:nsize2-1))
479       call MPI_Group_translate_ranks( g2, nsize2, (/ (i,i=0,nsize2-1) /), &
480                                       basegroup, comm_rank2, ierr )
481 !      print "(a,10i3)", "Ranks of g2 in base group:", comm_rank2
482
483       call mpi_group_rank(g1,myrank1,ierr)
484       call mpi_group_rank(g2,myrank2,ierr)
485
486       ! Do the actual transfers.
487       ! This version with non-blocking communications
488
489      ncomms = size(comms)
490
491       ! Some bookkeeping for the requests
492       nrecvs_local = 0
493       nsends_local = 0
494       do i=1,ncomms
495          c => comms(i)
496          if (myrank2 == c%dst) then
497             nrecvs_local = nrecvs_local + 1
498          endif
499          if (myrank1 == c%src) then
500             nsends_local = nsends_local + 1
501          endif
502       enddo
503       allocate(local_reqR(nrecvs_local))
504       allocate(local_reqS(nsends_local))
505       allocate(statuses(mpi_status_size,nrecvs_local))
506
507       ! First, post the receives
508       nrecvs_local = 0
509       do i=1,ncomms
510          c => comms(i)
511          if (myrank2 == c%dst) then
512             nrecvs_local = nrecvs_local + 1
513             src_in_comm = comm_rank1(c%src)
514             call MPI_irecv(data2(c%i2),c%nitems,MPI_Double_Precision,src_in_comm, &
515                            i,bridge_comm,local_reqR(nrecvs_local),ierr)
516          endif
517       enddo
518
519       ! Post the sends
520       nsends_local = 0
521       do i=1,ncomms
522          c => comms(i)
523          if (myrank1 == c%src) then
524             nsends_local = nsends_local + 1
525             dst_in_comm = comm_rank2(c%dst)
526             call MPI_isend(data1(c%i1),c%nitems,MPI_Double_Precision,dst_in_comm, &
527                         i,bridge_comm,local_reqS(nsends_local),ierr)
528          endif
529       enddo
530
531       ! A former loop of waits can be substituted by a "waitall",
532       ! with every processor keeping track of the actual number of
533       ! requests in which it is involved.
534
535       ! Should we wait also on the sends?
536
537       call MPI_waitall(nrecvs_local, local_reqR, statuses, ierr)
538
539
540       ! This barrier is needed, I think
541       call MPI_Barrier(bridge_comm,ierr)
542
543       deallocate(local_reqR, local_reqS, statuses)
544
545     end subroutine do_transfers_dp
546#endif
547end module m_redist_spmatrix
548! --- End of tangled code
549