1! 2! Copyright (C) 2003-2013 Quantum ESPRESSO group 3! This file is distributed under the terms of the 4! GNU General Public License. See the file `License' 5! in the root directory of the present distribution, 6! or http://www.gnu.org/copyleft/gpl.txt . 7! 8! 9 10SUBROUTINE laxlib_end() 11 use laxlib_processors_grid 12 CALL laxlib_end_drv ( ) 13END SUBROUTINE laxlib_end 14 15 16SUBROUTINE laxlib_getval_ ( nproc_ortho, leg_ortho, np_ortho, me_ortho, ortho_comm, ortho_row_comm, ortho_col_comm, & 17 ortho_comm_id, ortho_parent_comm, me_blacs, np_blacs, ortho_cntx, world_cntx, do_distr_diag_inside_bgrp ) 18 use laxlib_processors_grid, ONLY : & 19 nproc_ortho_ => nproc_ortho, & 20 leg_ortho_ => leg_ortho, & 21 np_ortho_ => np_ortho, & 22 me_ortho_ => me_ortho, & 23 ortho_comm_ => ortho_comm, & 24 ortho_row_comm_ => ortho_row_comm, & 25 ortho_col_comm_ => ortho_col_comm, & 26 ortho_comm_id_ => ortho_comm_id, & 27 ortho_parent_comm_ => ortho_parent_comm, & 28 me_blacs_ => me_blacs, & 29 np_blacs_ => np_blacs, & 30 ortho_cntx_ => ortho_cntx, & 31 world_cntx_ => world_cntx, & 32 do_distr_diag_inside_bgrp_ => do_distr_diag_inside_bgrp 33 IMPLICIT NONE 34 INTEGER, OPTIONAL, INTENT(OUT) :: nproc_ortho 35 INTEGER, OPTIONAL, INTENT(OUT) :: leg_ortho 36 INTEGER, OPTIONAL, INTENT(OUT) :: np_ortho(2) 37 INTEGER, OPTIONAL, INTENT(OUT) :: me_ortho(2) 38 INTEGER, OPTIONAL, INTENT(OUT) :: ortho_comm 39 INTEGER, OPTIONAL, INTENT(OUT) :: ortho_row_comm 40 INTEGER, OPTIONAL, INTENT(OUT) :: ortho_col_comm 41 INTEGER, OPTIONAL, INTENT(OUT) :: ortho_comm_id 42 INTEGER, OPTIONAL, INTENT(OUT) :: ortho_parent_comm 43 INTEGER, OPTIONAL, INTENT(OUT) :: me_blacs 44 INTEGER, OPTIONAL, INTENT(OUT) :: np_blacs 45 INTEGER, OPTIONAL, INTENT(OUT) :: ortho_cntx 46 INTEGER, OPTIONAL, INTENT(OUT) :: world_cntx 47 LOGICAL, OPTIONAL, INTENT(OUT) :: do_distr_diag_inside_bgrp 48 IF( PRESENT(nproc_ortho) ) nproc_ortho = nproc_ortho_ 49 IF( PRESENT(leg_ortho) ) leg_ortho = leg_ortho_ 50 IF( PRESENT(np_ortho) ) np_ortho = np_ortho_ 51 IF( PRESENT(me_ortho) ) me_ortho = me_ortho_ 52 IF( PRESENT(ortho_comm) ) ortho_comm = ortho_comm_ 53 IF( PRESENT(ortho_row_comm) ) ortho_row_comm = ortho_row_comm_ 54 IF( PRESENT(ortho_col_comm) ) ortho_col_comm = ortho_col_comm_ 55 IF( PRESENT(ortho_comm_id) ) ortho_comm_id = ortho_comm_id_ 56 IF( PRESENT(ortho_parent_comm) ) ortho_parent_comm = ortho_parent_comm_ 57 IF( PRESENT(me_blacs) ) me_blacs = me_blacs_ 58 IF( PRESENT(np_blacs) ) np_blacs = np_blacs_ 59 IF( PRESENT(ortho_cntx) ) ortho_cntx = ortho_cntx_ 60 IF( PRESENT(world_cntx) ) world_cntx = world_cntx_ 61 IF( PRESENT(do_distr_diag_inside_bgrp) ) do_distr_diag_inside_bgrp = do_distr_diag_inside_bgrp_ 62END SUBROUTINE 63! 64SUBROUTINE laxlib_get_status_x ( lax_status ) 65 use laxlib_processors_grid, ONLY : & 66 nproc_ortho_ => nproc_ortho, & 67 leg_ortho_ => leg_ortho, & 68 np_ortho_ => np_ortho, & 69 me_ortho_ => me_ortho, & 70 ortho_comm_ => ortho_comm, & 71 ortho_row_comm_ => ortho_row_comm, & 72 ortho_col_comm_ => ortho_col_comm, & 73 ortho_comm_id_ => ortho_comm_id, & 74 ortho_parent_comm_ => ortho_parent_comm, & 75 me_blacs_ => me_blacs, & 76 np_blacs_ => np_blacs, & 77 ortho_cntx_ => ortho_cntx, & 78 world_cntx_ => world_cntx, & 79 do_distr_diag_inside_bgrp_ => do_distr_diag_inside_bgrp 80 IMPLICIT NONE 81 include 'laxlib_param.fh' 82 INTEGER, INTENT(OUT) :: LAX_STATUS(:) 83 lax_status(LAX_STATUS_NPROC)= nproc_ortho_ 84 lax_status(LAX_STATUS_LEG)= leg_ortho_ 85 lax_status(LAX_STATUS_NP1)= np_ortho_( 1 ) 86 lax_status(LAX_STATUS_NP2)= np_ortho_( 2 ) 87 lax_status(LAX_STATUS_ME1)= me_ortho_( 1 ) 88 lax_status(LAX_STATUS_ME2)= me_ortho_( 2 ) 89 lax_status(LAX_STATUS_COMM)= ortho_comm_ 90 lax_status(LAX_STATUS_ROWCOMM)= ortho_row_comm_ 91 lax_status(LAX_STATUS_COLCOMM)= ortho_col_comm_ 92 lax_status(LAX_STATUS_COMMID)= ortho_comm_id_ 93 lax_status(LAX_STATUS_PARENTCOMM)= ortho_parent_comm_ 94 lax_status(LAX_STATUS_MEBLACS)= me_blacs_ 95 lax_status(LAX_STATUS_NPBLACS)= np_blacs_ 96 lax_status(LAX_STATUS_ORTHOCNTX)= ortho_cntx_ 97 lax_status(LAX_STATUS_WORLDCNTX)= world_cntx_ 98 IF( do_distr_diag_inside_bgrp_ ) THEN 99 lax_status(LAX_STATUS_DISTDIAG)= 1 100 ELSE 101 lax_status(LAX_STATUS_DISTDIAG)= 2 102 END IF 103END SUBROUTINE 104 105!---------------------------------------------------------------------------- 106 107SUBROUTINE laxlib_start_drv( ndiag_, my_world_comm, parent_comm, do_distr_diag_inside_bgrp_ ) 108 ! 109 use laxlib_processors_grid 110 USE laxlib_parallel_include 111 ! 112 ! 113 ! ... Ortho/diag/linear algebra group initialization 114 ! 115 IMPLICIT NONE 116 ! 117 INTEGER, INTENT(INOUT) :: ndiag_ ! (IN) input number of procs in the diag group, (OUT) actual number 118 INTEGER, INTENT(IN) :: my_world_comm ! parallel communicator of the "local" world 119 INTEGER, INTENT(IN) :: parent_comm ! parallel communicator inside which the distributed linear algebra group 120 ! communicators are created 121 LOGICAL, INTENT(IN) :: do_distr_diag_inside_bgrp_ ! comme son nom l'indique 122 ! 123 INTEGER :: mpime = 0 ! the global MPI task index (used in clocks) can be set with a laxlib_rank call 124 ! 125 INTEGER :: nproc_ortho_try 126 INTEGER :: parent_nproc ! nproc of the parent group 127 INTEGER :: world_nproc ! nproc of the world group 128 INTEGER :: my_parent_id ! id of the parent communicator 129 INTEGER :: nparent_comm ! mumber of parent communicators 130 INTEGER :: ierr = 0 131 ! 132 IF( lax_is_initialized ) & 133 CALL laxlib_end_drv ( ) 134 135 world_nproc = laxlib_size( my_world_comm ) ! the global number of processors in world_comm 136 mpime = laxlib_rank( my_world_comm ) ! set the global MPI task index (used in clocks) 137 parent_nproc = laxlib_size( parent_comm )! the number of processors in the current parent communicator 138 my_parent_id = mpime / parent_nproc ! set the index of the current parent communicator 139 nparent_comm = world_nproc/parent_nproc ! number of paren communicators 140 141 ! save input value inside the module 142 do_distr_diag_inside_bgrp = do_distr_diag_inside_bgrp_ 143 144 ! 145#if defined __SCALAPACK 146 np_blacs = laxlib_size( my_world_comm ) 147 me_blacs = laxlib_rank( my_world_comm ) 148 ! 149 ! define a 1D grid containing all MPI tasks of the global communicator 150 ! NOTE: world_cntx has the MPI communicator on entry and the BLACS context on exit 151 ! BLACS_GRIDINIT() will create a copy of the communicator, which can be 152 ! later retrieved using CALL BLACS_GET(world_cntx, 10, comm_copy) 153 ! 154 world_cntx = my_world_comm 155 CALL BLACS_GRIDINIT( world_cntx, 'Row', 1, np_blacs ) 156 ! 157#endif 158 ! 159 IF( ndiag_ > 0 ) THEN 160 ! command-line argument -ndiag N or -northo N set to a value N 161 ! use the command line value ensuring that it falls in the proper range 162 nproc_ortho_try = MIN( ndiag_ , parent_nproc ) 163 ELSE 164 ! no command-line argument -ndiag N or -northo N is present 165 ! insert here custom architecture specific default definitions 166#if defined(__SCALAPACK) && !defined(__CUDA) 167 nproc_ortho_try = MAX( parent_nproc, 1 ) 168#else 169 nproc_ortho_try = 1 170#endif 171 END IF 172 ! 173 ! the ortho group for parallel linear algebra is a sub-group of the pool, 174 ! then there are as many ortho groups as pools. 175 ! 176 CALL init_ortho_group ( nproc_ortho_try, my_world_comm, parent_comm, nparent_comm, my_parent_id ) 177 ! 178 ! set the number of processors in the diag group to the actual number used 179 ! 180 ndiag_ = nproc_ortho 181 ! 182 lax_is_initialized = .true. 183 ! 184 RETURN 185 ! 186CONTAINS 187 188 SUBROUTINE init_ortho_group ( nproc_try_in, my_world_comm, comm_all, nparent_comm, my_parent_id ) 189 ! 190 IMPLICIT NONE 191 192 INTEGER, INTENT(IN) :: nproc_try_in, comm_all 193 INTEGER, INTENT(IN) :: my_world_comm ! parallel communicator of the "local" world 194 INTEGER, INTENT(IN) :: nparent_comm 195 INTEGER, INTENT(IN) :: my_parent_id ! id of the parent communicator 196 197 INTEGER :: ierr, color, key, me_all, nproc_all, nproc_try 198 199#if defined __SCALAPACK 200 INTEGER, ALLOCATABLE :: blacsmap(:,:) 201 INTEGER, ALLOCATABLE :: ortho_cntx_pe(:) 202 INTEGER :: nprow, npcol, myrow, mycol, i, j, k 203 INTEGER, EXTERNAL :: BLACS_PNUM 204#endif 205 206#if defined __MPI 207 208 me_all = laxlib_rank( comm_all ) 209 ! 210 nproc_all = laxlib_size( comm_all ) 211 ! 212 nproc_try = MIN( nproc_try_in, nproc_all ) 213 nproc_try = MAX( nproc_try, 1 ) 214 215 ! find the square closer (but lower) to nproc_try 216 ! 217 CALL grid2d_dims( 'S', nproc_try, np_ortho(1), np_ortho(2) ) 218 ! 219 ! now, and only now, it is possible to define the number of tasks 220 ! in the ortho group for parallel linear algebra 221 ! 222 nproc_ortho = np_ortho(1) * np_ortho(2) 223 ! 224 IF( nproc_all >= 4*nproc_ortho ) THEN 225 ! 226 ! here we choose a processor every 4, in order not to stress memory BW 227 ! on multi core procs, for which further performance enhancements are 228 ! possible using OpenMP BLAS inside regter/cegter/rdiaghg/cdiaghg 229 ! (to be implemented) 230 ! 231 color = 0 232 IF( me_all < 4*nproc_ortho .AND. MOD( me_all, 4 ) == 0 ) color = 1 233 ! 234 leg_ortho = 4 235 ! 236 ELSE IF( nproc_all >= 2*nproc_ortho ) THEN 237 ! 238 ! here we choose a processor every 2, in order not to stress memory BW 239 ! 240 color = 0 241 IF( me_all < 2*nproc_ortho .AND. MOD( me_all, 2 ) == 0 ) color = 1 242 ! 243 leg_ortho = 2 244 ! 245 ELSE 246 ! 247 ! here we choose the first processors 248 ! 249 color = 0 250 IF( me_all < nproc_ortho ) color = 1 251 ! 252 leg_ortho = 1 253 ! 254 END IF 255 ! 256 key = me_all 257 ! 258 ! initialize the communicator for the new group by splitting the input communicator 259 ! 260 CALL laxlib_comm_split ( comm_all, color, key, ortho_comm ) 261 ! 262 ! and remember where it comes from 263 ! 264 ortho_parent_comm = comm_all 265 ! 266 ! Computes coordinates of the processors, in row maior order 267 ! 268 me_ortho1 = laxlib_rank( ortho_comm ) 269 ! 270 IF( me_all == 0 .AND. me_ortho1 /= 0 ) & 271 CALL lax_error__( " init_ortho_group ", " wrong root task in ortho group ", ierr ) 272 ! 273 if( color == 1 ) then 274 ortho_comm_id = 1 275 CALL GRID2D_COORDS( 'R', me_ortho1, np_ortho(1), np_ortho(2), me_ortho(1), me_ortho(2) ) 276 CALL GRID2D_RANK( 'R', np_ortho(1), np_ortho(2), me_ortho(1), me_ortho(2), ierr ) 277 IF( ierr /= me_ortho1 ) & 278 CALL lax_error__( " init_ortho_group ", " wrong task coordinates in ortho group ", ierr ) 279 IF( me_ortho1*leg_ortho /= me_all ) & 280 CALL lax_error__( " init_ortho_group ", " wrong rank assignment in ortho group ", ierr ) 281 282 CALL laxlib_comm_split( ortho_comm, me_ortho(2), me_ortho(1), ortho_col_comm) 283 CALL laxlib_comm_split( ortho_comm, me_ortho(1), me_ortho(2), ortho_row_comm) 284 285 else 286 ortho_comm_id = 0 287 me_ortho(1) = me_ortho1 288 me_ortho(2) = me_ortho1 289 endif 290 291#if defined __SCALAPACK 292 ! 293 ! This part is used to eliminate the image dependency from ortho groups 294 ! SCALAPACK is now independent from whatever level of parallelization 295 ! is present on top of pool parallelization 296 ! 297 ALLOCATE( ortho_cntx_pe( nparent_comm ) ) 298 ALLOCATE( blacsmap( np_ortho(1), np_ortho(2) ) ) 299 300 DO j = 1, nparent_comm 301 302 CALL BLACS_GET(world_cntx, 10, ortho_cntx_pe( j ) ) ! retrieve communicator of world context 303 blacsmap = 0 304 nprow = np_ortho(1) 305 npcol = np_ortho(2) 306 307 IF( ( j == ( my_parent_id + 1 ) ) .and. ( ortho_comm_id > 0 ) ) THEN 308 309 blacsmap( me_ortho(1) + 1, me_ortho(2) + 1 ) = BLACS_PNUM( world_cntx, 0, me_blacs ) 310 311 END IF 312 313 ! All MPI tasks defined in the global communicator take part in the definition of the BLACS grid 314 315 CALL MPI_ALLREDUCE( MPI_IN_PLACE, blacsmap, SIZE(blacsmap), MPI_INTEGER, MPI_SUM, my_world_comm, ierr ) 316 IF( ierr /= 0 ) & 317 CALL lax_error__( ' init_ortho_group ', ' problem in MPI_ALLREDUCE of blacsmap ', ierr ) 318 319 CALL BLACS_GRIDMAP( ortho_cntx_pe( j ), blacsmap, nprow, nprow, npcol ) 320 321 CALL BLACS_GRIDINFO( ortho_cntx_pe( j ), nprow, npcol, myrow, mycol ) 322 323 IF( ( j == ( my_parent_id + 1 ) ) .and. ( ortho_comm_id > 0 ) ) THEN 324 325 IF( np_ortho(1) /= nprow ) & 326 CALL lax_error__( ' init_ortho_group ', ' problem with SCALAPACK, wrong no. of task rows ', 1 ) 327 IF( np_ortho(2) /= npcol ) & 328 CALL lax_error__( ' init_ortho_group ', ' problem with SCALAPACK, wrong no. of task columns ', 1 ) 329 IF( me_ortho(1) /= myrow ) & 330 CALL lax_error__( ' init_ortho_group ', ' problem with SCALAPACK, wrong task row ID ', 1 ) 331 IF( me_ortho(2) /= mycol ) & 332 CALL lax_error__( ' init_ortho_group ', ' problem with SCALAPACK, wrong task columns ID ', 1 ) 333 334 ortho_cntx = ortho_cntx_pe( j ) 335 336 END IF 337 338 END DO 339 340 DEALLOCATE( blacsmap ) 341 DEALLOCATE( ortho_cntx_pe ) 342 343 ! end SCALAPACK code block 344 345#endif 346 347#else 348 349 ortho_comm_id = 1 350 351#endif 352 353 RETURN 354 END SUBROUTINE init_ortho_group 355 356END SUBROUTINE laxlib_start_drv 357 358!------------------------------------------------------------------------------! 359 360SUBROUTINE print_lambda_x( lambda, idesc, n, nshow, nudx, ccc, ionode, iunit ) 361 IMPLICIT NONE 362 include 'laxlib_low.fh' 363 include 'laxlib_kinds.fh' 364 real(DP), intent(in) :: lambda(:,:,:), ccc 365 INTEGER, INTENT(IN) :: idesc(:,:) 366 integer, intent(in) :: n, nshow, nudx 367 logical, intent(in) :: ionode 368 integer, intent(in) :: iunit 369 ! 370 integer :: nnn, j, i, is 371 real(DP), allocatable :: lambda_repl(:,:) 372 nnn = min( nudx, nshow ) 373 ALLOCATE( lambda_repl( nudx, nudx ) ) 374 IF( ionode ) WRITE( iunit,*) 375 DO is = 1, SIZE( lambda, 3 ) 376 CALL collect_lambda( lambda_repl, lambda(:,:,is), idesc(:,is) ) 377 IF( ionode ) THEN 378 WRITE( iunit,3370) ' lambda nudx, spin = ', nudx, is 379 IF( nnn < n ) WRITE( iunit,3370) ' print only first ', nnn 380 DO i=1,nnn 381 WRITE( iunit,3380) (lambda_repl(i,j)*ccc,j=1,nnn) 382 END DO 383 END IF 384 END DO 385 DEALLOCATE( lambda_repl ) 3863370 FORMAT(26x,a,2i4) 3873380 FORMAT(9f8.4) 388 RETURN 389END SUBROUTINE print_lambda_x 390 391 SUBROUTINE laxlib_desc_init1( nsiz, nx, la_proc, idesc, rank_ip, idesc_ip ) 392 ! 393 IMPLICIT NONE 394 include 'laxlib_low.fh' 395 include 'laxlib_param.fh' 396 include 'laxlib_kinds.fh' 397 ! 398 INTEGER, INTENT(IN) :: nsiz 399 INTEGER, INTENT(OUT) :: nx 400 LOGICAL, INTENT(OUT) :: la_proc 401 INTEGER, INTENT(OUT) :: idesc(LAX_DESC_SIZE) 402 INTEGER, INTENT(OUT), ALLOCATABLE :: rank_ip( :, : ) 403 INTEGER, INTENT(OUT), ALLOCATABLE :: idesc_ip(:,:,:) 404 ! 405 INTEGER :: ortho_comm, np_ortho(2), me_ortho(2), ortho_comm_id, & 406 leg_ortho 407 ! 408 CALL laxlib_getval( np_ortho = np_ortho, me_ortho = me_ortho, & 409 ortho_comm = ortho_comm, leg_ortho = leg_ortho, & 410 ortho_comm_id = ortho_comm_id ) 411 ! 412 IF ( .NOT. ALLOCATED (idesc_ip) ) THEN 413 ALLOCATE( idesc_ip( LAX_DESC_SIZE, np_ortho(1), np_ortho(2) ) ) 414 ELSE 415 IF ( SIZE (idesc_ip,2) /= np_ortho(1) .OR. & 416 SIZE (idesc_ip,3) /= np_ortho(2) ) & 417 CALL lax_error__( " desc_init ", " inconsistent dimension ", 2 ) 418 END IF 419 IF ( .NOT. ALLOCATED (rank_ip) ) & 420 ALLOCATE( rank_ip( np_ortho(1), np_ortho(2) ) ) 421 ! 422 CALL laxlib_init_desc( idesc, idesc_ip, rank_ip, nsiz, nsiz ) 423 ! 424 nx = idesc(LAX_DESC_NRCX) 425 ! 426 la_proc = .FALSE. 427 IF( idesc(LAX_DESC_ACTIVE_NODE) > 0 ) la_proc = .TRUE. 428 ! 429 RETURN 430 END SUBROUTINE laxlib_desc_init1 431 ! 432 SUBROUTINE laxlib_desc_init2( nsiz, nx, la_proc, idesc, rank_ip, irc_ip, nrc_ip ) 433 ! 434 IMPLICIT NONE 435 include 'laxlib_low.fh' 436 include 'laxlib_param.fh' 437 include 'laxlib_kinds.fh' 438 ! 439 INTEGER, INTENT(IN) :: nsiz 440 INTEGER, INTENT(OUT) :: nx 441 LOGICAL, INTENT(OUT) :: la_proc 442 INTEGER, INTENT(OUT) :: idesc(LAX_DESC_SIZE) 443 INTEGER, INTENT(OUT), ALLOCATABLE :: rank_ip(:,:) 444 INTEGER, INTENT(OUT), ALLOCATABLE :: irc_ip(:) 445 INTEGER, INTENT(OUT), ALLOCATABLE :: nrc_ip(:) 446 447 INTEGER :: i, j, rank 448 INTEGER :: ortho_comm, np_ortho(2), me_ortho(2), ortho_comm_id, & 449 leg_ortho, ortho_cntx 450 ! 451 CALL laxlib_getval( np_ortho = np_ortho, me_ortho = me_ortho, & 452 ortho_comm = ortho_comm, leg_ortho = leg_ortho, & 453 ortho_comm_id = ortho_comm_id, ortho_cntx = ortho_cntx ) 454 ! 455 CALL laxlib_init_desc( idesc, nsiz, nsiz, np_ortho, me_ortho, & 456 ortho_comm, ortho_cntx, ortho_comm_id ) 457 ! 458 nx = idesc(LAX_DESC_NRCX) 459 ! 460 IF ( .NOT. ALLOCATED (rank_ip) ) THEN 461 ALLOCATE( rank_ip( np_ortho(1), np_ortho(2) ) ) 462 ALLOCATE( irc_ip( np_ortho(1) ), nrc_ip (np_ortho(1) ) ) 463 ELSE 464 IF ( SIZE (rank_ip,1) /= np_ortho(1) .OR. & 465 SIZE (rank_ip,2) /= np_ortho(2) ) & 466 CALL lax_error__( " desc_init ", " inconsistent dimension ", 1 ) 467 END IF 468 DO j = 0, idesc(LAX_DESC_NPC) - 1 469 CALL laxlib_local_dims( irc_ip( j + 1 ), nrc_ip( j + 1 ), & 470 idesc(LAX_DESC_N), idesc(LAX_DESC_NX), np_ortho(1), j ) 471 DO i = 0, idesc(LAX_DESC_NPR) - 1 472 CALL GRID2D_RANK( 'R', idesc(LAX_DESC_NPR), idesc(LAX_DESC_NPC), i, j, rank ) 473 rank_ip( i+1, j+1 ) = rank * leg_ortho 474 END DO 475 END DO 476 ! 477 la_proc = .FALSE. 478 IF( idesc(LAX_DESC_ACTIVE_NODE) > 0 ) la_proc = .TRUE. 479 ! 480 RETURN 481 END SUBROUTINE laxlib_desc_init2 482 ! 483 ! 484SUBROUTINE laxlib_init_desc_x( idesc, n, nx, np, me, comm, cntx, comm_id ) 485 USE laxlib_descriptor, ONLY: la_descriptor, descla_init, laxlib_desc_to_intarray 486 IMPLICIT NONE 487 include 'laxlib_param.fh' 488 INTEGER, INTENT(OUT) :: idesc(LAX_DESC_SIZE) 489 INTEGER, INTENT(IN) :: n ! the size of this matrix 490 INTEGER, INTENT(IN) :: nx ! the max among different matrixes sharing this descriptor or the same data distribution 491 INTEGER, INTENT(IN) :: np(2), me(2), comm, cntx 492 INTEGER, INTENT(IN) :: comm_id 493 ! 494 TYPE(la_descriptor) :: descla 495 ! 496 CALL descla_init( descla, n, nx, np, me, comm, cntx, comm_id ) 497 CALL laxlib_desc_to_intarray( idesc, descla ) 498 RETURN 499END SUBROUTINE laxlib_init_desc_x 500 501 502SUBROUTINE laxlib_multi_init_desc_x( idesc, idesc_ip, rank_ip, n, nx ) 503 USE laxlib_descriptor, ONLY: la_descriptor, descla_init, laxlib_desc_to_intarray 504 use laxlib_processors_grid, ONLY: leg_ortho, np_ortho, me_ortho, ortho_comm, ortho_comm_id, ortho_cntx 505 IMPLICIT NONE 506 include 'laxlib_param.fh' 507 INTEGER, INTENT(OUT) :: idesc(LAX_DESC_SIZE) 508 INTEGER, INTENT(OUT) :: idesc_ip(:,:,:) 509 INTEGER, INTENT(OUT) :: rank_ip(:,:) 510 INTEGER, INTENT(IN) :: n ! the size of this matrix 511 INTEGER, INTENT(IN) :: nx ! the max among different matrixes sharing this descriptor or the same data distribution 512 513 INTEGER :: i, j, rank, includeme 514 INTEGER :: coor_ip( 2 ) 515 ! 516 TYPE(la_descriptor) :: descla 517 ! 518 CALL descla_init( descla, n, nx, np_ortho, me_ortho, ortho_comm, ortho_cntx, ortho_comm_id ) 519 ! 520 CALL laxlib_desc_to_intarray( idesc, descla ) 521 ! 522 includeme = 1 523 ! 524 DO j = 0, idesc(LAX_DESC_NPC) - 1 525 DO i = 0, idesc(LAX_DESC_NPR) - 1 526 coor_ip( 1 ) = i 527 coor_ip( 2 ) = j 528 CALL descla_init( descla, idesc(LAX_DESC_N), idesc(LAX_DESC_NX), & 529 np_ortho, coor_ip, ortho_comm, ortho_cntx, includeme ) 530 CALL laxlib_desc_to_intarray( idesc_ip(:,i+1,j+1), descla ) 531 CALL GRID2D_RANK( 'R', idesc(LAX_DESC_NPR), idesc(LAX_DESC_NPC), i, j, rank ) 532 rank_ip( i+1, j+1 ) = rank * leg_ortho 533 END DO 534 END DO 535 ! 536 RETURN 537END SUBROUTINE laxlib_multi_init_desc_x 538 539 SUBROUTINE descla_local_dims( i2g, nl, n, nx, np, me ) 540 IMPLICIT NONE 541 INTEGER, INTENT(OUT) :: i2g ! global index of the first local element 542 INTEGER, INTENT(OUT) :: nl ! local number of elements 543 INTEGER, INTENT(IN) :: n ! number of actual element in the global array 544 INTEGER, INTENT(IN) :: nx ! dimension of the global array (nx>=n) to be distributed 545 INTEGER, INTENT(IN) :: np ! number of processors 546 INTEGER, INTENT(IN) :: me ! taskid for which i2g and nl are computed 547 ! 548 ! note that we can distribute a global array larger than the 549 ! number of actual elements. This could be required for performance 550 ! reasons, and to have an equal partition of matrix having different size 551 ! like matrixes of spin-up and spin-down 552 ! 553 INTEGER, EXTERNAL :: ldim_block, ldim_cyclic, ldim_block_sca 554 INTEGER, EXTERNAL :: gind_block, gind_cyclic, gind_block_sca 555 ! 556#if __SCALAPACK 557 nl = ldim_block_sca( nx, np, me ) 558 i2g = gind_block_sca( 1, nx, np, me ) 559#else 560 nl = ldim_block( nx, np, me ) 561 i2g = gind_block( 1, nx, np, me ) 562#endif 563 ! This is to try to keep a matrix N * N into the same 564 ! distribution of a matrix NX * NX, useful to have 565 ! the matrix of spin-up distributed in the same way 566 ! of the matrix of spin-down 567 ! 568 IF( i2g + nl - 1 > n ) nl = n - i2g + 1 569 IF( nl < 0 ) nl = 0 570 RETURN 571 ! 572 END SUBROUTINE descla_local_dims 573 574 575! ---------------------------------------------- 576! Simplified driver 577 578 SUBROUTINE diagonalize_parallel_x( n, rhos, rhod, s, idesc ) 579 580 USE dspev_module 581 582 IMPLICIT NONE 583 include 'laxlib_kinds.fh' 584 include 'laxlib_param.fh' 585 include 'laxlib_mid.fh' 586 include 'laxlib_low.fh' 587 REAL(DP), INTENT(IN) :: rhos(:,:) ! input symmetric matrix 588 REAL(DP) :: rhod(:) ! output eigenvalues 589 REAL(DP) :: s(:,:) ! output eigenvectors 590 INTEGER, INTENT(IN) :: n ! size of the global matrix 591 INTEGER, INTENT(IN) :: idesc(LAX_DESC_SIZE) 592 593 IF( n < 1 ) RETURN 594 595 ! Matrix is distributed on the same processors group 596 ! used for parallel matrix multiplication 597 ! 598 IF( SIZE(s,1) /= SIZE(rhos,1) .OR. SIZE(s,2) /= SIZE(rhos,2) ) & 599 CALL lax_error__( " diagonalize_parallel ", " inconsistent dimension for s and rhos ", 1 ) 600 601 IF ( idesc(LAX_DESC_ACTIVE_NODE) > 0 ) THEN 602 ! 603 IF( SIZE(s,1) /= idesc(LAX_DESC_NRCX) ) & 604 CALL lax_error__( " diagonalize_parallel ", " inconsistent dimension ", 1) 605 ! 606 ! Compute local dimension of the cyclically distributed matrix 607 ! 608 s = rhos 609 ! 610#if defined(__SCALAPACK) 611 CALL pdsyevd_drv( .true. , n, idesc(LAX_DESC_NRCX), s, SIZE(s,1), rhod, idesc(LAX_DESC_CNTX), idesc(LAX_DESC_COMM) ) 612#else 613 CALL laxlib_pdsyevd( .true., n, idesc, s, SIZE(s,1), rhod ) 614#endif 615 ! 616 END IF 617 618 RETURN 619 620 END SUBROUTINE diagonalize_parallel_x 621 622 623 SUBROUTINE diagonalize_serial_x( n, rhos, rhod ) 624 IMPLICIT NONE 625 include 'laxlib_kinds.fh' 626 include 'laxlib_low.fh' 627 INTEGER, INTENT(IN) :: n 628 REAL(DP) :: rhos(:,:) 629 REAL(DP) :: rhod(:) 630 ! 631 ! inputs: 632 ! n size of the eigenproblem 633 ! rhos the symmetric matrix 634 ! outputs: 635 ! rhos eigenvectors 636 ! rhod eigenvalues 637 ! 638 REAL(DP), ALLOCATABLE :: aux(:) 639 INTEGER :: i, j, k 640 641 IF( n < 1 ) RETURN 642 643 ALLOCATE( aux( n * ( n + 1 ) / 2 ) ) 644 645 ! pack lower triangle of rho into aux 646 ! 647 k = 0 648 DO j = 1, n 649 DO i = j, n 650 k = k + 1 651 aux( k ) = rhos( i, j ) 652 END DO 653 END DO 654 655 CALL dspev_drv( 'V', 'L', n, aux, rhod, rhos, SIZE(rhos,1) ) 656 657 DEALLOCATE( aux ) 658 659 RETURN 660 END SUBROUTINE diagonalize_serial_x 661 662 SUBROUTINE diagonalize_serial_gpu( m, rhos, rhod, s, info ) 663#if defined(__CUDA) 664 use cudafor 665#if defined ( __USE_CUSOLVER ) 666 USE cusolverDn 667#else 668 use eigsolve_vars 669 use nvtx_inters 670 use dsyevd_gpu 671#endif 672 IMPLICIT NONE 673 include 'laxlib_kinds.fh' 674 INTEGER, INTENT(IN) :: m 675 REAL(DP), DEVICE, INTENT(IN) :: rhos(:,:) 676 REAL(DP), DEVICE, INTENT(OUT) :: rhod(:) 677 REAL(DP), DEVICE, INTENT(OUT) :: s(:,:) 678 INTEGER, INTENT(OUT) :: info 679 ! 680 INTEGER :: lwork_d 681 INTEGER :: i, j, lda 682 ! 683#if defined (__USE_CUSOLVER) 684 ! 685 INTEGER, DEVICE :: devInfo 686 TYPE(cusolverDnHandle) :: cuSolverHandle 687 REAL(DP), ALLOCATABLE, DEVICE :: work_d(:) 688 ! 689#else 690 ! 691 REAL(DP), ALLOCATABLE :: work_d(:), a(:,:) 692 ATTRIBUTES( DEVICE ) :: work_d, a 693 REAL(DP), ALLOCATABLE :: b(:,:) 694 REAL(DP), ALLOCATABLE :: work_h(:), w_h(:), z_h(:,:) 695 ATTRIBUTES( PINNED ) :: work_h, w_h, z_h 696 INTEGER, ALLOCATABLE :: iwork_h(:) 697 ATTRIBUTES( PINNED ) :: iwork_h 698 ! 699 INTEGER :: lwork_h, liwork_h 700 ! 701#endif 702 ! .... Subroutine Body 703 ! 704#if defined (__USE_CUSOLVER) 705 ! 706 s = rhos 707 lda = SIZE( rhos, 1 ) 708 ! 709 info = cusolverDnCreate(cuSolverHandle) 710 IF ( info /= CUSOLVER_STATUS_SUCCESS ) & 711 CALL lax_error__( ' diagonalize_serial_gpu ', 'cusolverDnCreate', ABS( info ) ) 712 713 info = cusolverDnDsyevd_bufferSize( & 714 cuSolverHandle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, m, s, lda, rhod, lwork_d) 715 IF( info /= CUSOLVER_STATUS_SUCCESS ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' error in solver 1 ', ABS( info ) ) 716 717 ALLOCATE( work_d ( lwork_d ), STAT=info ) 718 IF( info /= 0 ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' allocate work_d ', ABS( info ) ) 719 720 info = cusolverDnDsyevd( & 721 cuSolverHandle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, m, s, lda, rhod, work_d, lwork_d, devInfo) 722 IF( info /= 0 ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' error in solver 2 ', ABS( info ) ) 723 724 info = cudaDeviceSynchronize() 725 726 info = cusolverDnDestroy(cuSolverHandle) 727 IF( info /= CUSOLVER_STATUS_SUCCESS ) CALL lax_error__( ' diagonalize_serial_gpu ', ' cusolverDnDestroy failed ', ABS( info ) ) 728 729 DEALLOCATE( work_d ) 730 731 ! 732#else 733 ! 734 info = 0 735 lwork_d = 2*64*64 + 66*SIZE(rhos,1) 736 lwork_h = 1 + 6*SIZE(rhos,1) + 2*SIZE(rhos,1)*SIZE(rhos,1) 737 liwork_h = 3 + 5*SIZE(rhos,1) 738 ALLOCATE(work_d(lwork_d),STAT = info) 739 IF( info /= 0 ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' allocate work_d ', ABS( info ) ) 740 ALLOCATE(a(SIZE(rhos,1),SIZE(rhos,2)),STAT = info) 741 IF( info /= 0 ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' allocate a ', ABS( info ) ) 742 ALLOCATE(work_h(lwork_h),STAT = info) 743 IF( info /= 0 ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' allocate work_h ', ABS( info ) ) 744 ALLOCATE(iwork_h(liwork_h),STAT = info) 745 IF( info /= 0 ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' allocate iwork_h ', ABS( info ) ) 746 ! 747 ALLOCATE(w_h(SIZE(rhod)),STAT = info) 748 IF( info /= 0 ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' allocate w_h ', ABS( info ) ) 749 ALLOCATE(z_h(SIZE(s,1),SIZE(s,2)),STAT = info) 750 IF( info /= 0 ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' allocate z_h ', ABS( info ) ) 751 752 if(initialized == 0) call init_eigsolve_gpu 753 754 info = cudaMemcpy(a, rhos, SIZE(rhos,1)*SIZE(rhos,2), cudaMemcpyDeviceToDevice) 755 lda = SIZE(rhos,1) 756 !$cuf kernel do(2) <<<*,*>>> 757 do j = 1,m 758 do i = 1,m 759 if (i > j) then 760 s(i,j) = a(i,j) 761 endif 762 end do 763 end do 764 765 call dsyevd_gpu('V', 'U', 1, m, m, a, lda, s, lda, rhod, work_d, lwork_d, & 766 work_h, lwork_h, iwork_h, liwork_h, z_h, lda, w_h, info) 767 768 DEALLOCATE(z_h) 769 DEALLOCATE(w_h) 770 DEALLOCATE(iwork_h) 771 DEALLOCATE(work_h) 772 DEALLOCATE(a) 773 DEALLOCATE(work_d) 774#endif 775#else 776 CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' not compiled in this version ', 0 ) 777#endif 778 END SUBROUTINE 779 780