1!--------------------------------------------------------------------------------------------------!
2! Copyright (C) by the DBCSR developers group - All rights reserved                                !
3! This file is part of the DBCSR library.                                                          !
4!                                                                                                  !
5! For information on the license, see the LICENSE file.                                            !
6! For further information please visit https://dbcsr.cp2k.org                                      !
7! SPDX-License-Identifier: GPL-2.0+                                                                !
8!--------------------------------------------------------------------------------------------------!
9
10MODULE dbcsr_mp_operations
11   !! Wrappers to message passing calls.
12
13   USE dbcsr_config, ONLY: has_MPI
14   USE dbcsr_data_methods, ONLY: dbcsr_data_get_type
15   USE dbcsr_mp_methods, ONLY: &
16      dbcsr_mp_get_process, dbcsr_mp_grid_setup, dbcsr_mp_group, dbcsr_mp_has_subgroups, &
17      dbcsr_mp_my_col_group, dbcsr_mp_my_row_group, dbcsr_mp_mynode, dbcsr_mp_mypcol, &
18      dbcsr_mp_myprow, dbcsr_mp_npcols, dbcsr_mp_nprows, dbcsr_mp_numnodes, dbcsr_mp_pgrid
19   USE dbcsr_ptr_util, ONLY: memory_copy
20   USE dbcsr_types, ONLY: dbcsr_data_obj, &
21                          dbcsr_mp_obj, &
22                          dbcsr_type_complex_4, &
23                          dbcsr_type_complex_8, &
24                          dbcsr_type_int_4, &
25                          dbcsr_type_real_4, &
26                          dbcsr_type_real_8
27   USE dbcsr_kinds, ONLY: real_4, &
28                          real_8
29   USE dbcsr_mpiwrap, ONLY: &
30      mp_allgather, mp_alltoall, mp_gatherv, mp_ibcast, mp_irecv, mp_iscatter, mp_isend, &
31      mp_isendrecv, mp_rget, mp_sendrecv, mp_type_descriptor_type, mp_type_indexed_make_c, &
32      mp_type_indexed_make_d, mp_type_indexed_make_r, mp_type_indexed_make_z, mp_type_make, &
33      mp_waitall, mp_win_create
34#include "base/dbcsr_base_uses.f90"
35
36   IMPLICIT NONE
37
38   PRIVATE
39
40   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbcsr_mp_operations'
41
42   ! MP routines
43   PUBLIC :: hybrid_alltoall_s1, hybrid_alltoall_d1, &
44             hybrid_alltoall_c1, hybrid_alltoall_z1, &
45             hybrid_alltoall_i1, hybrid_alltoall_any
46   PUBLIC :: dbcsr_allgatherv
47   PUBLIC :: dbcsr_sendrecv_any
48   PUBLIC :: dbcsr_isend_any, dbcsr_irecv_any
49   PUBLIC :: dbcsr_win_create_any, dbcsr_rget_any, dbcsr_ibcast_any
50   PUBLIC :: dbcsr_iscatterv_any, dbcsr_gatherv_any
51   PUBLIC :: dbcsr_isendrecv_any
52   ! Type helpers
53   PUBLIC :: dbcsr_mp_type_from_anytype
54
55   INTERFACE dbcsr_hybrid_alltoall
56      MODULE PROCEDURE hybrid_alltoall_s1, hybrid_alltoall_d1, &
57         hybrid_alltoall_c1, hybrid_alltoall_z1
58      MODULE PROCEDURE hybrid_alltoall_i1
59      MODULE PROCEDURE hybrid_alltoall_any
60   END INTERFACE
61
62CONTAINS
63
64   SUBROUTINE hybrid_alltoall_any(sb, scount, sdispl, &
65                                  rb, rcount, rdispl, mp_env, most_ptp, remainder_ptp, no_hybrid)
66      TYPE(dbcsr_data_obj), INTENT(IN)                   :: sb
67      INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN)      :: scount, sdispl
68      TYPE(dbcsr_data_obj), INTENT(INOUT)                :: rb
69      INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN)      :: rcount, rdispl
70      TYPE(dbcsr_mp_obj), INTENT(IN)                     :: mp_env
71      LOGICAL, INTENT(in), OPTIONAL                      :: most_ptp, remainder_ptp, no_hybrid
72
73      CHARACTER(len=*), PARAMETER :: routineN = 'hybrid_alltoall_any'
74
75      INTEGER                                            :: error_handle
76
77!   ---------------------------------------------------------------------------
78
79      CALL timeset(routineN, error_handle)
80
81      SELECT CASE (dbcsr_data_get_type(sb))
82      CASE (dbcsr_type_real_4)
83         CALL hybrid_alltoall_s1(sb%d%r_sp, scount, sdispl, &
84                                 rb%d%r_sp, rcount, rdispl, mp_env, &
85                                 most_ptp, remainder_ptp, no_hybrid)
86      CASE (dbcsr_type_real_8)
87         CALL hybrid_alltoall_d1(sb%d%r_dp, scount, sdispl, &
88                                 rb%d%r_dp, rcount, rdispl, mp_env, &
89                                 most_ptp, remainder_ptp, no_hybrid)
90      CASE (dbcsr_type_complex_4)
91         CALL hybrid_alltoall_c1(sb%d%c_sp, scount, sdispl, &
92                                 rb%d%c_sp, rcount, rdispl, mp_env, &
93                                 most_ptp, remainder_ptp, no_hybrid)
94      CASE (dbcsr_type_complex_8)
95         CALL hybrid_alltoall_z1(sb%d%c_dp, scount, sdispl, &
96                                 rb%d%c_dp, rcount, rdispl, mp_env, &
97                                 most_ptp, remainder_ptp, no_hybrid)
98      CASE default
99         DBCSR_ABORT("Invalid data type")
100      END SELECT
101
102      CALL timestop(error_handle)
103   END SUBROUTINE hybrid_alltoall_any
104
105   SUBROUTINE hybrid_alltoall_i1(sb, scount, sdispl, &
106                                 rb, rcount, rdispl, mp_env, most_ptp, remainder_ptp, no_hybrid)
107      !! Row/column and global all-to-all
108      !!
109      !! Communicator selection
110      !! Uses row and column communicators for row/column
111      !! sends. Remaining sends are performed using the global
112      !! communicator.  Point-to-point isend/irecv are used if ptp is
113      !! set, otherwise a alltoall collective call is issued.
114      !! see mp_alltoall
115
116      INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(in), &
117         TARGET                                 :: sb
118      INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN) :: scount, sdispl
119      INTEGER, DIMENSION(:), CONTIGUOUS, &
120         INTENT(INOUT), TARGET                  :: rb
121      INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN) :: rcount, rdispl
122      TYPE(dbcsr_mp_obj), INTENT(IN)           :: mp_env
123         !! MP Environment
124      LOGICAL, INTENT(IN), OPTIONAL            :: most_ptp, remainder_ptp, &
125                                                  no_hybrid
126         !! Use point-to-point for row/column; default is no
127         !! Use point-to-point for remaining; default is no
128         !! Use regular global collective; default is no
129
130      INTEGER :: all_group, mynode, mypcol, myprow, nall_rr, nall_sr, ncol_rr, &
131                 ncol_sr, npcols, nprows, nrow_rr, nrow_sr, numnodes, dst, src, &
132                 prow, pcol, send_cnt, recv_cnt, tag, grp, i
133      INTEGER, ALLOCATABLE, DIMENSION(:) :: all_rr, all_sr, col_rr, col_sr, &
134                                            new_rcount, new_rdispl, new_scount, new_sdispl, row_rr, row_sr
135      INTEGER, DIMENSION(:, :), POINTER        :: pgrid
136      LOGICAL                                  :: most_collective, &
137                                                  remainder_collective, no_h
138      INTEGER, DIMENSION(:), POINTER, CONTIGUOUS :: send_data_p, recv_data_p
139      TYPE(dbcsr_mp_obj)                       :: mpe
140
141      IF (.NOT. dbcsr_mp_has_subgroups(mp_env)) THEN
142         mpe = mp_env
143         CALL dbcsr_mp_grid_setup(mpe)
144      ENDIF
145      most_collective = .TRUE.
146      remainder_collective = .TRUE.
147      no_h = .FALSE.
148      IF (PRESENT(most_ptp)) most_collective = .NOT. most_ptp
149      IF (PRESENT(remainder_ptp)) remainder_collective = .NOT. remainder_ptp
150      IF (PRESENT(no_hybrid)) no_h = no_hybrid
151      all_group = dbcsr_mp_group(mp_env)
152      ! Don't use subcommunicators if they're not defined.
153      no_h = no_h .OR. .NOT. dbcsr_mp_has_subgroups(mp_env) .OR. .NOT. has_MPI
154      subgrouped: IF (mp_env%mp%subgroups_defined .AND. .NOT. no_h) THEN
155         mynode = dbcsr_mp_mynode(mp_env)
156         numnodes = dbcsr_mp_numnodes(mp_env)
157         nprows = dbcsr_mp_nprows(mp_env)
158         npcols = dbcsr_mp_npcols(mp_env)
159         myprow = dbcsr_mp_myprow(mp_env)
160         mypcol = dbcsr_mp_mypcol(mp_env)
161         pgrid => dbcsr_mp_pgrid(mp_env)
162         ALLOCATE (row_sr(0:npcols - 1)); nrow_sr = 0
163         ALLOCATE (row_rr(0:npcols - 1)); nrow_rr = 0
164         ALLOCATE (col_sr(0:nprows - 1)); ncol_sr = 0
165         ALLOCATE (col_rr(0:nprows - 1)); ncol_rr = 0
166         ALLOCATE (all_sr(0:numnodes - 1)); nall_sr = 0
167         ALLOCATE (all_rr(0:numnodes - 1)); nall_rr = 0
168         ALLOCATE (new_scount(numnodes), new_rcount(numnodes))
169         ALLOCATE (new_sdispl(numnodes), new_rdispl(numnodes))
170         IF (.NOT. remainder_collective) THEN
171            CALL remainder_point_to_point()
172         ENDIF
173         IF (.NOT. most_collective) THEN
174            CALL most_point_to_point()
175         ELSE
176            CALL most_alltoall()
177         ENDIF
178         IF (remainder_collective) THEN
179            CALL remainder_alltoall()
180         ENDIF
181         ! Wait for all issued sends and receives.
182         IF (.NOT. most_collective) THEN
183            CALL mp_waitall(row_sr(0:nrow_sr - 1))
184            CALL mp_waitall(col_sr(0:ncol_sr - 1))
185            CALL mp_waitall(row_rr(0:nrow_rr - 1))
186            CALL mp_waitall(col_rr(0:ncol_rr - 1))
187         END IF
188         IF (.NOT. remainder_collective) THEN
189            CALL mp_waitall(all_sr(1:nall_sr))
190            CALL mp_waitall(all_rr(1:nall_rr))
191         ENDIF
192      ELSE
193         CALL mp_alltoall(sb, scount, sdispl, &
194                          rb, rcount, rdispl, &
195                          all_group)
196      ENDIF subgrouped
197   CONTAINS
198      SUBROUTINE most_alltoall()
199         DO pcol = 0, npcols - 1
200            new_scount(1 + pcol) = scount(1 + pgrid(myprow, pcol))
201            new_rcount(1 + pcol) = rcount(1 + pgrid(myprow, pcol))
202            new_sdispl(1 + pcol) = sdispl(1 + pgrid(myprow, pcol))
203            new_rdispl(1 + pcol) = rdispl(1 + pgrid(myprow, pcol))
204         END DO
205         CALL mp_alltoall(sb, new_scount(1:npcols), new_sdispl(1:npcols), &
206                          rb, new_rcount(1:npcols), new_rdispl(1:npcols), &
207                          dbcsr_mp_my_row_group(mp_env))
208         DO prow = 0, nprows - 1
209            new_scount(1 + prow) = scount(1 + pgrid(prow, mypcol))
210            new_rcount(1 + prow) = rcount(1 + pgrid(prow, mypcol))
211            new_sdispl(1 + prow) = sdispl(1 + pgrid(prow, mypcol))
212            new_rdispl(1 + prow) = rdispl(1 + pgrid(prow, mypcol))
213         END DO
214         CALL mp_alltoall(sb, new_scount(1:nprows), new_sdispl(1:nprows), &
215                          rb, new_rcount(1:nprows), new_rdispl(1:nprows), &
216                          dbcsr_mp_my_col_group(mp_env))
217      END SUBROUTINE most_alltoall
218      SUBROUTINE most_point_to_point()
219         ! Go through my prow and exchange.
220         DO i = 0, npcols - 1
221            pcol = MOD(mypcol + i, npcols)
222            grp = dbcsr_mp_my_row_group(mp_env)
223            !
224            dst = dbcsr_mp_get_process(mp_env, myprow, pcol)
225            send_cnt = scount(dst + 1)
226            send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1)
227            tag = 4*mypcol
228            IF (send_cnt .GT. 0) THEN
229               CALL mp_isend(send_data_p, pcol, grp, row_sr(nrow_sr), tag)
230               nrow_sr = nrow_sr + 1
231            ENDIF
232            !
233            pcol = MODULO(mypcol - i, npcols)
234            src = dbcsr_mp_get_process(mp_env, myprow, pcol)
235            recv_cnt = rcount(src + 1)
236            recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1)
237            tag = 4*pcol
238            IF (recv_cnt .GT. 0) THEN
239               CALL mp_irecv(recv_data_p, pcol, grp, row_rr(nrow_rr), tag)
240               nrow_rr = nrow_rr + 1
241            ENDIF
242         ENDDO
243         ! go through my pcol and exchange
244         DO i = 0, nprows - 1
245            prow = MOD(myprow + i, nprows)
246            grp = dbcsr_mp_my_col_group(mp_env)
247            !
248            dst = dbcsr_mp_get_process(mp_env, prow, mypcol)
249            send_cnt = scount(dst + 1)
250            IF (send_cnt .GT. 0) THEN
251               send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1)
252               tag = 4*myprow + 1
253               CALL mp_isend(send_data_p, prow, grp, col_sr(ncol_sr), tag)
254               ncol_sr = ncol_sr + 1
255            ENDIF
256            !
257            prow = MODULO(myprow - i, nprows)
258            src = dbcsr_mp_get_process(mp_env, prow, mypcol)
259            recv_cnt = rcount(src + 1)
260            IF (recv_cnt .GT. 0) THEN
261               recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1)
262               tag = 4*prow + 1
263               CALL mp_irecv(recv_data_p, prow, grp, col_rr(ncol_rr), tag)
264               ncol_rr = ncol_rr + 1
265            ENDIF
266         ENDDO
267      END SUBROUTINE most_point_to_point
268      SUBROUTINE remainder_alltoall()
269         new_scount(:) = scount(:)
270         new_rcount(:) = rcount(:)
271         DO prow = 0, nprows - 1
272            new_scount(1 + pgrid(prow, mypcol)) = 0
273            new_rcount(1 + pgrid(prow, mypcol)) = 0
274         END DO
275         DO pcol = 0, npcols - 1
276            new_scount(1 + pgrid(myprow, pcol)) = 0
277            new_rcount(1 + pgrid(myprow, pcol)) = 0
278         END DO
279         CALL mp_alltoall(sb, new_scount, sdispl, &
280                          rb, new_rcount, rdispl, all_group)
281      END SUBROUTINE remainder_alltoall
282      SUBROUTINE remainder_point_to_point()
283         INTEGER                                            :: col, row
284
285         DO row = 0, nprows - 1
286            prow = MOD(row + myprow, nprows)
287            IF (prow .EQ. myprow) CYCLE
288            DO col = 0, npcols - 1
289               pcol = MOD(col + mypcol, npcols)
290               IF (pcol .EQ. mypcol) CYCLE
291               dst = dbcsr_mp_get_process(mp_env, prow, pcol)
292               send_cnt = scount(dst + 1)
293               IF (send_cnt .GT. 0) THEN
294                  tag = 4*mynode + 2
295                  send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1)
296                  CALL mp_isend(send_data_p, dst, all_group, all_sr(nall_sr + 1), tag)
297                  nall_sr = nall_sr + 1
298               ENDIF
299               !
300               src = dbcsr_mp_get_process(mp_env, prow, pcol)
301               recv_cnt = rcount(src + 1)
302               IF (recv_cnt .GT. 0) THEN
303                  recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1)
304                  tag = 4*src + 2
305                  CALL mp_irecv(recv_data_p, src, all_group, all_rr(nall_rr + 1), tag)
306                  nall_rr = nall_rr + 1
307               ENDIF
308            ENDDO
309         ENDDO
310      END SUBROUTINE remainder_point_to_point
311   END SUBROUTINE hybrid_alltoall_i1
312
313   FUNCTION dbcsr_mp_type_from_anytype(data_area) RESULT(mp_type)
314      !! Creates an MPI combined type from the given anytype.
315
316      TYPE(dbcsr_data_obj), INTENT(IN)                   :: data_area
317         !! Data area of any type
318      TYPE(mp_type_descriptor_type)                      :: mp_type
319         !! Type descriptor
320
321      SELECT CASE (data_area%d%data_type)
322      CASE (dbcsr_type_int_4)
323         mp_type = mp_type_make(data_area%d%i4)
324      CASE (dbcsr_type_real_4)
325         mp_type = mp_type_make(data_area%d%r_sp)
326      CASE (dbcsr_type_real_8)
327         mp_type = mp_type_make(data_area%d%r_dp)
328      CASE (dbcsr_type_complex_4)
329         mp_type = mp_type_make(data_area%d%c_sp)
330      CASE (dbcsr_type_complex_8)
331         mp_type = mp_type_make(data_area%d%c_dp)
332      END SELECT
333   END FUNCTION dbcsr_mp_type_from_anytype
334
335   SUBROUTINE dbcsr_sendrecv_any(msgin, dest, msgout, source, comm)
336      !! sendrecv of encapsulated data.
337      !! @note see mp_sendrecv
338
339      TYPE(dbcsr_data_obj), INTENT(IN)                   :: msgin
340      INTEGER, INTENT(IN)                                :: dest
341      TYPE(dbcsr_data_obj), INTENT(INOUT)                :: msgout
342      INTEGER, INTENT(IN)                                :: source, comm
343
344      IF (dbcsr_data_get_type(msgin) .NE. dbcsr_data_get_type(msgout)) &
345         DBCSR_ABORT("Different data type for msgin and msgout")
346
347      SELECT CASE (dbcsr_data_get_type(msgin))
348      CASE (dbcsr_type_real_4)
349         CALL mp_sendrecv(msgin%d%r_sp, dest, msgout%d%r_sp, source, comm)
350      CASE (dbcsr_type_real_8)
351         CALL mp_sendrecv(msgin%d%r_dp, dest, msgout%d%r_dp, source, comm)
352      CASE (dbcsr_type_complex_4)
353         CALL mp_sendrecv(msgin%d%c_sp, dest, msgout%d%c_sp, source, comm)
354      CASE (dbcsr_type_complex_8)
355         CALL mp_sendrecv(msgin%d%c_dp, dest, msgout%d%c_dp, source, comm)
356      CASE default
357         DBCSR_ABORT("Incorrect data type")
358      END SELECT
359   END SUBROUTINE dbcsr_sendrecv_any
360
361   SUBROUTINE dbcsr_isend_any(msgin, dest, comm, request, tag)
362      !! Non-blocking send of encapsulated data.
363      !! @note see mp_isend_iv
364
365      TYPE(dbcsr_data_obj), INTENT(IN)                   :: msgin
366      INTEGER, INTENT(IN)                                :: dest, comm
367      INTEGER, INTENT(OUT)                               :: request
368      INTEGER, INTENT(IN), OPTIONAL                      :: tag
369
370      SELECT CASE (dbcsr_data_get_type(msgin))
371      CASE (dbcsr_type_real_4)
372         CALL mp_isend(msgin%d%r_sp, dest, comm, request, tag)
373      CASE (dbcsr_type_real_8)
374         CALL mp_isend(msgin%d%r_dp, dest, comm, request, tag)
375      CASE (dbcsr_type_complex_4)
376         CALL mp_isend(msgin%d%c_sp, dest, comm, request, tag)
377      CASE (dbcsr_type_complex_8)
378         CALL mp_isend(msgin%d%c_dp, dest, comm, request, tag)
379      CASE default
380         DBCSR_ABORT("Incorrect data type")
381      END SELECT
382   END SUBROUTINE dbcsr_isend_any
383
384   SUBROUTINE dbcsr_irecv_any(msgin, source, comm, request, tag)
385      !! Non-blocking recv of encapsulated data.
386      !! @note see mp_irecv_iv
387
388      TYPE(dbcsr_data_obj), INTENT(IN)                   :: msgin
389      INTEGER, INTENT(IN)                                :: source, comm
390      INTEGER, INTENT(OUT)                               :: request
391      INTEGER, INTENT(IN), OPTIONAL                      :: tag
392
393      SELECT CASE (dbcsr_data_get_type(msgin))
394      CASE (dbcsr_type_real_4)
395         CALL mp_irecv(msgin%d%r_sp, source, comm, request, tag)
396      CASE (dbcsr_type_real_8)
397         CALL mp_irecv(msgin%d%r_dp, source, comm, request, tag)
398      CASE (dbcsr_type_complex_4)
399         CALL mp_irecv(msgin%d%c_sp, source, comm, request, tag)
400      CASE (dbcsr_type_complex_8)
401         CALL mp_irecv(msgin%d%c_dp, source, comm, request, tag)
402      CASE default
403         DBCSR_ABORT("Incorrect data type")
404      END SELECT
405   END SUBROUTINE dbcsr_irecv_any
406
407   SUBROUTINE dbcsr_win_create_any(base, comm, win)
408      !! Window initialization function of encapsulated data.
409      TYPE(dbcsr_data_obj), INTENT(IN)                   :: base
410      INTEGER, INTENT(IN)                                :: comm
411      INTEGER, INTENT(OUT)                               :: win
412
413      SELECT CASE (dbcsr_data_get_type(base))
414      CASE (dbcsr_type_real_4)
415         CALL mp_win_create(base%d%r_sp, comm, win)
416      CASE (dbcsr_type_real_8)
417         CALL mp_win_create(base%d%r_dp, comm, win)
418      CASE (dbcsr_type_complex_4)
419         CALL mp_win_create(base%d%c_sp, comm, win)
420      CASE (dbcsr_type_complex_8)
421         CALL mp_win_create(base%d%c_dp, comm, win)
422      CASE default
423         DBCSR_ABORT("Incorrect data type")
424      END SELECT
425   END SUBROUTINE dbcsr_win_create_any
426
427   SUBROUTINE dbcsr_rget_any(base, source, win, win_data, myproc, disp, request, &
428      !! Single-sided Get function of encapsulated data.
429                             origin_datatype, target_datatype)
430      TYPE(dbcsr_data_obj), INTENT(IN)                   :: base
431      INTEGER, INTENT(IN)                                :: source, win
432      TYPE(dbcsr_data_obj), INTENT(IN)                   :: win_data
433      INTEGER, INTENT(IN), OPTIONAL                      :: myproc, disp
434      INTEGER, INTENT(OUT)                               :: request
435      TYPE(mp_type_descriptor_type), INTENT(IN), &
436         OPTIONAL                                        :: origin_datatype, target_datatype
437
438      IF (dbcsr_data_get_type(base) /= dbcsr_data_get_type(win_data)) &
439         DBCSR_ABORT("Mismatch data type between buffer and window")
440
441      SELECT CASE (dbcsr_data_get_type(base))
442      CASE (dbcsr_type_real_4)
443         CALL mp_rget(base%d%r_sp, source, win, win_data%d%r_sp, myproc, &
444                      disp, request, origin_datatype, target_datatype)
445      CASE (dbcsr_type_real_8)
446         CALL mp_rget(base%d%r_dp, source, win, win_data%d%r_dp, myproc, &
447                      disp, request, origin_datatype, target_datatype)
448      CASE (dbcsr_type_complex_4)
449         CALL mp_rget(base%d%c_sp, source, win, win_data%d%c_sp, myproc, &
450                      disp, request, origin_datatype, target_datatype)
451      CASE (dbcsr_type_complex_8)
452         CALL mp_rget(base%d%c_dp, source, win, win_data%d%c_dp, myproc, &
453                      disp, request, origin_datatype, target_datatype)
454      CASE default
455         DBCSR_ABORT("Incorrect data type")
456      END SELECT
457   END SUBROUTINE dbcsr_rget_any
458
459   SUBROUTINE dbcsr_ibcast_any(base, source, grp, request)
460      !! Bcast function of encapsulated data.
461      TYPE(dbcsr_data_obj), INTENT(IN)                   :: base
462      INTEGER, INTENT(IN)                                :: source, grp
463      INTEGER, INTENT(INOUT)                             :: request
464
465      SELECT CASE (dbcsr_data_get_type(base))
466      CASE (dbcsr_type_real_4)
467         CALL mp_ibcast(base%d%r_sp, source, grp, request)
468      CASE (dbcsr_type_real_8)
469         CALL mp_ibcast(base%d%r_dp, source, grp, request)
470      CASE (dbcsr_type_complex_4)
471         CALL mp_ibcast(base%d%c_sp, source, grp, request)
472      CASE (dbcsr_type_complex_8)
473         CALL mp_ibcast(base%d%c_dp, source, grp, request)
474      CASE default
475         DBCSR_ABORT("Incorrect data type")
476      END SELECT
477   END SUBROUTINE dbcsr_ibcast_any
478
479   SUBROUTINE dbcsr_iscatterv_any(base, counts, displs, msg, recvcount, root, grp, request)
480      !! Scatter function of encapsulated data.
481      TYPE(dbcsr_data_obj), INTENT(IN)                   :: base
482      INTEGER, DIMENSION(:), INTENT(IN), CONTIGUOUS      :: counts, displs
483      TYPE(dbcsr_data_obj), INTENT(INOUT)                :: msg
484      INTEGER, INTENT(IN)                                :: recvcount, root, grp
485      INTEGER, INTENT(INOUT)                             :: request
486
487      IF (dbcsr_data_get_type(base) .NE. dbcsr_data_get_type(msg)) &
488         DBCSR_ABORT("Different data type for msgin and msgout")
489
490      SELECT CASE (dbcsr_data_get_type(base))
491      CASE (dbcsr_type_real_4)
492         CALL mp_iscatter(base%d%r_sp, counts, displs, msg%d%r_sp, recvcount, root, grp, request)
493      CASE (dbcsr_type_real_8)
494         CALL mp_iscatter(base%d%r_dp, counts, displs, msg%d%r_dp, recvcount, root, grp, request)
495      CASE (dbcsr_type_complex_4)
496         CALL mp_iscatter(base%d%c_sp, counts, displs, msg%d%c_sp, recvcount, root, grp, request)
497      CASE (dbcsr_type_complex_8)
498         CALL mp_iscatter(base%d%c_dp, counts, displs, msg%d%c_dp, recvcount, root, grp, request)
499      CASE default
500         DBCSR_ABORT("Incorrect data type")
501      END SELECT
502   END SUBROUTINE dbcsr_iscatterv_any
503
504   SUBROUTINE dbcsr_gatherv_any(base, ub_base, msg, counts, displs, root, grp)
505      !! Gather function of encapsulated data.
506      TYPE(dbcsr_data_obj), INTENT(IN)                   :: base
507      INTEGER, INTENT(IN)                                :: ub_base
508      TYPE(dbcsr_data_obj), INTENT(INOUT)                :: msg
509      INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN)      :: counts, displs
510      INTEGER, INTENT(IN)                                :: root, grp
511
512      IF (dbcsr_data_get_type(base) .NE. dbcsr_data_get_type(msg)) &
513         DBCSR_ABORT("Different data type for msgin and msgout")
514
515      SELECT CASE (dbcsr_data_get_type(base))
516      CASE (dbcsr_type_real_4)
517         CALL mp_gatherv(base%d%r_sp(:ub_base), msg%d%r_sp, counts, displs, root, grp)
518      CASE (dbcsr_type_real_8)
519         CALL mp_gatherv(base%d%r_dp(:ub_base), msg%d%r_dp, counts, displs, root, grp)
520      CASE (dbcsr_type_complex_4)
521         CALL mp_gatherv(base%d%c_sp(:ub_base), msg%d%c_sp, counts, displs, root, grp)
522      CASE (dbcsr_type_complex_8)
523         CALL mp_gatherv(base%d%c_dp(:ub_base), msg%d%c_dp, counts, displs, root, grp)
524      CASE default
525         DBCSR_ABORT("Incorrect data type")
526      END SELECT
527   END SUBROUTINE dbcsr_gatherv_any
528
529   SUBROUTINE dbcsr_isendrecv_any(msgin, dest, msgout, source, grp, send_request, recv_request)
530      !! Send/Recv function of encapsulated data.
531      TYPE(dbcsr_data_obj), INTENT(IN)                   :: msgin
532      INTEGER, INTENT(IN)                                :: dest
533      TYPE(dbcsr_data_obj), INTENT(INOUT)                :: msgout
534      INTEGER, INTENT(IN)                                :: source, grp
535      INTEGER, INTENT(OUT)                               :: send_request, recv_request
536
537      IF (dbcsr_data_get_type(msgin) .NE. dbcsr_data_get_type(msgout)) &
538         DBCSR_ABORT("Different data type for msgin and msgout")
539
540      SELECT CASE (dbcsr_data_get_type(msgin))
541      CASE (dbcsr_type_real_4)
542         CALL mp_isendrecv(msgin%d%r_sp, dest, &
543                           msgout%d%r_sp, source, &
544                           grp, send_request, recv_request)
545      CASE (dbcsr_type_real_8)
546         CALL mp_isendrecv(msgin%d%r_dp, dest, &
547                           msgout%d%r_dp, source, &
548                           grp, send_request, recv_request)
549      CASE (dbcsr_type_complex_4)
550         CALL mp_isendrecv(msgin%d%c_sp, dest, &
551                           msgout%d%c_sp, source, &
552                           grp, send_request, recv_request)
553      CASE (dbcsr_type_complex_8)
554         CALL mp_isendrecv(msgin%d%c_dp, dest, &
555                           msgout%d%c_dp, source, &
556                           grp, send_request, recv_request)
557      CASE default
558         DBCSR_ABORT("Incorrect data type")
559      END SELECT
560   END SUBROUTINE dbcsr_isendrecv_any
561
562   SUBROUTINE dbcsr_allgatherv(send_data, scount, recv_data, recv_count, recv_displ, gid)
563      !! Allgather of encapsulated data
564      !! @note see mp_allgatherv_dv
565
566      TYPE(dbcsr_data_obj), INTENT(IN)                   :: send_data
567      INTEGER, INTENT(IN)                                :: scount
568      TYPE(dbcsr_data_obj), INTENT(INOUT)                :: recv_data
569      INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN)      :: recv_count, recv_displ
570      INTEGER, INTENT(IN)                                :: gid
571
572      IF (dbcsr_data_get_type(send_data) /= dbcsr_data_get_type(recv_data)) &
573         DBCSR_ABORT("Data type mismatch")
574      SELECT CASE (dbcsr_data_get_type(send_data))
575      CASE (dbcsr_type_real_4)
576         CALL mp_allgather(send_data%d%r_sp(1:scount), recv_data%d%r_sp, &
577                           recv_count, recv_displ, gid)
578      CASE (dbcsr_type_real_8)
579         CALL mp_allgather(send_data%d%r_dp(1:scount), recv_data%d%r_dp, &
580                           recv_count, recv_displ, gid)
581      CASE (dbcsr_type_complex_4)
582         CALL mp_allgather(send_data%d%c_sp(1:scount), recv_data%d%c_sp, &
583                           recv_count, recv_displ, gid)
584      CASE (dbcsr_type_complex_8)
585         CALL mp_allgather(send_data%d%c_dp(1:scount), recv_data%d%c_dp, &
586                           recv_count, recv_displ, gid)
587      CASE default
588         DBCSR_ABORT("Invalid data type")
589      END SELECT
590   END SUBROUTINE dbcsr_allgatherv
591
592#:include '../data/dbcsr.fypp'
593#:for n, nametype1, base1, prec1, kind1, type1, dkind1 in inst_params_float
594   SUBROUTINE hybrid_alltoall_${nametype1}$1(sb, scount, sdispl, &
595                                             rb, rcount, rdispl, mp_env, most_ptp, remainder_ptp, no_hybrid)
596      !! Row/column and global all-to-all
597      !!
598      !! Communicator selection
599      !! Uses row and column communicators for row/column
600      !! sends. Remaining sends are performed using the global
601      !! communicator.  Point-to-point isend/irecv are used if ptp is
602      !! set, otherwise a alltoall collective call is issued.
603      !! see mp_alltoall
604
605      ${type1}$, DIMENSION(:), &
606         CONTIGUOUS, INTENT(in), TARGET        :: sb
607      INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN) :: scount, sdispl
608      ${type1}$, DIMENSION(:), &
609         CONTIGUOUS, INTENT(INOUT), TARGET     :: rb
610      INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN) :: rcount, rdispl
611      TYPE(dbcsr_mp_obj), INTENT(IN)           :: mp_env
612         !! MP Environment
613      LOGICAL, INTENT(in), OPTIONAL            :: most_ptp, remainder_ptp, &
614                                                  no_hybrid
615         !! Use point-to-point for row/column; default is no
616         !! Use point-to-point for remaining; default is no
617         !! Use regular global collective; default is no
618
619      INTEGER :: all_group, mynode, mypcol, myprow, nall_rr, nall_sr, ncol_rr, &
620                 ncol_sr, npcols, nprows, nrow_rr, nrow_sr, numnodes, dst, src, &
621                 prow, pcol, send_cnt, recv_cnt, tag, grp, i
622      INTEGER, ALLOCATABLE, DIMENSION(:) :: all_rr, all_sr, col_rr, col_sr, &
623                                            new_rcount, new_rdispl, new_scount, new_sdispl, row_rr, row_sr
624      INTEGER, DIMENSION(:, :), CONTIGUOUS, POINTER :: pgrid
625      LOGICAL                                  :: most_collective, &
626                                                  remainder_collective, no_h
627      ${type1}$, DIMENSION(:), CONTIGUOUS, POINTER :: send_data_p, recv_data_p
628      TYPE(dbcsr_mp_obj)                       :: mpe
629
630      IF (.NOT. dbcsr_mp_has_subgroups(mp_env)) THEN
631         mpe = mp_env
632         CALL dbcsr_mp_grid_setup(mpe)
633      ENDIF
634      most_collective = .TRUE.
635      remainder_collective = .TRUE.
636      no_h = .FALSE.
637      IF (PRESENT(most_ptp)) most_collective = .NOT. most_ptp
638      IF (PRESENT(remainder_ptp)) remainder_collective = .NOT. remainder_ptp
639      IF (PRESENT(no_hybrid)) no_h = no_hybrid
640      all_group = dbcsr_mp_group(mp_env)
641      ! Don't use subcommunicators if they're not defined.
642      no_h = no_h .OR. .NOT. dbcsr_mp_has_subgroups(mp_env) .OR. .NOT. has_MPI
643      subgrouped: IF (mp_env%mp%subgroups_defined .AND. .NOT. no_h) THEN
644         mynode = dbcsr_mp_mynode(mp_env)
645         numnodes = dbcsr_mp_numnodes(mp_env)
646         nprows = dbcsr_mp_nprows(mp_env)
647         npcols = dbcsr_mp_npcols(mp_env)
648         myprow = dbcsr_mp_myprow(mp_env)
649         mypcol = dbcsr_mp_mypcol(mp_env)
650         pgrid => dbcsr_mp_pgrid(mp_env)
651         ALLOCATE (row_sr(0:npcols - 1)); nrow_sr = 0
652         ALLOCATE (row_rr(0:npcols - 1)); nrow_rr = 0
653         ALLOCATE (col_sr(0:nprows - 1)); ncol_sr = 0
654         ALLOCATE (col_rr(0:nprows - 1)); ncol_rr = 0
655         ALLOCATE (all_sr(0:numnodes - 1)); nall_sr = 0
656         ALLOCATE (all_rr(0:numnodes - 1)); nall_rr = 0
657         ALLOCATE (new_scount(numnodes), new_rcount(numnodes))
658         ALLOCATE (new_sdispl(numnodes), new_rdispl(numnodes))
659         IF (.NOT. remainder_collective) THEN
660            CALL remainder_point_to_point()
661         ENDIF
662         IF (.NOT. most_collective) THEN
663            CALL most_point_to_point()
664         ELSE
665            CALL most_alltoall()
666         ENDIF
667         IF (remainder_collective) THEN
668            CALL remainder_alltoall()
669         ENDIF
670         ! Wait for all issued sends and receives.
671         IF (.NOT. most_collective) THEN
672            CALL mp_waitall(row_sr(0:nrow_sr - 1))
673            CALL mp_waitall(col_sr(0:ncol_sr - 1))
674            CALL mp_waitall(row_rr(0:nrow_rr - 1))
675            CALL mp_waitall(col_rr(0:ncol_rr - 1))
676         END IF
677         IF (.NOT. remainder_collective) THEN
678            CALL mp_waitall(all_sr(1:nall_sr))
679            CALL mp_waitall(all_rr(1:nall_rr))
680         ENDIF
681      ELSE
682         CALL mp_alltoall(sb, scount, sdispl, &
683                          rb, rcount, rdispl, &
684                          all_group)
685      ENDIF subgrouped
686   CONTAINS
687      SUBROUTINE most_alltoall()
688         DO pcol = 0, npcols - 1
689            new_scount(1 + pcol) = scount(1 + pgrid(myprow, pcol))
690            new_rcount(1 + pcol) = rcount(1 + pgrid(myprow, pcol))
691            new_sdispl(1 + pcol) = sdispl(1 + pgrid(myprow, pcol))
692            new_rdispl(1 + pcol) = rdispl(1 + pgrid(myprow, pcol))
693         END DO
694         CALL mp_alltoall(sb, new_scount(1:npcols), new_sdispl(1:npcols), &
695                          rb, new_rcount(1:npcols), new_rdispl(1:npcols), &
696                          dbcsr_mp_my_row_group(mp_env))
697         DO prow = 0, nprows - 1
698            new_scount(1 + prow) = scount(1 + pgrid(prow, mypcol))
699            new_rcount(1 + prow) = rcount(1 + pgrid(prow, mypcol))
700            new_sdispl(1 + prow) = sdispl(1 + pgrid(prow, mypcol))
701            new_rdispl(1 + prow) = rdispl(1 + pgrid(prow, mypcol))
702         END DO
703         CALL mp_alltoall(sb, new_scount(1:nprows), new_sdispl(1:nprows), &
704                          rb, new_rcount(1:nprows), new_rdispl(1:nprows), &
705                          dbcsr_mp_my_col_group(mp_env))
706      END SUBROUTINE most_alltoall
707      SUBROUTINE most_point_to_point()
708         ! Go through my prow and exchange.
709         DO i = 0, npcols - 1
710            pcol = MOD(mypcol + i, npcols)
711            grp = dbcsr_mp_my_row_group(mp_env)
712            !
713            dst = dbcsr_mp_get_process(mp_env, myprow, pcol)
714            send_cnt = scount(dst + 1)
715            IF (send_cnt .GT. 0) THEN
716               send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1)
717               IF (pcol .NE. mypcol) THEN
718                  tag = 4*mypcol
719                  CALL mp_isend(send_data_p, pcol, grp, row_sr(nrow_sr), tag)
720                  nrow_sr = nrow_sr + 1
721               ENDIF
722            ENDIF
723            !
724            pcol = MODULO(mypcol - i, npcols)
725            src = dbcsr_mp_get_process(mp_env, myprow, pcol)
726            recv_cnt = rcount(src + 1)
727            IF (recv_cnt .GT. 0) THEN
728               recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1)
729               IF (pcol .NE. mypcol) THEN
730                  tag = 4*pcol
731                  CALL mp_irecv(recv_data_p, pcol, grp, row_rr(nrow_rr), tag)
732                  nrow_rr = nrow_rr + 1
733               ELSE
734                  CALL memory_copy(recv_data_p, send_data_p, recv_cnt)
735               ENDIF
736            ENDIF
737         ENDDO
738         ! go through my pcol and exchange
739         DO i = 0, nprows - 1
740            prow = MOD(myprow + i, nprows)
741            grp = dbcsr_mp_my_col_group(mp_env)
742            !
743            dst = dbcsr_mp_get_process(mp_env, prow, mypcol)
744            send_cnt = scount(dst + 1)
745            IF (send_cnt .GT. 0) THEN
746               send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1)
747               IF (prow .NE. myprow) THEN
748                  tag = 4*myprow + 1
749                  CALL mp_isend(send_data_p, prow, grp, col_sr(ncol_sr), tag)
750                  ncol_sr = ncol_sr + 1
751               ENDIF
752            ENDIF
753            !
754            prow = MODULO(myprow - i, nprows)
755            src = dbcsr_mp_get_process(mp_env, prow, mypcol)
756            recv_cnt = rcount(src + 1)
757            IF (recv_cnt .GT. 0) THEN
758               recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1)
759               IF (prow .NE. myprow) THEN
760                  tag = 4*prow + 1
761                  CALL mp_irecv(recv_data_p, prow, grp, col_rr(ncol_rr), tag)
762                  ncol_rr = ncol_rr + 1
763               ELSE
764                  CALL memory_copy(recv_data_p, send_data_p, recv_cnt)
765               ENDIF
766            ENDIF
767         ENDDO
768      END SUBROUTINE most_point_to_point
769      SUBROUTINE remainder_alltoall()
770         new_scount(:) = scount(:)
771         new_rcount(:) = rcount(:)
772         DO prow = 0, nprows - 1
773            new_scount(1 + pgrid(prow, mypcol)) = 0
774            new_rcount(1 + pgrid(prow, mypcol)) = 0
775         END DO
776         DO pcol = 0, npcols - 1
777            new_scount(1 + pgrid(myprow, pcol)) = 0
778            new_rcount(1 + pgrid(myprow, pcol)) = 0
779         END DO
780         CALL mp_alltoall(sb, new_scount, sdispl, &
781                          rb, new_rcount, rdispl, all_group)
782      END SUBROUTINE remainder_alltoall
783      SUBROUTINE remainder_point_to_point()
784         INTEGER                                  :: col, row
785
786         DO row = 0, nprows - 1
787            prow = MOD(row + myprow, nprows)
788            IF (prow .EQ. myprow) CYCLE
789            DO col = 0, npcols - 1
790               pcol = MOD(col + mypcol, npcols)
791               IF (pcol .EQ. mypcol) CYCLE
792               dst = dbcsr_mp_get_process(mp_env, prow, pcol)
793               send_cnt = scount(dst + 1)
794               IF (send_cnt .GT. 0) THEN
795                  send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1)
796                  tag = 4*mynode + 2
797                  CALL mp_isend(send_data_p, dst, all_group, all_sr(nall_sr + 1), tag)
798                  nall_sr = nall_sr + 1
799               ENDIF
800               !
801               src = dbcsr_mp_get_process(mp_env, prow, pcol)
802               recv_cnt = rcount(src + 1)
803               IF (recv_cnt .GT. 0) THEN
804                  recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1)
805                  tag = 4*src + 2
806                  CALL mp_irecv(recv_data_p, src, all_group, all_rr(nall_rr + 1), tag)
807                  nall_rr = nall_rr + 1
808               ENDIF
809            ENDDO
810         ENDDO
811      END SUBROUTINE remainder_point_to_point
812   END SUBROUTINE hybrid_alltoall_${nametype1}$1
813#:endfor
814
815END MODULE dbcsr_mp_operations
816