1!
2! Copyright (C) 2003-2013 Quantum ESPRESSO group
3! This file is distributed under the terms of the
4! GNU General Public License. See the file `License'
5! in the root directory of the present distribution,
6! or http://www.gnu.org/copyleft/gpl.txt .
7!
8!
9
10SUBROUTINE laxlib_end()
11  use laxlib_processors_grid
12  CALL laxlib_end_drv ( )
13END SUBROUTINE laxlib_end
14
15
16SUBROUTINE laxlib_getval_ ( nproc_ortho, leg_ortho, np_ortho, me_ortho, ortho_comm, ortho_row_comm, ortho_col_comm, &
17  ortho_comm_id, ortho_parent_comm, me_blacs, np_blacs, ortho_cntx, world_cntx, do_distr_diag_inside_bgrp  )
18  use laxlib_processors_grid, ONLY : &
19    nproc_ortho_ => nproc_ortho, &
20    leg_ortho_   => leg_ortho, &
21    np_ortho_    => np_ortho, &
22    me_ortho_    => me_ortho, &
23    ortho_comm_  => ortho_comm, &
24    ortho_row_comm_ => ortho_row_comm, &
25    ortho_col_comm_ => ortho_col_comm, &
26    ortho_comm_id_  => ortho_comm_id, &
27    ortho_parent_comm_ => ortho_parent_comm, &
28    me_blacs_    => me_blacs,  &
29    np_blacs_    => np_blacs, &
30    ortho_cntx_  => ortho_cntx, &
31    world_cntx_  => world_cntx, &
32    do_distr_diag_inside_bgrp_ => do_distr_diag_inside_bgrp
33  IMPLICIT NONE
34  INTEGER, OPTIONAL, INTENT(OUT) :: nproc_ortho
35  INTEGER, OPTIONAL, INTENT(OUT) :: leg_ortho
36  INTEGER, OPTIONAL, INTENT(OUT) :: np_ortho(2)
37  INTEGER, OPTIONAL, INTENT(OUT) :: me_ortho(2)
38  INTEGER, OPTIONAL, INTENT(OUT) :: ortho_comm
39  INTEGER, OPTIONAL, INTENT(OUT) :: ortho_row_comm
40  INTEGER, OPTIONAL, INTENT(OUT) :: ortho_col_comm
41  INTEGER, OPTIONAL, INTENT(OUT) :: ortho_comm_id
42  INTEGER, OPTIONAL, INTENT(OUT) :: ortho_parent_comm
43  INTEGER, OPTIONAL, INTENT(OUT) :: me_blacs
44  INTEGER, OPTIONAL, INTENT(OUT) :: np_blacs
45  INTEGER, OPTIONAL, INTENT(OUT) :: ortho_cntx
46  INTEGER, OPTIONAL, INTENT(OUT) :: world_cntx
47  LOGICAL, OPTIONAL, INTENT(OUT) :: do_distr_diag_inside_bgrp
48  IF( PRESENT(nproc_ortho) ) nproc_ortho = nproc_ortho_
49  IF( PRESENT(leg_ortho) ) leg_ortho = leg_ortho_
50  IF( PRESENT(np_ortho) ) np_ortho = np_ortho_
51  IF( PRESENT(me_ortho) ) me_ortho = me_ortho_
52  IF( PRESENT(ortho_comm) ) ortho_comm = ortho_comm_
53  IF( PRESENT(ortho_row_comm) ) ortho_row_comm = ortho_row_comm_
54  IF( PRESENT(ortho_col_comm) ) ortho_col_comm = ortho_col_comm_
55  IF( PRESENT(ortho_comm_id) ) ortho_comm_id = ortho_comm_id_
56  IF( PRESENT(ortho_parent_comm) ) ortho_parent_comm = ortho_parent_comm_
57  IF( PRESENT(me_blacs) ) me_blacs = me_blacs_
58  IF( PRESENT(np_blacs) ) np_blacs = np_blacs_
59  IF( PRESENT(ortho_cntx) ) ortho_cntx = ortho_cntx_
60  IF( PRESENT(world_cntx) ) world_cntx = world_cntx_
61  IF( PRESENT(do_distr_diag_inside_bgrp) ) do_distr_diag_inside_bgrp = do_distr_diag_inside_bgrp_
62END SUBROUTINE
63!
64SUBROUTINE laxlib_get_status_x ( lax_status )
65  use laxlib_processors_grid, ONLY : &
66    nproc_ortho_ => nproc_ortho, &
67    leg_ortho_   => leg_ortho, &
68    np_ortho_    => np_ortho, &
69    me_ortho_    => me_ortho, &
70    ortho_comm_  => ortho_comm, &
71    ortho_row_comm_ => ortho_row_comm, &
72    ortho_col_comm_ => ortho_col_comm, &
73    ortho_comm_id_  => ortho_comm_id, &
74    ortho_parent_comm_ => ortho_parent_comm, &
75    me_blacs_    => me_blacs,  &
76    np_blacs_    => np_blacs, &
77    ortho_cntx_  => ortho_cntx, &
78    world_cntx_  => world_cntx, &
79    do_distr_diag_inside_bgrp_ => do_distr_diag_inside_bgrp
80  IMPLICIT NONE
81  include 'laxlib_param.fh'
82  INTEGER, INTENT(OUT) :: LAX_STATUS(:)
83  lax_status(LAX_STATUS_NPROC)= nproc_ortho_
84  lax_status(LAX_STATUS_LEG)= leg_ortho_
85  lax_status(LAX_STATUS_NP1)= np_ortho_( 1 )
86  lax_status(LAX_STATUS_NP2)= np_ortho_( 2 )
87  lax_status(LAX_STATUS_ME1)= me_ortho_( 1 )
88  lax_status(LAX_STATUS_ME2)= me_ortho_( 2 )
89  lax_status(LAX_STATUS_COMM)= ortho_comm_
90  lax_status(LAX_STATUS_ROWCOMM)= ortho_row_comm_
91  lax_status(LAX_STATUS_COLCOMM)= ortho_col_comm_
92  lax_status(LAX_STATUS_COMMID)= ortho_comm_id_
93  lax_status(LAX_STATUS_PARENTCOMM)= ortho_parent_comm_
94  lax_status(LAX_STATUS_MEBLACS)= me_blacs_
95  lax_status(LAX_STATUS_NPBLACS)= np_blacs_
96  lax_status(LAX_STATUS_ORTHOCNTX)= ortho_cntx_
97  lax_status(LAX_STATUS_WORLDCNTX)= world_cntx_
98  IF( do_distr_diag_inside_bgrp_ ) THEN
99     lax_status(LAX_STATUS_DISTDIAG)= 1
100  ELSE
101     lax_status(LAX_STATUS_DISTDIAG)= 2
102  END IF
103END SUBROUTINE
104
105!----------------------------------------------------------------------------
106
107SUBROUTINE laxlib_start_drv( ndiag_, my_world_comm, parent_comm, do_distr_diag_inside_bgrp_  )
108    !
109    use laxlib_processors_grid
110    USE laxlib_parallel_include
111    !
112    !
113    ! ... Ortho/diag/linear algebra group initialization
114    !
115    IMPLICIT NONE
116    !
117    INTEGER, INTENT(INOUT) :: ndiag_  ! (IN) input number of procs in the diag group, (OUT) actual number
118    INTEGER, INTENT(IN) :: my_world_comm ! parallel communicator of the "local" world
119    INTEGER, INTENT(IN) :: parent_comm ! parallel communicator inside which the distributed linear algebra group
120                                       ! communicators are created
121    LOGICAL, INTENT(IN) :: do_distr_diag_inside_bgrp_  ! comme son nom l'indique
122    !
123    INTEGER :: mpime      =  0  ! the global MPI task index (used in clocks) can be set with a laxlib_rank call
124    !
125    INTEGER :: nproc_ortho_try
126    INTEGER :: parent_nproc ! nproc of the parent group
127    INTEGER :: world_nproc  ! nproc of the world group
128    INTEGER :: my_parent_id ! id of the parent communicator
129    INTEGER :: nparent_comm ! mumber of parent communicators
130    INTEGER :: ierr = 0
131    !
132    IF( lax_is_initialized ) &
133       CALL laxlib_end_drv ( )
134
135    world_nproc  = laxlib_size( my_world_comm ) ! the global number of processors in world_comm
136    mpime        = laxlib_rank( my_world_comm ) ! set the global MPI task index  (used in clocks)
137    parent_nproc = laxlib_size( parent_comm )! the number of processors in the current parent communicator
138    my_parent_id = mpime / parent_nproc  ! set the index of the current parent communicator
139    nparent_comm = world_nproc/parent_nproc ! number of paren communicators
140
141    ! save input value inside the module
142    do_distr_diag_inside_bgrp = do_distr_diag_inside_bgrp_
143
144    !
145#if defined __SCALAPACK
146    np_blacs     = laxlib_size( my_world_comm )
147    me_blacs     = laxlib_rank( my_world_comm )
148    !
149    ! define a 1D grid containing all MPI tasks of the global communicator
150    ! NOTE: world_cntx has the MPI communicator on entry and the BLACS context on exit
151    !       BLACS_GRIDINIT() will create a copy of the communicator, which can be
152    !       later retrieved using CALL BLACS_GET(world_cntx, 10, comm_copy)
153    !
154    world_cntx = my_world_comm
155    CALL BLACS_GRIDINIT( world_cntx, 'Row', 1, np_blacs )
156    !
157#endif
158    !
159    IF( ndiag_ > 0 ) THEN
160       ! command-line argument -ndiag N or -northo N set to a value N
161       ! use the command line value ensuring that it falls in the proper range
162       nproc_ortho_try = MIN( ndiag_ , parent_nproc )
163    ELSE
164       ! no command-line argument -ndiag N or -northo N is present
165       ! insert here custom architecture specific default definitions
166#if defined(__SCALAPACK) && !defined(__CUDA)
167       nproc_ortho_try = MAX( parent_nproc, 1 )
168#else
169       nproc_ortho_try = 1
170#endif
171    END IF
172    !
173    ! the ortho group for parallel linear algebra is a sub-group of the pool,
174    ! then there are as many ortho groups as pools.
175    !
176    CALL init_ortho_group ( nproc_ortho_try, my_world_comm, parent_comm, nparent_comm, my_parent_id )
177    !
178    ! set the number of processors in the diag group to the actual number used
179    !
180    ndiag_ = nproc_ortho
181    !
182    lax_is_initialized = .true.
183    !
184    RETURN
185    !
186CONTAINS
187
188  SUBROUTINE init_ortho_group ( nproc_try_in, my_world_comm, comm_all, nparent_comm, my_parent_id )
189    !
190    IMPLICIT NONE
191
192    INTEGER, INTENT(IN) :: nproc_try_in, comm_all
193    INTEGER, INTENT(IN) :: my_world_comm ! parallel communicator of the "local" world
194    INTEGER, INTENT(IN) :: nparent_comm
195    INTEGER, INTENT(IN) :: my_parent_id ! id of the parent communicator
196
197    INTEGER :: ierr, color, key, me_all, nproc_all, nproc_try
198
199#if defined __SCALAPACK
200    INTEGER, ALLOCATABLE :: blacsmap(:,:)
201    INTEGER, ALLOCATABLE :: ortho_cntx_pe(:)
202    INTEGER :: nprow, npcol, myrow, mycol, i, j, k
203    INTEGER, EXTERNAL :: BLACS_PNUM
204#endif
205
206#if defined __MPI
207
208    me_all    = laxlib_rank( comm_all )
209    !
210    nproc_all = laxlib_size( comm_all )
211    !
212    nproc_try = MIN( nproc_try_in, nproc_all )
213    nproc_try = MAX( nproc_try, 1 )
214
215    !  find the square closer (but lower) to nproc_try
216    !
217    CALL grid2d_dims( 'S', nproc_try, np_ortho(1), np_ortho(2) )
218    !
219    !  now, and only now, it is possible to define the number of tasks
220    !  in the ortho group for parallel linear algebra
221    !
222    nproc_ortho = np_ortho(1) * np_ortho(2)
223    !
224    IF( nproc_all >= 4*nproc_ortho ) THEN
225       !
226       !  here we choose a processor every 4, in order not to stress memory BW
227       !  on multi core procs, for which further performance enhancements are
228       !  possible using OpenMP BLAS inside regter/cegter/rdiaghg/cdiaghg
229       !  (to be implemented)
230       !
231       color = 0
232       IF( me_all < 4*nproc_ortho .AND. MOD( me_all, 4 ) == 0 ) color = 1
233       !
234       leg_ortho = 4
235       !
236    ELSE IF( nproc_all >= 2*nproc_ortho ) THEN
237       !
238       !  here we choose a processor every 2, in order not to stress memory BW
239       !
240       color = 0
241       IF( me_all < 2*nproc_ortho .AND. MOD( me_all, 2 ) == 0 ) color = 1
242       !
243       leg_ortho = 2
244       !
245    ELSE
246       !
247       !  here we choose the first processors
248       !
249       color = 0
250       IF( me_all < nproc_ortho ) color = 1
251       !
252       leg_ortho = 1
253       !
254    END IF
255    !
256    key   = me_all
257    !
258    !  initialize the communicator for the new group by splitting the input communicator
259    !
260    CALL laxlib_comm_split ( comm_all, color, key, ortho_comm )
261    !
262    ! and remember where it comes from
263    !
264    ortho_parent_comm = comm_all
265    !
266    !  Computes coordinates of the processors, in row maior order
267    !
268    me_ortho1   = laxlib_rank( ortho_comm )
269    !
270    IF( me_all == 0 .AND. me_ortho1 /= 0 ) &
271         CALL lax_error__( " init_ortho_group ", " wrong root task in ortho group ", ierr )
272    !
273    if( color == 1 ) then
274       ortho_comm_id = 1
275       CALL GRID2D_COORDS( 'R', me_ortho1, np_ortho(1), np_ortho(2), me_ortho(1), me_ortho(2) )
276       CALL GRID2D_RANK( 'R', np_ortho(1), np_ortho(2), me_ortho(1), me_ortho(2), ierr )
277       IF( ierr /= me_ortho1 ) &
278            CALL lax_error__( " init_ortho_group ", " wrong task coordinates in ortho group ", ierr )
279       IF( me_ortho1*leg_ortho /= me_all ) &
280            CALL lax_error__( " init_ortho_group ", " wrong rank assignment in ortho group ", ierr )
281
282       CALL laxlib_comm_split( ortho_comm, me_ortho(2), me_ortho(1), ortho_col_comm)
283       CALL laxlib_comm_split( ortho_comm, me_ortho(1), me_ortho(2), ortho_row_comm)
284
285    else
286       ortho_comm_id = 0
287       me_ortho(1) = me_ortho1
288       me_ortho(2) = me_ortho1
289    endif
290
291#if defined __SCALAPACK
292    !
293    !  This part is used to eliminate the image dependency from ortho groups
294    !  SCALAPACK is now independent from whatever level of parallelization
295    !  is present on top of pool parallelization
296    !
297    ALLOCATE( ortho_cntx_pe( nparent_comm ) )
298    ALLOCATE( blacsmap( np_ortho(1), np_ortho(2) ) )
299
300    DO j = 1, nparent_comm
301
302         CALL BLACS_GET(world_cntx, 10, ortho_cntx_pe( j ) ) ! retrieve communicator of world context
303         blacsmap = 0
304         nprow = np_ortho(1)
305         npcol = np_ortho(2)
306
307         IF( ( j == ( my_parent_id + 1 ) ) .and. ( ortho_comm_id > 0 ) ) THEN
308
309           blacsmap( me_ortho(1) + 1, me_ortho(2) + 1 ) = BLACS_PNUM( world_cntx, 0, me_blacs )
310
311         END IF
312
313         ! All MPI tasks defined in the global communicator take part in the definition of the BLACS grid
314
315         CALL MPI_ALLREDUCE( MPI_IN_PLACE, blacsmap,  SIZE(blacsmap), MPI_INTEGER, MPI_SUM, my_world_comm, ierr )
316         IF( ierr /= 0 ) &
317            CALL lax_error__( ' init_ortho_group ', ' problem in MPI_ALLREDUCE of blacsmap ', ierr )
318
319         CALL BLACS_GRIDMAP( ortho_cntx_pe( j ), blacsmap, nprow, nprow, npcol )
320
321         CALL BLACS_GRIDINFO( ortho_cntx_pe( j ), nprow, npcol, myrow, mycol )
322
323         IF( ( j == ( my_parent_id + 1 ) ) .and. ( ortho_comm_id > 0 ) ) THEN
324
325            IF(  np_ortho(1) /= nprow ) &
326               CALL lax_error__( ' init_ortho_group ', ' problem with SCALAPACK, wrong no. of task rows ', 1 )
327            IF(  np_ortho(2) /= npcol ) &
328               CALL lax_error__( ' init_ortho_group ', ' problem with SCALAPACK, wrong no. of task columns ', 1 )
329            IF(  me_ortho(1) /= myrow ) &
330               CALL lax_error__( ' init_ortho_group ', ' problem with SCALAPACK, wrong task row ID ', 1 )
331            IF(  me_ortho(2) /= mycol ) &
332               CALL lax_error__( ' init_ortho_group ', ' problem with SCALAPACK, wrong task columns ID ', 1 )
333
334            ortho_cntx = ortho_cntx_pe( j )
335
336         END IF
337
338    END DO
339
340    DEALLOCATE( blacsmap )
341    DEALLOCATE( ortho_cntx_pe )
342
343    !  end SCALAPACK code block
344
345#endif
346
347#else
348
349    ortho_comm_id = 1
350
351#endif
352
353    RETURN
354  END SUBROUTINE init_ortho_group
355
356END SUBROUTINE laxlib_start_drv
357
358!------------------------------------------------------------------------------!
359
360SUBROUTINE print_lambda_x( lambda, idesc, n, nshow, nudx, ccc, ionode, iunit )
361    IMPLICIT NONE
362    include 'laxlib_low.fh'
363    include 'laxlib_kinds.fh'
364    real(DP), intent(in) :: lambda(:,:,:), ccc
365    INTEGER, INTENT(IN) :: idesc(:,:)
366    integer, intent(in) :: n, nshow, nudx
367    logical, intent(in) :: ionode
368    integer, intent(in) :: iunit
369    !
370    integer :: nnn, j, i, is
371    real(DP), allocatable :: lambda_repl(:,:)
372    nnn = min( nudx, nshow )
373    ALLOCATE( lambda_repl( nudx, nudx ) )
374    IF( ionode ) WRITE( iunit,*)
375    DO is = 1, SIZE( lambda, 3 )
376       CALL collect_lambda( lambda_repl, lambda(:,:,is), idesc(:,is) )
377       IF( ionode ) THEN
378          WRITE( iunit,3370) '    lambda   nudx, spin = ', nudx, is
379          IF( nnn < n ) WRITE( iunit,3370) '    print only first ', nnn
380          DO i=1,nnn
381             WRITE( iunit,3380) (lambda_repl(i,j)*ccc,j=1,nnn)
382          END DO
383       END IF
384    END DO
385    DEALLOCATE( lambda_repl )
3863370   FORMAT(26x,a,2i4)
3873380   FORMAT(9f8.4)
388    RETURN
389END SUBROUTINE print_lambda_x
390
391  SUBROUTINE laxlib_desc_init1( nsiz, nx, la_proc, idesc, rank_ip, idesc_ip )
392     !
393     IMPLICIT NONE
394     include 'laxlib_low.fh'
395     include 'laxlib_param.fh'
396     include 'laxlib_kinds.fh'
397     !
398     INTEGER, INTENT(IN)  :: nsiz
399     INTEGER, INTENT(OUT) :: nx
400     LOGICAL, INTENT(OUT) :: la_proc
401     INTEGER, INTENT(OUT) :: idesc(LAX_DESC_SIZE)
402     INTEGER, INTENT(OUT), ALLOCATABLE :: rank_ip( :, : )
403     INTEGER, INTENT(OUT), ALLOCATABLE :: idesc_ip(:,:,:)
404     !
405     INTEGER :: ortho_comm, np_ortho(2), me_ortho(2), ortho_comm_id, &
406          leg_ortho
407     !
408     CALL laxlib_getval( np_ortho = np_ortho, me_ortho = me_ortho, &
409          ortho_comm = ortho_comm, leg_ortho = leg_ortho, &
410          ortho_comm_id = ortho_comm_id )
411     !
412     IF ( .NOT. ALLOCATED (idesc_ip) ) THEN
413        ALLOCATE( idesc_ip( LAX_DESC_SIZE, np_ortho(1), np_ortho(2) ) )
414     ELSE
415        IF ( SIZE (idesc_ip,2) /= np_ortho(1) .OR. &
416             SIZE (idesc_ip,3) /= np_ortho(2) ) &
417             CALL lax_error__( " desc_init ", " inconsistent dimension ", 2 )
418     END IF
419     IF ( .NOT. ALLOCATED (rank_ip) ) &
420          ALLOCATE( rank_ip( np_ortho(1), np_ortho(2) ) )
421     !
422     CALL laxlib_init_desc( idesc, idesc_ip, rank_ip, nsiz, nsiz )
423     !
424     nx = idesc(LAX_DESC_NRCX)
425     !
426     la_proc = .FALSE.
427     IF( idesc(LAX_DESC_ACTIVE_NODE) > 0 ) la_proc = .TRUE.
428     !
429     RETURN
430   END SUBROUTINE laxlib_desc_init1
431   !
432   SUBROUTINE laxlib_desc_init2( nsiz, nx, la_proc, idesc, rank_ip, irc_ip, nrc_ip )
433     !
434     IMPLICIT NONE
435     include 'laxlib_low.fh'
436     include 'laxlib_param.fh'
437     include 'laxlib_kinds.fh'
438     !
439     INTEGER, INTENT(IN)  :: nsiz
440     INTEGER, INTENT(OUT) :: nx
441     LOGICAL, INTENT(OUT) :: la_proc
442     INTEGER, INTENT(OUT) :: idesc(LAX_DESC_SIZE)
443     INTEGER, INTENT(OUT), ALLOCATABLE :: rank_ip(:,:)
444     INTEGER, INTENT(OUT), ALLOCATABLE :: irc_ip(:)
445     INTEGER, INTENT(OUT), ALLOCATABLE :: nrc_ip(:)
446
447     INTEGER :: i, j, rank
448     INTEGER :: ortho_comm, np_ortho(2), me_ortho(2), ortho_comm_id, &
449          leg_ortho, ortho_cntx
450     !
451     CALL laxlib_getval( np_ortho = np_ortho, me_ortho = me_ortho, &
452          ortho_comm = ortho_comm, leg_ortho = leg_ortho, &
453          ortho_comm_id = ortho_comm_id, ortho_cntx = ortho_cntx )
454     !
455     CALL laxlib_init_desc( idesc, nsiz, nsiz, np_ortho, me_ortho, &
456          ortho_comm, ortho_cntx, ortho_comm_id )
457     !
458     nx = idesc(LAX_DESC_NRCX)
459     !
460     IF ( .NOT. ALLOCATED (rank_ip) ) THEN
461        ALLOCATE( rank_ip( np_ortho(1), np_ortho(2) ) )
462        ALLOCATE( irc_ip( np_ortho(1) ), nrc_ip (np_ortho(1) ) )
463     ELSE
464        IF ( SIZE (rank_ip,1) /= np_ortho(1) .OR. &
465             SIZE (rank_ip,2) /= np_ortho(2) ) &
466             CALL lax_error__( " desc_init ", " inconsistent dimension ", 1 )
467     END IF
468     DO j = 0, idesc(LAX_DESC_NPC) - 1
469        CALL laxlib_local_dims( irc_ip( j + 1 ), nrc_ip( j + 1 ), &
470             idesc(LAX_DESC_N), idesc(LAX_DESC_NX), np_ortho(1), j )
471        DO i = 0, idesc(LAX_DESC_NPR) - 1
472           CALL GRID2D_RANK( 'R', idesc(LAX_DESC_NPR), idesc(LAX_DESC_NPC), i, j, rank )
473           rank_ip( i+1, j+1 ) = rank * leg_ortho
474        END DO
475     END DO
476     !
477     la_proc = .FALSE.
478     IF( idesc(LAX_DESC_ACTIVE_NODE) > 0 ) la_proc = .TRUE.
479     !
480     RETURN
481   END SUBROUTINE laxlib_desc_init2
482  !
483  !
484SUBROUTINE laxlib_init_desc_x( idesc, n, nx, np, me, comm, cntx, comm_id )
485    USE laxlib_descriptor,       ONLY: la_descriptor, descla_init, laxlib_desc_to_intarray
486    IMPLICIT NONE
487    include 'laxlib_param.fh'
488    INTEGER, INTENT(OUT) :: idesc(LAX_DESC_SIZE)
489    INTEGER, INTENT(IN)  :: n   !  the size of this matrix
490    INTEGER, INTENT(IN)  :: nx  !  the max among different matrixes sharing this descriptor or the same data distribution
491    INTEGER, INTENT(IN)  :: np(2), me(2), comm, cntx
492    INTEGER, INTENT(IN)  :: comm_id
493    !
494    TYPE(la_descriptor) :: descla
495    !
496    CALL descla_init( descla, n, nx, np, me, comm, cntx, comm_id )
497    CALL laxlib_desc_to_intarray( idesc, descla )
498    RETURN
499END SUBROUTINE laxlib_init_desc_x
500
501
502SUBROUTINE laxlib_multi_init_desc_x( idesc, idesc_ip, rank_ip, n, nx  )
503    USE laxlib_descriptor,       ONLY: la_descriptor, descla_init, laxlib_desc_to_intarray
504    use laxlib_processors_grid,  ONLY: leg_ortho, np_ortho, me_ortho, ortho_comm, ortho_comm_id, ortho_cntx
505    IMPLICIT NONE
506    include 'laxlib_param.fh'
507    INTEGER, INTENT(OUT) :: idesc(LAX_DESC_SIZE)
508    INTEGER, INTENT(OUT) :: idesc_ip(:,:,:)
509    INTEGER, INTENT(OUT) :: rank_ip(:,:)
510    INTEGER, INTENT(IN)  :: n   !  the size of this matrix
511    INTEGER, INTENT(IN)  :: nx  !  the max among different matrixes sharing this descriptor or the same data distribution
512
513    INTEGER :: i, j, rank, includeme
514    INTEGER :: coor_ip( 2 )
515    !
516    TYPE(la_descriptor) :: descla
517    !
518    CALL descla_init( descla, n, nx, np_ortho, me_ortho, ortho_comm, ortho_cntx, ortho_comm_id )
519    !
520    CALL laxlib_desc_to_intarray( idesc, descla )
521    !
522    includeme = 1
523    !
524    DO j = 0, idesc(LAX_DESC_NPC) - 1
525       DO i = 0, idesc(LAX_DESC_NPR) - 1
526          coor_ip( 1 ) = i
527          coor_ip( 2 ) = j
528          CALL descla_init( descla, idesc(LAX_DESC_N), idesc(LAX_DESC_NX), &
529               np_ortho, coor_ip, ortho_comm, ortho_cntx, includeme )
530          CALL laxlib_desc_to_intarray( idesc_ip(:,i+1,j+1), descla )
531          CALL GRID2D_RANK( 'R', idesc(LAX_DESC_NPR), idesc(LAX_DESC_NPC), i, j, rank )
532          rank_ip( i+1, j+1 ) = rank * leg_ortho
533       END DO
534    END DO
535    !
536    RETURN
537END SUBROUTINE laxlib_multi_init_desc_x
538
539   SUBROUTINE descla_local_dims( i2g, nl, n, nx, np, me )
540      IMPLICIT NONE
541      INTEGER, INTENT(OUT) :: i2g  !  global index of the first local element
542      INTEGER, INTENT(OUT) :: nl   !  local number of elements
543      INTEGER, INTENT(IN)  :: n    !  number of actual element in the global array
544      INTEGER, INTENT(IN)  :: nx   !  dimension of the global array (nx>=n) to be distributed
545      INTEGER, INTENT(IN)  :: np   !  number of processors
546      INTEGER, INTENT(IN)  :: me   !  taskid for which i2g and nl are computed
547      !
548      !  note that we can distribute a global array larger than the
549      !  number of actual elements. This could be required for performance
550      !  reasons, and to have an equal partition of matrix having different size
551      !  like matrixes of spin-up and spin-down
552      !
553      INTEGER, EXTERNAL ::  ldim_block, ldim_cyclic, ldim_block_sca
554      INTEGER, EXTERNAL ::  gind_block, gind_cyclic, gind_block_sca
555      !
556#if __SCALAPACK
557      nl  = ldim_block_sca( nx, np, me )
558      i2g = gind_block_sca( 1, nx, np, me )
559#else
560      nl  = ldim_block( nx, np, me )
561      i2g = gind_block( 1, nx, np, me )
562#endif
563      ! This is to try to keep a matrix N * N into the same
564      ! distribution of a matrix NX * NX, useful to have
565      ! the matrix of spin-up distributed in the same way
566      ! of the matrix of spin-down
567      !
568      IF( i2g + nl - 1 > n ) nl = n - i2g + 1
569      IF( nl < 0 ) nl = 0
570      RETURN
571      !
572   END SUBROUTINE descla_local_dims
573
574
575!   ----------------------------------------------
576!   Simplified driver
577
578   SUBROUTINE diagonalize_parallel_x( n, rhos, rhod, s, idesc )
579
580      USE dspev_module
581
582      IMPLICIT NONE
583      include 'laxlib_kinds.fh'
584      include 'laxlib_param.fh'
585      include 'laxlib_mid.fh'
586      include 'laxlib_low.fh'
587      REAL(DP), INTENT(IN)  :: rhos(:,:) !  input symmetric matrix
588      REAL(DP)              :: rhod(:)   !  output eigenvalues
589      REAL(DP)              :: s(:,:)    !  output eigenvectors
590      INTEGER,  INTENT(IN) :: n         !  size of the global matrix
591      INTEGER,  INTENT(IN) :: idesc(LAX_DESC_SIZE)
592
593      IF( n < 1 ) RETURN
594
595      !  Matrix is distributed on the same processors group
596      !  used for parallel matrix multiplication
597      !
598      IF( SIZE(s,1) /= SIZE(rhos,1) .OR. SIZE(s,2) /= SIZE(rhos,2) ) &
599         CALL lax_error__( " diagonalize_parallel ", " inconsistent dimension for s and rhos ", 1 )
600
601      IF ( idesc(LAX_DESC_ACTIVE_NODE) > 0 ) THEN
602         !
603         IF( SIZE(s,1) /= idesc(LAX_DESC_NRCX) ) &
604            CALL lax_error__( " diagonalize_parallel ", " inconsistent dimension ", 1)
605         !
606         !  Compute local dimension of the cyclically distributed matrix
607         !
608         s = rhos
609         !
610#if defined(__SCALAPACK)
611         CALL pdsyevd_drv( .true. , n, idesc(LAX_DESC_NRCX), s, SIZE(s,1), rhod, idesc(LAX_DESC_CNTX), idesc(LAX_DESC_COMM) )
612#else
613         CALL laxlib_pdsyevd( .true., n, idesc, s, SIZE(s,1), rhod )
614#endif
615         !
616      END IF
617
618      RETURN
619
620   END SUBROUTINE diagonalize_parallel_x
621
622
623   SUBROUTINE diagonalize_serial_x( n, rhos, rhod )
624      IMPLICIT NONE
625      include 'laxlib_kinds.fh'
626      include 'laxlib_low.fh'
627      INTEGER,  INTENT(IN)  :: n
628      REAL(DP)              :: rhos(:,:)
629      REAL(DP)              :: rhod(:)
630      !
631      ! inputs:
632      ! n     size of the eigenproblem
633      ! rhos  the symmetric matrix
634      ! outputs:
635      ! rhos  eigenvectors
636      ! rhod  eigenvalues
637      !
638      REAL(DP), ALLOCATABLE :: aux(:)
639      INTEGER :: i, j, k
640
641      IF( n < 1 ) RETURN
642
643      ALLOCATE( aux( n * ( n + 1 ) / 2 ) )
644
645      !  pack lower triangle of rho into aux
646      !
647      k = 0
648      DO j = 1, n
649         DO i = j, n
650            k = k + 1
651            aux( k ) = rhos( i, j )
652         END DO
653      END DO
654
655      CALL dspev_drv( 'V', 'L', n, aux, rhod, rhos, SIZE(rhos,1) )
656
657      DEALLOCATE( aux )
658
659      RETURN
660   END SUBROUTINE diagonalize_serial_x
661
662   SUBROUTINE diagonalize_serial_gpu( m, rhos, rhod, s, info )
663#if defined(__CUDA)
664      use cudafor
665#if defined ( __USE_CUSOLVER )
666      USE cusolverDn
667#else
668      use eigsolve_vars
669      use nvtx_inters
670      use dsyevd_gpu
671#endif
672      IMPLICIT NONE
673      include 'laxlib_kinds.fh'
674      INTEGER, INTENT(IN) :: m
675      REAL(DP), DEVICE, INTENT(IN) :: rhos(:,:)
676      REAL(DP), DEVICE, INTENT(OUT) :: rhod(:)
677      REAL(DP), DEVICE, INTENT(OUT) :: s(:,:)
678      INTEGER, INTENT(OUT) :: info
679      !
680      INTEGER :: lwork_d
681      INTEGER :: i, j, lda
682      !
683#if defined (__USE_CUSOLVER)
684      !
685      INTEGER, DEVICE        :: devInfo
686      TYPE(cusolverDnHandle) :: cuSolverHandle
687      REAL(DP), ALLOCATABLE, DEVICE :: work_d(:)
688      !
689#else
690      !
691      REAL(DP), ALLOCATABLE :: work_d(:), a(:,:)
692      ATTRIBUTES( DEVICE ) :: work_d, a
693      REAL(DP), ALLOCATABLE :: b(:,:)
694      REAL(DP), ALLOCATABLE :: work_h(:), w_h(:), z_h(:,:)
695      ATTRIBUTES( PINNED ) :: work_h, w_h, z_h
696      INTEGER, ALLOCATABLE :: iwork_h(:)
697      ATTRIBUTES( PINNED ) :: iwork_h
698      !
699      INTEGER :: lwork_h, liwork_h
700      !
701#endif
702      ! .... Subroutine Body
703      !
704#if defined (__USE_CUSOLVER)
705      !
706      s = rhos
707      lda = SIZE( rhos, 1 )
708      !
709      info = cusolverDnCreate(cuSolverHandle)
710      IF ( info /= CUSOLVER_STATUS_SUCCESS ) &
711         CALL lax_error__( ' diagonalize_serial_gpu ', 'cusolverDnCreate',  ABS( info ) )
712
713      info = cusolverDnDsyevd_bufferSize( &
714             cuSolverHandle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, m, s, lda, rhod, lwork_d)
715      IF( info /= CUSOLVER_STATUS_SUCCESS ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' error in solver 1 ', ABS( info ) )
716
717      ALLOCATE( work_d ( lwork_d ), STAT=info )
718      IF( info /= 0 ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' allocate work_d ', ABS( info ) )
719
720      info = cusolverDnDsyevd( &
721             cuSolverHandle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, m, s, lda, rhod, work_d, lwork_d, devInfo)
722      IF( info /= 0 ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' error in solver 2 ', ABS( info ) )
723
724      info = cudaDeviceSynchronize()
725
726      info = cusolverDnDestroy(cuSolverHandle)
727      IF( info /= CUSOLVER_STATUS_SUCCESS ) CALL lax_error__( ' diagonalize_serial_gpu ', ' cusolverDnDestroy failed ', ABS( info ) )
728
729      DEALLOCATE( work_d )
730
731      !
732#else
733      !
734      info = 0
735      lwork_d  = 2*64*64 + 66*SIZE(rhos,1)
736      lwork_h = 1 + 6*SIZE(rhos,1) + 2*SIZE(rhos,1)*SIZE(rhos,1)
737      liwork_h = 3 + 5*SIZE(rhos,1)
738      ALLOCATE(work_d(lwork_d),STAT = info)
739      IF( info /= 0 ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' allocate work_d ', ABS( info ) )
740      ALLOCATE(a(SIZE(rhos,1),SIZE(rhos,2)),STAT = info)
741      IF( info /= 0 ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' allocate a ', ABS( info ) )
742      ALLOCATE(work_h(lwork_h),STAT = info)
743      IF( info /= 0 ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' allocate work_h ', ABS( info ) )
744      ALLOCATE(iwork_h(liwork_h),STAT = info)
745      IF( info /= 0 ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' allocate iwork_h ', ABS( info ) )
746      !
747      ALLOCATE(w_h(SIZE(rhod)),STAT = info)
748      IF( info /= 0 ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' allocate w_h ', ABS( info ) )
749      ALLOCATE(z_h(SIZE(s,1),SIZE(s,2)),STAT = info)
750      IF( info /= 0 ) CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' allocate z_h ', ABS( info ) )
751
752      if(initialized == 0) call init_eigsolve_gpu
753
754      info = cudaMemcpy(a, rhos, SIZE(rhos,1)*SIZE(rhos,2), cudaMemcpyDeviceToDevice)
755      lda = SIZE(rhos,1)
756      !$cuf kernel do(2) <<<*,*>>>
757      do j = 1,m
758        do i = 1,m
759          if (i > j) then
760            s(i,j) = a(i,j)
761          endif
762        end do
763      end do
764
765      call dsyevd_gpu('V', 'U', 1, m, m, a, lda, s, lda, rhod, work_d, lwork_d, &
766                      work_h, lwork_h, iwork_h, liwork_h, z_h, lda, w_h, info)
767
768      DEALLOCATE(z_h)
769      DEALLOCATE(w_h)
770      DEALLOCATE(iwork_h)
771      DEALLOCATE(work_h)
772      DEALLOCATE(a)
773      DEALLOCATE(work_d)
774#endif
775#else
776      CALL lax_error__( ' laxlib diagonalize_serial_gpu ', ' not compiled in this version ', 0 )
777#endif
778   END SUBROUTINE
779
780