1!--------------------------------------------------------------------------------------------------! 2! Copyright (C) by the DBCSR developers group - All rights reserved ! 3! This file is part of the DBCSR library. ! 4! ! 5! For information on the license, see the LICENSE file. ! 6! For further information please visit https://dbcsr.cp2k.org ! 7! SPDX-License-Identifier: GPL-2.0+ ! 8!--------------------------------------------------------------------------------------------------! 9 10MODULE dbcsr_mp_operations 11 !! Wrappers to message passing calls. 12 13 USE dbcsr_config, ONLY: has_MPI 14 USE dbcsr_data_methods, ONLY: dbcsr_data_get_type 15 USE dbcsr_mp_methods, ONLY: & 16 dbcsr_mp_get_process, dbcsr_mp_grid_setup, dbcsr_mp_group, dbcsr_mp_has_subgroups, & 17 dbcsr_mp_my_col_group, dbcsr_mp_my_row_group, dbcsr_mp_mynode, dbcsr_mp_mypcol, & 18 dbcsr_mp_myprow, dbcsr_mp_npcols, dbcsr_mp_nprows, dbcsr_mp_numnodes, dbcsr_mp_pgrid 19 USE dbcsr_ptr_util, ONLY: memory_copy 20 USE dbcsr_types, ONLY: dbcsr_data_obj, & 21 dbcsr_mp_obj, & 22 dbcsr_type_complex_4, & 23 dbcsr_type_complex_8, & 24 dbcsr_type_int_4, & 25 dbcsr_type_real_4, & 26 dbcsr_type_real_8 27 USE dbcsr_kinds, ONLY: real_4, & 28 real_8 29 USE dbcsr_mpiwrap, ONLY: & 30 mp_allgather, mp_alltoall, mp_gatherv, mp_ibcast, mp_irecv, mp_iscatter, mp_isend, & 31 mp_isendrecv, mp_rget, mp_sendrecv, mp_type_descriptor_type, mp_type_indexed_make_c, & 32 mp_type_indexed_make_d, mp_type_indexed_make_r, mp_type_indexed_make_z, mp_type_make, & 33 mp_waitall, mp_win_create 34#include "base/dbcsr_base_uses.f90" 35 36 IMPLICIT NONE 37 38 PRIVATE 39 40 CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbcsr_mp_operations' 41 42 ! MP routines 43 PUBLIC :: hybrid_alltoall_s1, hybrid_alltoall_d1, & 44 hybrid_alltoall_c1, hybrid_alltoall_z1, & 45 hybrid_alltoall_i1, hybrid_alltoall_any 46 PUBLIC :: dbcsr_allgatherv 47 PUBLIC :: dbcsr_sendrecv_any 48 PUBLIC :: dbcsr_isend_any, dbcsr_irecv_any 49 PUBLIC :: dbcsr_win_create_any, dbcsr_rget_any, dbcsr_ibcast_any 50 PUBLIC :: dbcsr_iscatterv_any, dbcsr_gatherv_any 51 PUBLIC :: dbcsr_isendrecv_any 52 ! Type helpers 53 PUBLIC :: dbcsr_mp_type_from_anytype 54 55 INTERFACE dbcsr_hybrid_alltoall 56 MODULE PROCEDURE hybrid_alltoall_s1, hybrid_alltoall_d1, & 57 hybrid_alltoall_c1, hybrid_alltoall_z1 58 MODULE PROCEDURE hybrid_alltoall_i1 59 MODULE PROCEDURE hybrid_alltoall_any 60 END INTERFACE 61 62CONTAINS 63 64 SUBROUTINE hybrid_alltoall_any(sb, scount, sdispl, & 65 rb, rcount, rdispl, mp_env, most_ptp, remainder_ptp, no_hybrid) 66 TYPE(dbcsr_data_obj), INTENT(IN) :: sb 67 INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN) :: scount, sdispl 68 TYPE(dbcsr_data_obj), INTENT(INOUT) :: rb 69 INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN) :: rcount, rdispl 70 TYPE(dbcsr_mp_obj), INTENT(IN) :: mp_env 71 LOGICAL, INTENT(in), OPTIONAL :: most_ptp, remainder_ptp, no_hybrid 72 73 CHARACTER(len=*), PARAMETER :: routineN = 'hybrid_alltoall_any' 74 75 INTEGER :: error_handle 76 77! --------------------------------------------------------------------------- 78 79 CALL timeset(routineN, error_handle) 80 81 SELECT CASE (dbcsr_data_get_type(sb)) 82 CASE (dbcsr_type_real_4) 83 CALL hybrid_alltoall_s1(sb%d%r_sp, scount, sdispl, & 84 rb%d%r_sp, rcount, rdispl, mp_env, & 85 most_ptp, remainder_ptp, no_hybrid) 86 CASE (dbcsr_type_real_8) 87 CALL hybrid_alltoall_d1(sb%d%r_dp, scount, sdispl, & 88 rb%d%r_dp, rcount, rdispl, mp_env, & 89 most_ptp, remainder_ptp, no_hybrid) 90 CASE (dbcsr_type_complex_4) 91 CALL hybrid_alltoall_c1(sb%d%c_sp, scount, sdispl, & 92 rb%d%c_sp, rcount, rdispl, mp_env, & 93 most_ptp, remainder_ptp, no_hybrid) 94 CASE (dbcsr_type_complex_8) 95 CALL hybrid_alltoall_z1(sb%d%c_dp, scount, sdispl, & 96 rb%d%c_dp, rcount, rdispl, mp_env, & 97 most_ptp, remainder_ptp, no_hybrid) 98 CASE default 99 DBCSR_ABORT("Invalid data type") 100 END SELECT 101 102 CALL timestop(error_handle) 103 END SUBROUTINE hybrid_alltoall_any 104 105 SUBROUTINE hybrid_alltoall_i1(sb, scount, sdispl, & 106 rb, rcount, rdispl, mp_env, most_ptp, remainder_ptp, no_hybrid) 107 !! Row/column and global all-to-all 108 !! 109 !! Communicator selection 110 !! Uses row and column communicators for row/column 111 !! sends. Remaining sends are performed using the global 112 !! communicator. Point-to-point isend/irecv are used if ptp is 113 !! set, otherwise a alltoall collective call is issued. 114 !! see mp_alltoall 115 116 INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(in), & 117 TARGET :: sb 118 INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN) :: scount, sdispl 119 INTEGER, DIMENSION(:), CONTIGUOUS, & 120 INTENT(INOUT), TARGET :: rb 121 INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN) :: rcount, rdispl 122 TYPE(dbcsr_mp_obj), INTENT(IN) :: mp_env 123 !! MP Environment 124 LOGICAL, INTENT(IN), OPTIONAL :: most_ptp, remainder_ptp, & 125 no_hybrid 126 !! Use point-to-point for row/column; default is no 127 !! Use point-to-point for remaining; default is no 128 !! Use regular global collective; default is no 129 130 INTEGER :: all_group, mynode, mypcol, myprow, nall_rr, nall_sr, ncol_rr, & 131 ncol_sr, npcols, nprows, nrow_rr, nrow_sr, numnodes, dst, src, & 132 prow, pcol, send_cnt, recv_cnt, tag, grp, i 133 INTEGER, ALLOCATABLE, DIMENSION(:) :: all_rr, all_sr, col_rr, col_sr, & 134 new_rcount, new_rdispl, new_scount, new_sdispl, row_rr, row_sr 135 INTEGER, DIMENSION(:, :), POINTER :: pgrid 136 LOGICAL :: most_collective, & 137 remainder_collective, no_h 138 INTEGER, DIMENSION(:), POINTER, CONTIGUOUS :: send_data_p, recv_data_p 139 TYPE(dbcsr_mp_obj) :: mpe 140 141 IF (.NOT. dbcsr_mp_has_subgroups(mp_env)) THEN 142 mpe = mp_env 143 CALL dbcsr_mp_grid_setup(mpe) 144 ENDIF 145 most_collective = .TRUE. 146 remainder_collective = .TRUE. 147 no_h = .FALSE. 148 IF (PRESENT(most_ptp)) most_collective = .NOT. most_ptp 149 IF (PRESENT(remainder_ptp)) remainder_collective = .NOT. remainder_ptp 150 IF (PRESENT(no_hybrid)) no_h = no_hybrid 151 all_group = dbcsr_mp_group(mp_env) 152 ! Don't use subcommunicators if they're not defined. 153 no_h = no_h .OR. .NOT. dbcsr_mp_has_subgroups(mp_env) .OR. .NOT. has_MPI 154 subgrouped: IF (mp_env%mp%subgroups_defined .AND. .NOT. no_h) THEN 155 mynode = dbcsr_mp_mynode(mp_env) 156 numnodes = dbcsr_mp_numnodes(mp_env) 157 nprows = dbcsr_mp_nprows(mp_env) 158 npcols = dbcsr_mp_npcols(mp_env) 159 myprow = dbcsr_mp_myprow(mp_env) 160 mypcol = dbcsr_mp_mypcol(mp_env) 161 pgrid => dbcsr_mp_pgrid(mp_env) 162 ALLOCATE (row_sr(0:npcols - 1)); nrow_sr = 0 163 ALLOCATE (row_rr(0:npcols - 1)); nrow_rr = 0 164 ALLOCATE (col_sr(0:nprows - 1)); ncol_sr = 0 165 ALLOCATE (col_rr(0:nprows - 1)); ncol_rr = 0 166 ALLOCATE (all_sr(0:numnodes - 1)); nall_sr = 0 167 ALLOCATE (all_rr(0:numnodes - 1)); nall_rr = 0 168 ALLOCATE (new_scount(numnodes), new_rcount(numnodes)) 169 ALLOCATE (new_sdispl(numnodes), new_rdispl(numnodes)) 170 IF (.NOT. remainder_collective) THEN 171 CALL remainder_point_to_point() 172 ENDIF 173 IF (.NOT. most_collective) THEN 174 CALL most_point_to_point() 175 ELSE 176 CALL most_alltoall() 177 ENDIF 178 IF (remainder_collective) THEN 179 CALL remainder_alltoall() 180 ENDIF 181 ! Wait for all issued sends and receives. 182 IF (.NOT. most_collective) THEN 183 CALL mp_waitall(row_sr(0:nrow_sr - 1)) 184 CALL mp_waitall(col_sr(0:ncol_sr - 1)) 185 CALL mp_waitall(row_rr(0:nrow_rr - 1)) 186 CALL mp_waitall(col_rr(0:ncol_rr - 1)) 187 END IF 188 IF (.NOT. remainder_collective) THEN 189 CALL mp_waitall(all_sr(1:nall_sr)) 190 CALL mp_waitall(all_rr(1:nall_rr)) 191 ENDIF 192 ELSE 193 CALL mp_alltoall(sb, scount, sdispl, & 194 rb, rcount, rdispl, & 195 all_group) 196 ENDIF subgrouped 197 CONTAINS 198 SUBROUTINE most_alltoall() 199 DO pcol = 0, npcols - 1 200 new_scount(1 + pcol) = scount(1 + pgrid(myprow, pcol)) 201 new_rcount(1 + pcol) = rcount(1 + pgrid(myprow, pcol)) 202 new_sdispl(1 + pcol) = sdispl(1 + pgrid(myprow, pcol)) 203 new_rdispl(1 + pcol) = rdispl(1 + pgrid(myprow, pcol)) 204 END DO 205 CALL mp_alltoall(sb, new_scount(1:npcols), new_sdispl(1:npcols), & 206 rb, new_rcount(1:npcols), new_rdispl(1:npcols), & 207 dbcsr_mp_my_row_group(mp_env)) 208 DO prow = 0, nprows - 1 209 new_scount(1 + prow) = scount(1 + pgrid(prow, mypcol)) 210 new_rcount(1 + prow) = rcount(1 + pgrid(prow, mypcol)) 211 new_sdispl(1 + prow) = sdispl(1 + pgrid(prow, mypcol)) 212 new_rdispl(1 + prow) = rdispl(1 + pgrid(prow, mypcol)) 213 END DO 214 CALL mp_alltoall(sb, new_scount(1:nprows), new_sdispl(1:nprows), & 215 rb, new_rcount(1:nprows), new_rdispl(1:nprows), & 216 dbcsr_mp_my_col_group(mp_env)) 217 END SUBROUTINE most_alltoall 218 SUBROUTINE most_point_to_point() 219 ! Go through my prow and exchange. 220 DO i = 0, npcols - 1 221 pcol = MOD(mypcol + i, npcols) 222 grp = dbcsr_mp_my_row_group(mp_env) 223 ! 224 dst = dbcsr_mp_get_process(mp_env, myprow, pcol) 225 send_cnt = scount(dst + 1) 226 send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1) 227 tag = 4*mypcol 228 IF (send_cnt .GT. 0) THEN 229 CALL mp_isend(send_data_p, pcol, grp, row_sr(nrow_sr), tag) 230 nrow_sr = nrow_sr + 1 231 ENDIF 232 ! 233 pcol = MODULO(mypcol - i, npcols) 234 src = dbcsr_mp_get_process(mp_env, myprow, pcol) 235 recv_cnt = rcount(src + 1) 236 recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1) 237 tag = 4*pcol 238 IF (recv_cnt .GT. 0) THEN 239 CALL mp_irecv(recv_data_p, pcol, grp, row_rr(nrow_rr), tag) 240 nrow_rr = nrow_rr + 1 241 ENDIF 242 ENDDO 243 ! go through my pcol and exchange 244 DO i = 0, nprows - 1 245 prow = MOD(myprow + i, nprows) 246 grp = dbcsr_mp_my_col_group(mp_env) 247 ! 248 dst = dbcsr_mp_get_process(mp_env, prow, mypcol) 249 send_cnt = scount(dst + 1) 250 IF (send_cnt .GT. 0) THEN 251 send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1) 252 tag = 4*myprow + 1 253 CALL mp_isend(send_data_p, prow, grp, col_sr(ncol_sr), tag) 254 ncol_sr = ncol_sr + 1 255 ENDIF 256 ! 257 prow = MODULO(myprow - i, nprows) 258 src = dbcsr_mp_get_process(mp_env, prow, mypcol) 259 recv_cnt = rcount(src + 1) 260 IF (recv_cnt .GT. 0) THEN 261 recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1) 262 tag = 4*prow + 1 263 CALL mp_irecv(recv_data_p, prow, grp, col_rr(ncol_rr), tag) 264 ncol_rr = ncol_rr + 1 265 ENDIF 266 ENDDO 267 END SUBROUTINE most_point_to_point 268 SUBROUTINE remainder_alltoall() 269 new_scount(:) = scount(:) 270 new_rcount(:) = rcount(:) 271 DO prow = 0, nprows - 1 272 new_scount(1 + pgrid(prow, mypcol)) = 0 273 new_rcount(1 + pgrid(prow, mypcol)) = 0 274 END DO 275 DO pcol = 0, npcols - 1 276 new_scount(1 + pgrid(myprow, pcol)) = 0 277 new_rcount(1 + pgrid(myprow, pcol)) = 0 278 END DO 279 CALL mp_alltoall(sb, new_scount, sdispl, & 280 rb, new_rcount, rdispl, all_group) 281 END SUBROUTINE remainder_alltoall 282 SUBROUTINE remainder_point_to_point() 283 INTEGER :: col, row 284 285 DO row = 0, nprows - 1 286 prow = MOD(row + myprow, nprows) 287 IF (prow .EQ. myprow) CYCLE 288 DO col = 0, npcols - 1 289 pcol = MOD(col + mypcol, npcols) 290 IF (pcol .EQ. mypcol) CYCLE 291 dst = dbcsr_mp_get_process(mp_env, prow, pcol) 292 send_cnt = scount(dst + 1) 293 IF (send_cnt .GT. 0) THEN 294 tag = 4*mynode + 2 295 send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1) 296 CALL mp_isend(send_data_p, dst, all_group, all_sr(nall_sr + 1), tag) 297 nall_sr = nall_sr + 1 298 ENDIF 299 ! 300 src = dbcsr_mp_get_process(mp_env, prow, pcol) 301 recv_cnt = rcount(src + 1) 302 IF (recv_cnt .GT. 0) THEN 303 recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1) 304 tag = 4*src + 2 305 CALL mp_irecv(recv_data_p, src, all_group, all_rr(nall_rr + 1), tag) 306 nall_rr = nall_rr + 1 307 ENDIF 308 ENDDO 309 ENDDO 310 END SUBROUTINE remainder_point_to_point 311 END SUBROUTINE hybrid_alltoall_i1 312 313 FUNCTION dbcsr_mp_type_from_anytype(data_area) RESULT(mp_type) 314 !! Creates an MPI combined type from the given anytype. 315 316 TYPE(dbcsr_data_obj), INTENT(IN) :: data_area 317 !! Data area of any type 318 TYPE(mp_type_descriptor_type) :: mp_type 319 !! Type descriptor 320 321 SELECT CASE (data_area%d%data_type) 322 CASE (dbcsr_type_int_4) 323 mp_type = mp_type_make(data_area%d%i4) 324 CASE (dbcsr_type_real_4) 325 mp_type = mp_type_make(data_area%d%r_sp) 326 CASE (dbcsr_type_real_8) 327 mp_type = mp_type_make(data_area%d%r_dp) 328 CASE (dbcsr_type_complex_4) 329 mp_type = mp_type_make(data_area%d%c_sp) 330 CASE (dbcsr_type_complex_8) 331 mp_type = mp_type_make(data_area%d%c_dp) 332 END SELECT 333 END FUNCTION dbcsr_mp_type_from_anytype 334 335 SUBROUTINE dbcsr_sendrecv_any(msgin, dest, msgout, source, comm) 336 !! sendrecv of encapsulated data. 337 !! @note see mp_sendrecv 338 339 TYPE(dbcsr_data_obj), INTENT(IN) :: msgin 340 INTEGER, INTENT(IN) :: dest 341 TYPE(dbcsr_data_obj), INTENT(INOUT) :: msgout 342 INTEGER, INTENT(IN) :: source, comm 343 344 IF (dbcsr_data_get_type(msgin) .NE. dbcsr_data_get_type(msgout)) & 345 DBCSR_ABORT("Different data type for msgin and msgout") 346 347 SELECT CASE (dbcsr_data_get_type(msgin)) 348 CASE (dbcsr_type_real_4) 349 CALL mp_sendrecv(msgin%d%r_sp, dest, msgout%d%r_sp, source, comm) 350 CASE (dbcsr_type_real_8) 351 CALL mp_sendrecv(msgin%d%r_dp, dest, msgout%d%r_dp, source, comm) 352 CASE (dbcsr_type_complex_4) 353 CALL mp_sendrecv(msgin%d%c_sp, dest, msgout%d%c_sp, source, comm) 354 CASE (dbcsr_type_complex_8) 355 CALL mp_sendrecv(msgin%d%c_dp, dest, msgout%d%c_dp, source, comm) 356 CASE default 357 DBCSR_ABORT("Incorrect data type") 358 END SELECT 359 END SUBROUTINE dbcsr_sendrecv_any 360 361 SUBROUTINE dbcsr_isend_any(msgin, dest, comm, request, tag) 362 !! Non-blocking send of encapsulated data. 363 !! @note see mp_isend_iv 364 365 TYPE(dbcsr_data_obj), INTENT(IN) :: msgin 366 INTEGER, INTENT(IN) :: dest, comm 367 INTEGER, INTENT(OUT) :: request 368 INTEGER, INTENT(IN), OPTIONAL :: tag 369 370 SELECT CASE (dbcsr_data_get_type(msgin)) 371 CASE (dbcsr_type_real_4) 372 CALL mp_isend(msgin%d%r_sp, dest, comm, request, tag) 373 CASE (dbcsr_type_real_8) 374 CALL mp_isend(msgin%d%r_dp, dest, comm, request, tag) 375 CASE (dbcsr_type_complex_4) 376 CALL mp_isend(msgin%d%c_sp, dest, comm, request, tag) 377 CASE (dbcsr_type_complex_8) 378 CALL mp_isend(msgin%d%c_dp, dest, comm, request, tag) 379 CASE default 380 DBCSR_ABORT("Incorrect data type") 381 END SELECT 382 END SUBROUTINE dbcsr_isend_any 383 384 SUBROUTINE dbcsr_irecv_any(msgin, source, comm, request, tag) 385 !! Non-blocking recv of encapsulated data. 386 !! @note see mp_irecv_iv 387 388 TYPE(dbcsr_data_obj), INTENT(IN) :: msgin 389 INTEGER, INTENT(IN) :: source, comm 390 INTEGER, INTENT(OUT) :: request 391 INTEGER, INTENT(IN), OPTIONAL :: tag 392 393 SELECT CASE (dbcsr_data_get_type(msgin)) 394 CASE (dbcsr_type_real_4) 395 CALL mp_irecv(msgin%d%r_sp, source, comm, request, tag) 396 CASE (dbcsr_type_real_8) 397 CALL mp_irecv(msgin%d%r_dp, source, comm, request, tag) 398 CASE (dbcsr_type_complex_4) 399 CALL mp_irecv(msgin%d%c_sp, source, comm, request, tag) 400 CASE (dbcsr_type_complex_8) 401 CALL mp_irecv(msgin%d%c_dp, source, comm, request, tag) 402 CASE default 403 DBCSR_ABORT("Incorrect data type") 404 END SELECT 405 END SUBROUTINE dbcsr_irecv_any 406 407 SUBROUTINE dbcsr_win_create_any(base, comm, win) 408 !! Window initialization function of encapsulated data. 409 TYPE(dbcsr_data_obj), INTENT(IN) :: base 410 INTEGER, INTENT(IN) :: comm 411 INTEGER, INTENT(OUT) :: win 412 413 SELECT CASE (dbcsr_data_get_type(base)) 414 CASE (dbcsr_type_real_4) 415 CALL mp_win_create(base%d%r_sp, comm, win) 416 CASE (dbcsr_type_real_8) 417 CALL mp_win_create(base%d%r_dp, comm, win) 418 CASE (dbcsr_type_complex_4) 419 CALL mp_win_create(base%d%c_sp, comm, win) 420 CASE (dbcsr_type_complex_8) 421 CALL mp_win_create(base%d%c_dp, comm, win) 422 CASE default 423 DBCSR_ABORT("Incorrect data type") 424 END SELECT 425 END SUBROUTINE dbcsr_win_create_any 426 427 SUBROUTINE dbcsr_rget_any(base, source, win, win_data, myproc, disp, request, & 428 !! Single-sided Get function of encapsulated data. 429 origin_datatype, target_datatype) 430 TYPE(dbcsr_data_obj), INTENT(IN) :: base 431 INTEGER, INTENT(IN) :: source, win 432 TYPE(dbcsr_data_obj), INTENT(IN) :: win_data 433 INTEGER, INTENT(IN), OPTIONAL :: myproc, disp 434 INTEGER, INTENT(OUT) :: request 435 TYPE(mp_type_descriptor_type), INTENT(IN), & 436 OPTIONAL :: origin_datatype, target_datatype 437 438 IF (dbcsr_data_get_type(base) /= dbcsr_data_get_type(win_data)) & 439 DBCSR_ABORT("Mismatch data type between buffer and window") 440 441 SELECT CASE (dbcsr_data_get_type(base)) 442 CASE (dbcsr_type_real_4) 443 CALL mp_rget(base%d%r_sp, source, win, win_data%d%r_sp, myproc, & 444 disp, request, origin_datatype, target_datatype) 445 CASE (dbcsr_type_real_8) 446 CALL mp_rget(base%d%r_dp, source, win, win_data%d%r_dp, myproc, & 447 disp, request, origin_datatype, target_datatype) 448 CASE (dbcsr_type_complex_4) 449 CALL mp_rget(base%d%c_sp, source, win, win_data%d%c_sp, myproc, & 450 disp, request, origin_datatype, target_datatype) 451 CASE (dbcsr_type_complex_8) 452 CALL mp_rget(base%d%c_dp, source, win, win_data%d%c_dp, myproc, & 453 disp, request, origin_datatype, target_datatype) 454 CASE default 455 DBCSR_ABORT("Incorrect data type") 456 END SELECT 457 END SUBROUTINE dbcsr_rget_any 458 459 SUBROUTINE dbcsr_ibcast_any(base, source, grp, request) 460 !! Bcast function of encapsulated data. 461 TYPE(dbcsr_data_obj), INTENT(IN) :: base 462 INTEGER, INTENT(IN) :: source, grp 463 INTEGER, INTENT(INOUT) :: request 464 465 SELECT CASE (dbcsr_data_get_type(base)) 466 CASE (dbcsr_type_real_4) 467 CALL mp_ibcast(base%d%r_sp, source, grp, request) 468 CASE (dbcsr_type_real_8) 469 CALL mp_ibcast(base%d%r_dp, source, grp, request) 470 CASE (dbcsr_type_complex_4) 471 CALL mp_ibcast(base%d%c_sp, source, grp, request) 472 CASE (dbcsr_type_complex_8) 473 CALL mp_ibcast(base%d%c_dp, source, grp, request) 474 CASE default 475 DBCSR_ABORT("Incorrect data type") 476 END SELECT 477 END SUBROUTINE dbcsr_ibcast_any 478 479 SUBROUTINE dbcsr_iscatterv_any(base, counts, displs, msg, recvcount, root, grp, request) 480 !! Scatter function of encapsulated data. 481 TYPE(dbcsr_data_obj), INTENT(IN) :: base 482 INTEGER, DIMENSION(:), INTENT(IN), CONTIGUOUS :: counts, displs 483 TYPE(dbcsr_data_obj), INTENT(INOUT) :: msg 484 INTEGER, INTENT(IN) :: recvcount, root, grp 485 INTEGER, INTENT(INOUT) :: request 486 487 IF (dbcsr_data_get_type(base) .NE. dbcsr_data_get_type(msg)) & 488 DBCSR_ABORT("Different data type for msgin and msgout") 489 490 SELECT CASE (dbcsr_data_get_type(base)) 491 CASE (dbcsr_type_real_4) 492 CALL mp_iscatter(base%d%r_sp, counts, displs, msg%d%r_sp, recvcount, root, grp, request) 493 CASE (dbcsr_type_real_8) 494 CALL mp_iscatter(base%d%r_dp, counts, displs, msg%d%r_dp, recvcount, root, grp, request) 495 CASE (dbcsr_type_complex_4) 496 CALL mp_iscatter(base%d%c_sp, counts, displs, msg%d%c_sp, recvcount, root, grp, request) 497 CASE (dbcsr_type_complex_8) 498 CALL mp_iscatter(base%d%c_dp, counts, displs, msg%d%c_dp, recvcount, root, grp, request) 499 CASE default 500 DBCSR_ABORT("Incorrect data type") 501 END SELECT 502 END SUBROUTINE dbcsr_iscatterv_any 503 504 SUBROUTINE dbcsr_gatherv_any(base, ub_base, msg, counts, displs, root, grp) 505 !! Gather function of encapsulated data. 506 TYPE(dbcsr_data_obj), INTENT(IN) :: base 507 INTEGER, INTENT(IN) :: ub_base 508 TYPE(dbcsr_data_obj), INTENT(INOUT) :: msg 509 INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN) :: counts, displs 510 INTEGER, INTENT(IN) :: root, grp 511 512 IF (dbcsr_data_get_type(base) .NE. dbcsr_data_get_type(msg)) & 513 DBCSR_ABORT("Different data type for msgin and msgout") 514 515 SELECT CASE (dbcsr_data_get_type(base)) 516 CASE (dbcsr_type_real_4) 517 CALL mp_gatherv(base%d%r_sp(:ub_base), msg%d%r_sp, counts, displs, root, grp) 518 CASE (dbcsr_type_real_8) 519 CALL mp_gatherv(base%d%r_dp(:ub_base), msg%d%r_dp, counts, displs, root, grp) 520 CASE (dbcsr_type_complex_4) 521 CALL mp_gatherv(base%d%c_sp(:ub_base), msg%d%c_sp, counts, displs, root, grp) 522 CASE (dbcsr_type_complex_8) 523 CALL mp_gatherv(base%d%c_dp(:ub_base), msg%d%c_dp, counts, displs, root, grp) 524 CASE default 525 DBCSR_ABORT("Incorrect data type") 526 END SELECT 527 END SUBROUTINE dbcsr_gatherv_any 528 529 SUBROUTINE dbcsr_isendrecv_any(msgin, dest, msgout, source, grp, send_request, recv_request) 530 !! Send/Recv function of encapsulated data. 531 TYPE(dbcsr_data_obj), INTENT(IN) :: msgin 532 INTEGER, INTENT(IN) :: dest 533 TYPE(dbcsr_data_obj), INTENT(INOUT) :: msgout 534 INTEGER, INTENT(IN) :: source, grp 535 INTEGER, INTENT(OUT) :: send_request, recv_request 536 537 IF (dbcsr_data_get_type(msgin) .NE. dbcsr_data_get_type(msgout)) & 538 DBCSR_ABORT("Different data type for msgin and msgout") 539 540 SELECT CASE (dbcsr_data_get_type(msgin)) 541 CASE (dbcsr_type_real_4) 542 CALL mp_isendrecv(msgin%d%r_sp, dest, & 543 msgout%d%r_sp, source, & 544 grp, send_request, recv_request) 545 CASE (dbcsr_type_real_8) 546 CALL mp_isendrecv(msgin%d%r_dp, dest, & 547 msgout%d%r_dp, source, & 548 grp, send_request, recv_request) 549 CASE (dbcsr_type_complex_4) 550 CALL mp_isendrecv(msgin%d%c_sp, dest, & 551 msgout%d%c_sp, source, & 552 grp, send_request, recv_request) 553 CASE (dbcsr_type_complex_8) 554 CALL mp_isendrecv(msgin%d%c_dp, dest, & 555 msgout%d%c_dp, source, & 556 grp, send_request, recv_request) 557 CASE default 558 DBCSR_ABORT("Incorrect data type") 559 END SELECT 560 END SUBROUTINE dbcsr_isendrecv_any 561 562 SUBROUTINE dbcsr_allgatherv(send_data, scount, recv_data, recv_count, recv_displ, gid) 563 !! Allgather of encapsulated data 564 !! @note see mp_allgatherv_dv 565 566 TYPE(dbcsr_data_obj), INTENT(IN) :: send_data 567 INTEGER, INTENT(IN) :: scount 568 TYPE(dbcsr_data_obj), INTENT(INOUT) :: recv_data 569 INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN) :: recv_count, recv_displ 570 INTEGER, INTENT(IN) :: gid 571 572 IF (dbcsr_data_get_type(send_data) /= dbcsr_data_get_type(recv_data)) & 573 DBCSR_ABORT("Data type mismatch") 574 SELECT CASE (dbcsr_data_get_type(send_data)) 575 CASE (dbcsr_type_real_4) 576 CALL mp_allgather(send_data%d%r_sp(1:scount), recv_data%d%r_sp, & 577 recv_count, recv_displ, gid) 578 CASE (dbcsr_type_real_8) 579 CALL mp_allgather(send_data%d%r_dp(1:scount), recv_data%d%r_dp, & 580 recv_count, recv_displ, gid) 581 CASE (dbcsr_type_complex_4) 582 CALL mp_allgather(send_data%d%c_sp(1:scount), recv_data%d%c_sp, & 583 recv_count, recv_displ, gid) 584 CASE (dbcsr_type_complex_8) 585 CALL mp_allgather(send_data%d%c_dp(1:scount), recv_data%d%c_dp, & 586 recv_count, recv_displ, gid) 587 CASE default 588 DBCSR_ABORT("Invalid data type") 589 END SELECT 590 END SUBROUTINE dbcsr_allgatherv 591 592#:include '../data/dbcsr.fypp' 593#:for n, nametype1, base1, prec1, kind1, type1, dkind1 in inst_params_float 594 SUBROUTINE hybrid_alltoall_${nametype1}$1(sb, scount, sdispl, & 595 rb, rcount, rdispl, mp_env, most_ptp, remainder_ptp, no_hybrid) 596 !! Row/column and global all-to-all 597 !! 598 !! Communicator selection 599 !! Uses row and column communicators for row/column 600 !! sends. Remaining sends are performed using the global 601 !! communicator. Point-to-point isend/irecv are used if ptp is 602 !! set, otherwise a alltoall collective call is issued. 603 !! see mp_alltoall 604 605 ${type1}$, DIMENSION(:), & 606 CONTIGUOUS, INTENT(in), TARGET :: sb 607 INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN) :: scount, sdispl 608 ${type1}$, DIMENSION(:), & 609 CONTIGUOUS, INTENT(INOUT), TARGET :: rb 610 INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN) :: rcount, rdispl 611 TYPE(dbcsr_mp_obj), INTENT(IN) :: mp_env 612 !! MP Environment 613 LOGICAL, INTENT(in), OPTIONAL :: most_ptp, remainder_ptp, & 614 no_hybrid 615 !! Use point-to-point for row/column; default is no 616 !! Use point-to-point for remaining; default is no 617 !! Use regular global collective; default is no 618 619 INTEGER :: all_group, mynode, mypcol, myprow, nall_rr, nall_sr, ncol_rr, & 620 ncol_sr, npcols, nprows, nrow_rr, nrow_sr, numnodes, dst, src, & 621 prow, pcol, send_cnt, recv_cnt, tag, grp, i 622 INTEGER, ALLOCATABLE, DIMENSION(:) :: all_rr, all_sr, col_rr, col_sr, & 623 new_rcount, new_rdispl, new_scount, new_sdispl, row_rr, row_sr 624 INTEGER, DIMENSION(:, :), CONTIGUOUS, POINTER :: pgrid 625 LOGICAL :: most_collective, & 626 remainder_collective, no_h 627 ${type1}$, DIMENSION(:), CONTIGUOUS, POINTER :: send_data_p, recv_data_p 628 TYPE(dbcsr_mp_obj) :: mpe 629 630 IF (.NOT. dbcsr_mp_has_subgroups(mp_env)) THEN 631 mpe = mp_env 632 CALL dbcsr_mp_grid_setup(mpe) 633 ENDIF 634 most_collective = .TRUE. 635 remainder_collective = .TRUE. 636 no_h = .FALSE. 637 IF (PRESENT(most_ptp)) most_collective = .NOT. most_ptp 638 IF (PRESENT(remainder_ptp)) remainder_collective = .NOT. remainder_ptp 639 IF (PRESENT(no_hybrid)) no_h = no_hybrid 640 all_group = dbcsr_mp_group(mp_env) 641 ! Don't use subcommunicators if they're not defined. 642 no_h = no_h .OR. .NOT. dbcsr_mp_has_subgroups(mp_env) .OR. .NOT. has_MPI 643 subgrouped: IF (mp_env%mp%subgroups_defined .AND. .NOT. no_h) THEN 644 mynode = dbcsr_mp_mynode(mp_env) 645 numnodes = dbcsr_mp_numnodes(mp_env) 646 nprows = dbcsr_mp_nprows(mp_env) 647 npcols = dbcsr_mp_npcols(mp_env) 648 myprow = dbcsr_mp_myprow(mp_env) 649 mypcol = dbcsr_mp_mypcol(mp_env) 650 pgrid => dbcsr_mp_pgrid(mp_env) 651 ALLOCATE (row_sr(0:npcols - 1)); nrow_sr = 0 652 ALLOCATE (row_rr(0:npcols - 1)); nrow_rr = 0 653 ALLOCATE (col_sr(0:nprows - 1)); ncol_sr = 0 654 ALLOCATE (col_rr(0:nprows - 1)); ncol_rr = 0 655 ALLOCATE (all_sr(0:numnodes - 1)); nall_sr = 0 656 ALLOCATE (all_rr(0:numnodes - 1)); nall_rr = 0 657 ALLOCATE (new_scount(numnodes), new_rcount(numnodes)) 658 ALLOCATE (new_sdispl(numnodes), new_rdispl(numnodes)) 659 IF (.NOT. remainder_collective) THEN 660 CALL remainder_point_to_point() 661 ENDIF 662 IF (.NOT. most_collective) THEN 663 CALL most_point_to_point() 664 ELSE 665 CALL most_alltoall() 666 ENDIF 667 IF (remainder_collective) THEN 668 CALL remainder_alltoall() 669 ENDIF 670 ! Wait for all issued sends and receives. 671 IF (.NOT. most_collective) THEN 672 CALL mp_waitall(row_sr(0:nrow_sr - 1)) 673 CALL mp_waitall(col_sr(0:ncol_sr - 1)) 674 CALL mp_waitall(row_rr(0:nrow_rr - 1)) 675 CALL mp_waitall(col_rr(0:ncol_rr - 1)) 676 END IF 677 IF (.NOT. remainder_collective) THEN 678 CALL mp_waitall(all_sr(1:nall_sr)) 679 CALL mp_waitall(all_rr(1:nall_rr)) 680 ENDIF 681 ELSE 682 CALL mp_alltoall(sb, scount, sdispl, & 683 rb, rcount, rdispl, & 684 all_group) 685 ENDIF subgrouped 686 CONTAINS 687 SUBROUTINE most_alltoall() 688 DO pcol = 0, npcols - 1 689 new_scount(1 + pcol) = scount(1 + pgrid(myprow, pcol)) 690 new_rcount(1 + pcol) = rcount(1 + pgrid(myprow, pcol)) 691 new_sdispl(1 + pcol) = sdispl(1 + pgrid(myprow, pcol)) 692 new_rdispl(1 + pcol) = rdispl(1 + pgrid(myprow, pcol)) 693 END DO 694 CALL mp_alltoall(sb, new_scount(1:npcols), new_sdispl(1:npcols), & 695 rb, new_rcount(1:npcols), new_rdispl(1:npcols), & 696 dbcsr_mp_my_row_group(mp_env)) 697 DO prow = 0, nprows - 1 698 new_scount(1 + prow) = scount(1 + pgrid(prow, mypcol)) 699 new_rcount(1 + prow) = rcount(1 + pgrid(prow, mypcol)) 700 new_sdispl(1 + prow) = sdispl(1 + pgrid(prow, mypcol)) 701 new_rdispl(1 + prow) = rdispl(1 + pgrid(prow, mypcol)) 702 END DO 703 CALL mp_alltoall(sb, new_scount(1:nprows), new_sdispl(1:nprows), & 704 rb, new_rcount(1:nprows), new_rdispl(1:nprows), & 705 dbcsr_mp_my_col_group(mp_env)) 706 END SUBROUTINE most_alltoall 707 SUBROUTINE most_point_to_point() 708 ! Go through my prow and exchange. 709 DO i = 0, npcols - 1 710 pcol = MOD(mypcol + i, npcols) 711 grp = dbcsr_mp_my_row_group(mp_env) 712 ! 713 dst = dbcsr_mp_get_process(mp_env, myprow, pcol) 714 send_cnt = scount(dst + 1) 715 IF (send_cnt .GT. 0) THEN 716 send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1) 717 IF (pcol .NE. mypcol) THEN 718 tag = 4*mypcol 719 CALL mp_isend(send_data_p, pcol, grp, row_sr(nrow_sr), tag) 720 nrow_sr = nrow_sr + 1 721 ENDIF 722 ENDIF 723 ! 724 pcol = MODULO(mypcol - i, npcols) 725 src = dbcsr_mp_get_process(mp_env, myprow, pcol) 726 recv_cnt = rcount(src + 1) 727 IF (recv_cnt .GT. 0) THEN 728 recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1) 729 IF (pcol .NE. mypcol) THEN 730 tag = 4*pcol 731 CALL mp_irecv(recv_data_p, pcol, grp, row_rr(nrow_rr), tag) 732 nrow_rr = nrow_rr + 1 733 ELSE 734 CALL memory_copy(recv_data_p, send_data_p, recv_cnt) 735 ENDIF 736 ENDIF 737 ENDDO 738 ! go through my pcol and exchange 739 DO i = 0, nprows - 1 740 prow = MOD(myprow + i, nprows) 741 grp = dbcsr_mp_my_col_group(mp_env) 742 ! 743 dst = dbcsr_mp_get_process(mp_env, prow, mypcol) 744 send_cnt = scount(dst + 1) 745 IF (send_cnt .GT. 0) THEN 746 send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1) 747 IF (prow .NE. myprow) THEN 748 tag = 4*myprow + 1 749 CALL mp_isend(send_data_p, prow, grp, col_sr(ncol_sr), tag) 750 ncol_sr = ncol_sr + 1 751 ENDIF 752 ENDIF 753 ! 754 prow = MODULO(myprow - i, nprows) 755 src = dbcsr_mp_get_process(mp_env, prow, mypcol) 756 recv_cnt = rcount(src + 1) 757 IF (recv_cnt .GT. 0) THEN 758 recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1) 759 IF (prow .NE. myprow) THEN 760 tag = 4*prow + 1 761 CALL mp_irecv(recv_data_p, prow, grp, col_rr(ncol_rr), tag) 762 ncol_rr = ncol_rr + 1 763 ELSE 764 CALL memory_copy(recv_data_p, send_data_p, recv_cnt) 765 ENDIF 766 ENDIF 767 ENDDO 768 END SUBROUTINE most_point_to_point 769 SUBROUTINE remainder_alltoall() 770 new_scount(:) = scount(:) 771 new_rcount(:) = rcount(:) 772 DO prow = 0, nprows - 1 773 new_scount(1 + pgrid(prow, mypcol)) = 0 774 new_rcount(1 + pgrid(prow, mypcol)) = 0 775 END DO 776 DO pcol = 0, npcols - 1 777 new_scount(1 + pgrid(myprow, pcol)) = 0 778 new_rcount(1 + pgrid(myprow, pcol)) = 0 779 END DO 780 CALL mp_alltoall(sb, new_scount, sdispl, & 781 rb, new_rcount, rdispl, all_group) 782 END SUBROUTINE remainder_alltoall 783 SUBROUTINE remainder_point_to_point() 784 INTEGER :: col, row 785 786 DO row = 0, nprows - 1 787 prow = MOD(row + myprow, nprows) 788 IF (prow .EQ. myprow) CYCLE 789 DO col = 0, npcols - 1 790 pcol = MOD(col + mypcol, npcols) 791 IF (pcol .EQ. mypcol) CYCLE 792 dst = dbcsr_mp_get_process(mp_env, prow, pcol) 793 send_cnt = scount(dst + 1) 794 IF (send_cnt .GT. 0) THEN 795 send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1) 796 tag = 4*mynode + 2 797 CALL mp_isend(send_data_p, dst, all_group, all_sr(nall_sr + 1), tag) 798 nall_sr = nall_sr + 1 799 ENDIF 800 ! 801 src = dbcsr_mp_get_process(mp_env, prow, pcol) 802 recv_cnt = rcount(src + 1) 803 IF (recv_cnt .GT. 0) THEN 804 recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1) 805 tag = 4*src + 2 806 CALL mp_irecv(recv_data_p, src, all_group, all_rr(nall_rr + 1), tag) 807 nall_rr = nall_rr + 1 808 ENDIF 809 ENDDO 810 ENDDO 811 END SUBROUTINE remainder_point_to_point 812 END SUBROUTINE hybrid_alltoall_${nametype1}$1 813#:endfor 814 815END MODULE dbcsr_mp_operations 816