1!
2! Copyright (C) 1996-2016	The SIESTA group
3!  This file is distributed under the terms of the
4!  GNU General Public License: see COPYING in the top directory
5!  or http://www.gnu.org/copyleft/gpl.txt.
6! See Docs/Contributors.txt for a list of contributors.
7!
8! General routine for inverting a matrix of arbitrary sizes.
9! The algoritm follows: 10.1088/1749-4699/5/1/014009
10! which contains an apparent general algoritm (not developed by the authors)
11!
12! The idea is that we calculate in the following way (the order is important):
13!   1. calculate Xn/Cn+1 and Yn/Bn-1.
14!   2. calculate all Mnn
15!   3. calculate all Mn-1n Mn+1n
16
17! The algorithm has been developed so that it does not need any-more
18! memory.
19
20! It has fully been developed by Nick Papior Andersen, 2013
21! Please contact the author of the code: nickpapior@gmail.com
22! before use elsewhere!!!
23
24! The routine can easily be converted to use another precision.
25
26! When called you can specify two optional arguments which tells
27! the module which parts you wish to calculate the inverse of.
28! I.e.
29!   call invert_TriMat(M,Minv,sPart=2,ePart=3)
30! where parts >= 3. This will leverage a lot of computations,
31! especially if parts >> 1 and only part of the matrix is wanted.
32
33
34! After the inverted matrix has been calculated the input matrix
35! looses the values in the diagonal
36! The rest of the information is still retained
37! If one only requests a part of the matrix, we retain some of the values
38!   1. the lower tri-parts of the inverted tri-matrix contains
39!       Xn/Cn+1 from [ePart+1:Parts]
40!   2. the upper tri-parts of the inverted tri-matrix contains
41!       Yn/Bn-1 from [1:sPart-1]
42!   3. the lower tri-parts of the tri-matrix to be inverted contains
43!       Xn/Cn+1 from [sPart:ePart]
44!   4. the upper tri-parts of the tri-matrix to be inverted contains
45!       Yn/Bn-1 from [sPart:ePart]
46!   5. the A(1:sPart-1,1:sPart-1)         is retained IFF sPart > 1
47!   6. the A(ePart+1:Parts,ePart+1:Parts) is retained IFF ePart < Parts
48
49module m_trimat_invert
50
51  use class_zTriMat
52  use precision, only: dp
53  use m_pivot_array, only : Npiv, ipiv, init_pivot
54  use m_pivot_array, only : clear_TriMat_inversion => clear_pivot
55
56  implicit none
57
58  private
59  private :: dp
60
61  ! Used for BLAS calls (local variables)
62  complex(dp), private, parameter :: z0  = dcmplx( 0._dp, 0._dp)
63  complex(dp), private, parameter :: z1  = dcmplx( 1._dp, 0._dp)
64  complex(dp), private, parameter :: zm1 = dcmplx(-1._dp, 0._dp)
65
66  public :: invert_TriMat
67  public :: init_TriMat_inversion
68  public :: clear_TriMat_inversion
69
70  ! For those inclined to do other things
71  ! than simply inverting the matrix... :)
72  public :: calc_Xn_div_Cn_p1
73  public :: calc_Yn_div_Bn_m1
74  public :: Xn_div_Cn_p1
75  public :: Yn_div_Bn_m1
76  public :: calc_Mnn_inv
77  public :: calc_Mnm1n_inv
78  public :: calc_Mnp1n_inv
79
80contains
81
82  subroutine invert_TriMat(M,Minv,calc_parts)
83    type(zTriMat), intent(inout) :: M, Minv
84    logical, intent(in), optional :: calc_parts(:)
85    complex(dp), pointer :: Mpinv(:)
86    integer :: lsPart, lePart
87    integer :: sNm1, sNp1, n
88    logical :: piv_initialized
89    logical, allocatable :: lc_parts(:)
90
91#ifndef TS_NOCHECKS
92    if ( parts(M) /= parts(Minv) ) then
93       call die('Could not calculate the inverse on non equal sized &
94            &matrices')
95    end if
96    if ( parts(M) == 1 ) then
97       call die('This matrix is not tri-diagonal')
98    end if
99    piv_initialized = .true.
100    do n = 1 , parts(M)
101       if ( Npiv < nrows_g(M,n) ) piv_initialized = .false.
102    end do
103    if ( .not. piv_initialized ) then
104       call die('Pivoting array for inverting matrix not set.')
105    end if
106#endif
107
108    ! Figure out if the calc_parts is correctly sized
109    allocate(lc_parts(parts(M)))
110    if ( present(calc_parts) ) then
111#ifndef TS_NOCHECKS
112       if ( size(calc_parts) /= parts(M) ) then
113          call die('Wrong size of calculation parts. Please correct code')
114       end if
115#endif
116       ! Copy over values
117       lc_parts(:) = calc_parts(:)
118       do n = 1 , parts(M)
119          if ( lc_parts(n) ) then
120             lsPart = n
121             exit
122          end if
123       end do
124       do n = parts(M) , 1 , -1
125          if ( lc_parts(n) ) then
126             lePart = n
127             exit
128          end if
129       end do
130    else
131       lc_parts(:) = .true.
132       lsPart = 1
133       lePart = parts(M)
134    end if
135
136    call timer('TM_inv',1)
137
138    ! Calculate all Xn/Cn+1
139    do n = parts(M) - 1 , lsPart , -1
140       Mpinv => val(Minv,n+1,n+1)
141       sNp1 = nrows_g(M,n+1)
142       call calc_Xn_div_Cn_p1(M,Minv, n, Mpinv, sNp1**2 )
143    end do
144    ! Calculate all Yn/Bn-1
145    do n = 2 , lePart
146       Mpinv => val(Minv,n-1,n-1)
147       sNm1 = nrows_g(M,n-1)
148       call calc_Yn_div_Bn_m1(M,Minv, n, Mpinv, sNm1**2 )
149    end do
150
151    ! We calculate all Mnn
152    ! Here it is permissable to overwrite the old A
153
154    do n = lsPart , lePart
155       if ( lc_parts(n) ) then
156          call calc_Mnn_inv(M,Minv,n)
157       end if
158    end do
159
160    ! ************ We have now calculated all diagonal parts of the
161    ! tri-diagonal matrix... **************************************
162
163    do n = lsPart + 1 , lePart
164       if ( lc_parts(n) ) then
165          call calc_Mnm1n_inv(M,Minv,n)
166       end if
167    end do
168    do n = lePart - 1 , lsPart , -1
169       if ( lc_parts(n) ) then
170          call calc_Mnp1n_inv(M,Minv,n)
171       end if
172    end do
173
174    ! De-allocate variable to track calculated parts
175    deallocate(lc_parts)
176
177    call timer('TM_inv',2)
178
179  end subroutine invert_TriMat
180
181  subroutine calc_Mnn_inv(M,Minv,n)
182    use intrinsic_missing, only: EYE
183    type(zTriMat), intent(inout) :: M, Minv
184    integer, intent(in) :: n
185    ! Local variables
186    complex(dp), pointer :: Mp(:), Mpinv(:)
187    complex(dp), pointer :: Xn(:), Yn(:), Cn(:), Bn(:)
188    integer :: sNm1, sN, sNp1, i
189
190    if ( 1 < n )        sNm1 = nrows_g(M,n-1)
191                        sN   = nrows_g(M,n)
192    if ( n < parts(M) ) sNp1 = nrows_g(M,n+1)
193
194    ! Retrieve Ann
195    Mp => val(M,n,n)
196    if ( n == 1 ) then
197       ! First we calculate M11^-1
198       ! Retrieve the X1/C2 array
199       Xn => Xn_div_Cn_p1(Minv,n)
200       ! The C2 array
201       Cn => val(M,n,n+1)
202       ! Calculate: A1 - X1
203#ifdef USE_GEMM3M
204       call zgemm3m( &
205#else
206       call zgemm( &
207#endif
208            'N','N',sN,sN,sNp1, &
209            zm1, Cn,sN, Xn,sNp1,z1, Mp,sN)
210
211    else if ( n == parts(M) ) then
212
213       ! Retrieve the Yn/Bn-1 array
214       Yn => Yn_div_Bn_m1(Minv,n)
215       ! The Bn-1 array
216       Bn => val(M,n,n-1)
217       ! Calculate: An - Yn
218#ifdef USE_GEMM3M
219       call zgemm3m( &
220#else
221       call zgemm( &
222#endif
223            'N','N',sN,sN,sNm1, &
224            zm1, Bn,sN, Yn,sNm1,z1, Mp,sN)
225
226    else
227       ! Retrieve the Xn/Cn+1 array
228       Xn => Xn_div_Cn_p1(Minv,n)
229       ! The Cn+1 array
230       Cn => val(M,n,n+1)
231       ! Calculate: An - Xn
232#ifdef USE_GEMM3M
233       call zgemm3m( &
234#else
235       call zgemm( &
236#endif
237            'N','N',sN,sN,sNp1, &
238            zm1, Cn,sN, Xn,sNp1,z1, Mp,sN)
239       ! Retrieve the Yn/Bn-1 array
240       Yn => Yn_div_Bn_m1(Minv,n)
241       ! The Bn-1 array
242       Bn => val(M,n,n-1)
243       ! Calculate: An - Xn - Yn
244#ifdef USE_GEMM3M
245       call zgemm3m( &
246#else
247       call zgemm( &
248#endif
249            'N','N',sN,sN,sNm1, &
250            zm1, Bn,sN, Yn,sNm1,z1, Mp,sN)
251
252    end if
253
254    ! Retrive the position in the inverted matrix
255    Mpinv => val(Minv,n,n)
256    call EYE(sN,Mpinv)
257
258    call zgesv(sN,sN,Mp,sN,ipiv,Mpinv,sN,i)
259    if ( i /= 0 ) call die('Error on inverting Mnn')
260
261  end subroutine calc_Mnn_inv
262
263  subroutine calc_Mnp1n_inv(M,Minv,n)
264    type(zTriMat), intent(inout) :: M, Minv
265    integer, intent(in) :: n
266    ! Local variables
267    complex(dp), pointer :: Mp(:), Mpinv(:)
268    complex(dp), pointer :: Xn(:)
269    integer :: sN, sNp1
270
271    if ( n < parts(M) ) then
272       sNp1 = nrows_g(M,n+1)
273    else
274       ! We can/shall not calculate this
275       return
276    end if
277    sN = nrows_g(M,n)
278
279    ! *** we will now calculate Mn+1,n
280    ! Copy over Xn/Cn+1
281    Xn    => Xn_div_Cn_p1(M   ,n)
282    Mpinv => Xn_div_Cn_p1(Minv,n)
283
284    call zcopy(sN*sNp1,Mpinv,1,Xn,1)
285
286    ! Do matrix-multiplication
287    Mp    => val(Minv,n,n)
288    ! Calculate: Xn/Cn+1 * Mnn
289#ifdef USE_GEMM3M
290       call zgemm3m( &
291#else
292       call zgemm( &
293#endif
294         'N','N',sNp1,sN,sN, &
295         zm1, Xn,sNp1, Mp,sN,z0, Mpinv,sNp1)
296
297  end subroutine calc_Mnp1n_inv
298
299  subroutine calc_Mnm1n_inv(M,Minv,n)
300    type(zTriMat), intent(inout) :: M, Minv
301    integer, intent(in) :: n
302    ! Local variables
303    complex(dp), pointer :: Mp(:), Mpinv(:)
304    complex(dp), pointer :: Yn(:)
305    integer :: sN, sNm1
306
307    if ( 1 < n ) then
308       sNm1 = nrows_g(M,n-1)
309    else
310       ! We can/shall not calculate this
311       return
312    end if
313    sN = nrows_g(M,n)
314
315    ! Copy over Yn/Bn-1
316    Yn    => Yn_div_Bn_m1(M   ,n)
317    Mpinv => Yn_div_Bn_m1(Minv,n)
318
319    call zcopy(sN*sNm1,Mpinv,1,Yn,1)
320
321    ! Do matrix-multiplication
322    Mp    => val(Minv,n,n)
323    ! Calculate: Yn/Bn-1 * Mnn
324#ifdef USE_GEMM3M
325       call zgemm3m( &
326#else
327       call zgemm( &
328#endif
329         'N','N',sNm1,sN,sN, &
330         zm1, Yn,sNm1, Mp,sN,z0, Mpinv,sNm1)
331
332  end subroutine calc_Mnm1n_inv
333
334
335
336  ! We will calculate the Xn/Cn+1 component of the
337  ! tri-diagonal inversion algorithm.
338  ! The Xn/Cn+1 will be saved in the Minv n,n-1 (as that has
339  ! the same size).
340  subroutine calc_Xn_div_Cn_p1(M,Minv,n,zwork,nz)
341    type(zTriMat), intent(inout) :: M, Minv
342    integer, intent(in) :: n, nz
343    complex(dp), intent(inout) :: zwork(nz)
344    ! Local variables
345    complex(dp), pointer :: ztmp(:), Xn(:), Cnp2(:)
346    integer :: sN, sNp1, sNp1SQ, sNp2, ierr
347    character(len=50) :: cerr
348
349#ifndef TS_NOCHECKS
350    if ( n < 1 .or. parts(M) <= n .or. parts(M) /= parts(Minv) ) then
351       call die('Could not calculate Xn on these matrices')
352    end if
353#endif
354    ! Collect all matrix sizes for this step...
355    sN     = nrows_g(M,n)
356    sNp1   = nrows_g(M,n+1)
357    sNp1SQ = sNp1 ** 2
358#ifndef TS_NOCHECKS
359    if ( nz < sNp1SQ ) then
360       call die('Work array in Xn calculation not sufficiently &
361            &big.')
362    end if
363#endif
364
365    ! Copy over the Bn array
366    Cnp2 => val(M, n+1, n)
367    ! This is where the inverted matrix will be located
368    Xn   => Xn_div_Cn_p1(Minv, n)
369    ! Copy over the An+1 array
370    ztmp => val(M, n+1, n+1)
371
372    call zcopy(sN*sNp1, Cnp2(1), 1, Xn(1), 1)
373    call zcopy(sNp1SQ, ztmp(1), 1, zwork(1), 1)
374
375    ! If we should calculate X_N-1 then X_N == 0
376    if ( n < parts(M) - 1 ) then
377       ! Size...
378       sNp2 =  nrows_g(M,n+2)
379       ! Retrieve the Xn+1/Cn+2 array
380       ztmp => Xn_div_Cn_p1(Minv,n+1)
381       ! Retrieve the Cn+2 array
382       Cnp2 => val(M,n+1,n+2)
383       ! Calculate: An+1 - Xn+1
384#ifdef USE_GEMM3M
385       call zgemm3m( &
386#else
387       call zgemm( &
388#endif
389            'N','N',sNp1,sNp1,sNp2, &
390            zm1, Cnp2,sNp1, ztmp,sNp2,z1, zwork,sNp1)
391    end if
392
393    ! Calculate Xn/Cn+1
394    call zgesv(sNp1,sN,zwork,sNp1,ipiv,Xn,sNp1,ierr)
395    if ( ierr /= 0 ) then
396       write(cerr,'(3(a,i0))') &
397            'Error on inverting X',n,'/C',n+1,' with error: ',ierr
398       call die(trim(cerr))
399    end if
400
401  end subroutine calc_Xn_div_Cn_p1
402
403  function Xn_div_Cn_p1(M,n) result(Xn)
404    type(zTriMat), intent(in) :: M
405    integer, intent(in) :: n
406    complex(dp), pointer :: Xn(:)
407    Xn => val(M,n+1,n)
408  end function Xn_div_Cn_p1
409
410  ! We will calculate the Yn/Bn-1 component of the
411  ! tri-diagonal inversion algorithm.
412  ! The Yn/Bn-1 will be saved in the Minv n-1,n (as that has
413  ! the same size).
414  subroutine calc_Yn_div_Bn_m1(M,Minv,n,zwork,nz)
415    type(zTriMat), intent(inout) :: M, Minv
416    integer, intent(in) :: n, nz
417    complex(dp), intent(inout) :: zwork(nz)
418    ! Local variables
419    complex(dp), pointer :: ztmp(:), Yn(:), Bnm2(:)
420    integer :: sN, sNm1, sNm1SQ, sNm2, ierr
421    character(len=50) :: cerr
422
423#ifndef TS_NOCHECKS
424    if ( n < 2 .or. parts(M) < n .or. parts(M) /= parts(Minv) ) then
425       call die('Could not calculate Yn on these matrices')
426    end if
427#endif
428    ! Collect all matrix sizes for this step...
429    sN     = nrows_g(M,n)
430    sNm1   = nrows_g(M,n-1)
431    sNm1SQ = sNm1 ** 2
432#ifndef TS_NOCHECKS
433    if ( nz < sNm1SQ ) then
434       call die('Work array in Yn calculation not sufficiently &
435            &big.')
436    end if
437#endif
438
439    ! Copy over the Cn array
440    Bnm2 => val(M   ,n-1,n)
441    ! This is where the inverted matrix will be located
442    Yn   => Yn_div_Bn_m1(Minv,n)
443    ! Copy over the An-1 array
444    ztmp => val(M,n-1,n-1)
445
446    call zcopy(sN*sNm1, Bnm2(1), 1, Yn(1), 1)
447    call zcopy(sNm1SQ, ztmp(1), 1, zwork(1), 1)
448
449    if ( 2 < n ) then
450       ! Size...
451       sNm2 =  nrows_g(M,n-2)
452       ! Retrieve the Yn-1/Bn-2 array
453       ztmp => Yn_div_Bn_m1(Minv,n-1)
454       ! Retrieve the Bn-2 array
455       Bnm2 => val(M,n-1,n-2)
456       ! Calculate: An-1 - Yn-1
457#ifdef USE_GEMM3M
458       call zgemm3m( &
459#else
460       call zgemm( &
461#endif
462            'N','N',sNm1,sNm1,sNm2, &
463            zm1, Bnm2,sNm1, ztmp,sNm2,z1, zwork,sNm1)
464    end if
465
466    ! Calculate Yn/Bn-1
467    call zgesv(sNm1,sN,zwork,sNm1,ipiv,Yn,sNm1,ierr)
468    if ( ierr /= 0 ) then
469       write(cerr,'(3(a,i0))') &
470            'Error on inverting Y',n,'/B',n-1,' with error: ',ierr
471       call die(trim(cerr))
472    end if
473
474  end subroutine calc_Yn_div_Bn_m1
475
476  function Yn_div_Bn_m1(M,n) result(Yn)
477    type(zTriMat), intent(in) :: M
478    integer, intent(in) :: n
479    complex(dp), pointer :: Yn(:)
480    Yn => val(M,n-1,n)
481  end function Yn_div_Bn_m1
482
483  ! We initialize the pivoting array for rotating the inversion
484  subroutine init_TriMat_inversion(M)
485    type(zTriMat), intent(in) :: M
486    integer :: i, N
487
488    N = 0
489    do i = 1 , parts(M)
490       if ( nrows_g(M,i) > N ) then
491          N = nrows_g(M,i)
492       end if
493    end do
494
495    call init_pivot(N)
496
497  end subroutine init_TriMat_inversion
498
499end module m_trimat_invert
500
501