1! 2! Copyright (C) 1996-2016 The SIESTA group 3! This file is distributed under the terms of the 4! GNU General Public License: see COPYING in the top directory 5! or http://www.gnu.org/copyleft/gpl.txt. 6! See Docs/Contributors.txt for a list of contributors. 7! 8! General routine for inverting a matrix of arbitrary sizes. 9! The algoritm follows: 10.1088/1749-4699/5/1/014009 10! which contains an apparent general algoritm (not developed by the authors) 11! 12! The idea is that we calculate in the following way (the order is important): 13! 1. calculate Xn/Cn+1 and Yn/Bn-1. 14! 2. calculate all Mnn 15! 3. calculate all Mn-1n Mn+1n 16 17! The algorithm has been developed so that it does not need any-more 18! memory. 19 20! It has fully been developed by Nick Papior Andersen, 2013 21! Please contact the author of the code: nickpapior@gmail.com 22! before use elsewhere!!! 23 24! The routine can easily be converted to use another precision. 25 26! When called you can specify two optional arguments which tells 27! the module which parts you wish to calculate the inverse of. 28! I.e. 29! call invert_TriMat(M,Minv,sPart=2,ePart=3) 30! where parts >= 3. This will leverage a lot of computations, 31! especially if parts >> 1 and only part of the matrix is wanted. 32 33 34! After the inverted matrix has been calculated the input matrix 35! looses the values in the diagonal 36! The rest of the information is still retained 37! If one only requests a part of the matrix, we retain some of the values 38! 1. the lower tri-parts of the inverted tri-matrix contains 39! Xn/Cn+1 from [ePart+1:Parts] 40! 2. the upper tri-parts of the inverted tri-matrix contains 41! Yn/Bn-1 from [1:sPart-1] 42! 3. the lower tri-parts of the tri-matrix to be inverted contains 43! Xn/Cn+1 from [sPart:ePart] 44! 4. the upper tri-parts of the tri-matrix to be inverted contains 45! Yn/Bn-1 from [sPart:ePart] 46! 5. the A(1:sPart-1,1:sPart-1) is retained IFF sPart > 1 47! 6. the A(ePart+1:Parts,ePart+1:Parts) is retained IFF ePart < Parts 48 49module m_trimat_invert 50 51 use class_zTriMat 52 use precision, only: dp 53 use m_pivot_array, only : Npiv, ipiv, init_pivot 54 use m_pivot_array, only : clear_TriMat_inversion => clear_pivot 55 56 implicit none 57 58 private 59 private :: dp 60 61 ! Used for BLAS calls (local variables) 62 complex(dp), private, parameter :: z0 = dcmplx( 0._dp, 0._dp) 63 complex(dp), private, parameter :: z1 = dcmplx( 1._dp, 0._dp) 64 complex(dp), private, parameter :: zm1 = dcmplx(-1._dp, 0._dp) 65 66 public :: invert_TriMat 67 public :: init_TriMat_inversion 68 public :: clear_TriMat_inversion 69 70 ! For those inclined to do other things 71 ! than simply inverting the matrix... :) 72 public :: calc_Xn_div_Cn_p1 73 public :: calc_Yn_div_Bn_m1 74 public :: Xn_div_Cn_p1 75 public :: Yn_div_Bn_m1 76 public :: calc_Mnn_inv 77 public :: calc_Mnm1n_inv 78 public :: calc_Mnp1n_inv 79 80contains 81 82 subroutine invert_TriMat(M,Minv,calc_parts) 83 type(zTriMat), intent(inout) :: M, Minv 84 logical, intent(in), optional :: calc_parts(:) 85 complex(dp), pointer :: Mpinv(:) 86 integer :: lsPart, lePart 87 integer :: sNm1, sNp1, n 88 logical :: piv_initialized 89 logical, allocatable :: lc_parts(:) 90 91#ifndef TS_NOCHECKS 92 if ( parts(M) /= parts(Minv) ) then 93 call die('Could not calculate the inverse on non equal sized & 94 &matrices') 95 end if 96 if ( parts(M) == 1 ) then 97 call die('This matrix is not tri-diagonal') 98 end if 99 piv_initialized = .true. 100 do n = 1 , parts(M) 101 if ( Npiv < nrows_g(M,n) ) piv_initialized = .false. 102 end do 103 if ( .not. piv_initialized ) then 104 call die('Pivoting array for inverting matrix not set.') 105 end if 106#endif 107 108 ! Figure out if the calc_parts is correctly sized 109 allocate(lc_parts(parts(M))) 110 if ( present(calc_parts) ) then 111#ifndef TS_NOCHECKS 112 if ( size(calc_parts) /= parts(M) ) then 113 call die('Wrong size of calculation parts. Please correct code') 114 end if 115#endif 116 ! Copy over values 117 lc_parts(:) = calc_parts(:) 118 do n = 1 , parts(M) 119 if ( lc_parts(n) ) then 120 lsPart = n 121 exit 122 end if 123 end do 124 do n = parts(M) , 1 , -1 125 if ( lc_parts(n) ) then 126 lePart = n 127 exit 128 end if 129 end do 130 else 131 lc_parts(:) = .true. 132 lsPart = 1 133 lePart = parts(M) 134 end if 135 136 call timer('TM_inv',1) 137 138 ! Calculate all Xn/Cn+1 139 do n = parts(M) - 1 , lsPart , -1 140 Mpinv => val(Minv,n+1,n+1) 141 sNp1 = nrows_g(M,n+1) 142 call calc_Xn_div_Cn_p1(M,Minv, n, Mpinv, sNp1**2 ) 143 end do 144 ! Calculate all Yn/Bn-1 145 do n = 2 , lePart 146 Mpinv => val(Minv,n-1,n-1) 147 sNm1 = nrows_g(M,n-1) 148 call calc_Yn_div_Bn_m1(M,Minv, n, Mpinv, sNm1**2 ) 149 end do 150 151 ! We calculate all Mnn 152 ! Here it is permissable to overwrite the old A 153 154 do n = lsPart , lePart 155 if ( lc_parts(n) ) then 156 call calc_Mnn_inv(M,Minv,n) 157 end if 158 end do 159 160 ! ************ We have now calculated all diagonal parts of the 161 ! tri-diagonal matrix... ************************************** 162 163 do n = lsPart + 1 , lePart 164 if ( lc_parts(n) ) then 165 call calc_Mnm1n_inv(M,Minv,n) 166 end if 167 end do 168 do n = lePart - 1 , lsPart , -1 169 if ( lc_parts(n) ) then 170 call calc_Mnp1n_inv(M,Minv,n) 171 end if 172 end do 173 174 ! De-allocate variable to track calculated parts 175 deallocate(lc_parts) 176 177 call timer('TM_inv',2) 178 179 end subroutine invert_TriMat 180 181 subroutine calc_Mnn_inv(M,Minv,n) 182 use intrinsic_missing, only: EYE 183 type(zTriMat), intent(inout) :: M, Minv 184 integer, intent(in) :: n 185 ! Local variables 186 complex(dp), pointer :: Mp(:), Mpinv(:) 187 complex(dp), pointer :: Xn(:), Yn(:), Cn(:), Bn(:) 188 integer :: sNm1, sN, sNp1, i 189 190 if ( 1 < n ) sNm1 = nrows_g(M,n-1) 191 sN = nrows_g(M,n) 192 if ( n < parts(M) ) sNp1 = nrows_g(M,n+1) 193 194 ! Retrieve Ann 195 Mp => val(M,n,n) 196 if ( n == 1 ) then 197 ! First we calculate M11^-1 198 ! Retrieve the X1/C2 array 199 Xn => Xn_div_Cn_p1(Minv,n) 200 ! The C2 array 201 Cn => val(M,n,n+1) 202 ! Calculate: A1 - X1 203#ifdef USE_GEMM3M 204 call zgemm3m( & 205#else 206 call zgemm( & 207#endif 208 'N','N',sN,sN,sNp1, & 209 zm1, Cn,sN, Xn,sNp1,z1, Mp,sN) 210 211 else if ( n == parts(M) ) then 212 213 ! Retrieve the Yn/Bn-1 array 214 Yn => Yn_div_Bn_m1(Minv,n) 215 ! The Bn-1 array 216 Bn => val(M,n,n-1) 217 ! Calculate: An - Yn 218#ifdef USE_GEMM3M 219 call zgemm3m( & 220#else 221 call zgemm( & 222#endif 223 'N','N',sN,sN,sNm1, & 224 zm1, Bn,sN, Yn,sNm1,z1, Mp,sN) 225 226 else 227 ! Retrieve the Xn/Cn+1 array 228 Xn => Xn_div_Cn_p1(Minv,n) 229 ! The Cn+1 array 230 Cn => val(M,n,n+1) 231 ! Calculate: An - Xn 232#ifdef USE_GEMM3M 233 call zgemm3m( & 234#else 235 call zgemm( & 236#endif 237 'N','N',sN,sN,sNp1, & 238 zm1, Cn,sN, Xn,sNp1,z1, Mp,sN) 239 ! Retrieve the Yn/Bn-1 array 240 Yn => Yn_div_Bn_m1(Minv,n) 241 ! The Bn-1 array 242 Bn => val(M,n,n-1) 243 ! Calculate: An - Xn - Yn 244#ifdef USE_GEMM3M 245 call zgemm3m( & 246#else 247 call zgemm( & 248#endif 249 'N','N',sN,sN,sNm1, & 250 zm1, Bn,sN, Yn,sNm1,z1, Mp,sN) 251 252 end if 253 254 ! Retrive the position in the inverted matrix 255 Mpinv => val(Minv,n,n) 256 call EYE(sN,Mpinv) 257 258 call zgesv(sN,sN,Mp,sN,ipiv,Mpinv,sN,i) 259 if ( i /= 0 ) call die('Error on inverting Mnn') 260 261 end subroutine calc_Mnn_inv 262 263 subroutine calc_Mnp1n_inv(M,Minv,n) 264 type(zTriMat), intent(inout) :: M, Minv 265 integer, intent(in) :: n 266 ! Local variables 267 complex(dp), pointer :: Mp(:), Mpinv(:) 268 complex(dp), pointer :: Xn(:) 269 integer :: sN, sNp1 270 271 if ( n < parts(M) ) then 272 sNp1 = nrows_g(M,n+1) 273 else 274 ! We can/shall not calculate this 275 return 276 end if 277 sN = nrows_g(M,n) 278 279 ! *** we will now calculate Mn+1,n 280 ! Copy over Xn/Cn+1 281 Xn => Xn_div_Cn_p1(M ,n) 282 Mpinv => Xn_div_Cn_p1(Minv,n) 283 284 call zcopy(sN*sNp1,Mpinv,1,Xn,1) 285 286 ! Do matrix-multiplication 287 Mp => val(Minv,n,n) 288 ! Calculate: Xn/Cn+1 * Mnn 289#ifdef USE_GEMM3M 290 call zgemm3m( & 291#else 292 call zgemm( & 293#endif 294 'N','N',sNp1,sN,sN, & 295 zm1, Xn,sNp1, Mp,sN,z0, Mpinv,sNp1) 296 297 end subroutine calc_Mnp1n_inv 298 299 subroutine calc_Mnm1n_inv(M,Minv,n) 300 type(zTriMat), intent(inout) :: M, Minv 301 integer, intent(in) :: n 302 ! Local variables 303 complex(dp), pointer :: Mp(:), Mpinv(:) 304 complex(dp), pointer :: Yn(:) 305 integer :: sN, sNm1 306 307 if ( 1 < n ) then 308 sNm1 = nrows_g(M,n-1) 309 else 310 ! We can/shall not calculate this 311 return 312 end if 313 sN = nrows_g(M,n) 314 315 ! Copy over Yn/Bn-1 316 Yn => Yn_div_Bn_m1(M ,n) 317 Mpinv => Yn_div_Bn_m1(Minv,n) 318 319 call zcopy(sN*sNm1,Mpinv,1,Yn,1) 320 321 ! Do matrix-multiplication 322 Mp => val(Minv,n,n) 323 ! Calculate: Yn/Bn-1 * Mnn 324#ifdef USE_GEMM3M 325 call zgemm3m( & 326#else 327 call zgemm( & 328#endif 329 'N','N',sNm1,sN,sN, & 330 zm1, Yn,sNm1, Mp,sN,z0, Mpinv,sNm1) 331 332 end subroutine calc_Mnm1n_inv 333 334 335 336 ! We will calculate the Xn/Cn+1 component of the 337 ! tri-diagonal inversion algorithm. 338 ! The Xn/Cn+1 will be saved in the Minv n,n-1 (as that has 339 ! the same size). 340 subroutine calc_Xn_div_Cn_p1(M,Minv,n,zwork,nz) 341 type(zTriMat), intent(inout) :: M, Minv 342 integer, intent(in) :: n, nz 343 complex(dp), intent(inout) :: zwork(nz) 344 ! Local variables 345 complex(dp), pointer :: ztmp(:), Xn(:), Cnp2(:) 346 integer :: sN, sNp1, sNp1SQ, sNp2, ierr 347 character(len=50) :: cerr 348 349#ifndef TS_NOCHECKS 350 if ( n < 1 .or. parts(M) <= n .or. parts(M) /= parts(Minv) ) then 351 call die('Could not calculate Xn on these matrices') 352 end if 353#endif 354 ! Collect all matrix sizes for this step... 355 sN = nrows_g(M,n) 356 sNp1 = nrows_g(M,n+1) 357 sNp1SQ = sNp1 ** 2 358#ifndef TS_NOCHECKS 359 if ( nz < sNp1SQ ) then 360 call die('Work array in Xn calculation not sufficiently & 361 &big.') 362 end if 363#endif 364 365 ! Copy over the Bn array 366 Cnp2 => val(M, n+1, n) 367 ! This is where the inverted matrix will be located 368 Xn => Xn_div_Cn_p1(Minv, n) 369 ! Copy over the An+1 array 370 ztmp => val(M, n+1, n+1) 371 372 call zcopy(sN*sNp1, Cnp2(1), 1, Xn(1), 1) 373 call zcopy(sNp1SQ, ztmp(1), 1, zwork(1), 1) 374 375 ! If we should calculate X_N-1 then X_N == 0 376 if ( n < parts(M) - 1 ) then 377 ! Size... 378 sNp2 = nrows_g(M,n+2) 379 ! Retrieve the Xn+1/Cn+2 array 380 ztmp => Xn_div_Cn_p1(Minv,n+1) 381 ! Retrieve the Cn+2 array 382 Cnp2 => val(M,n+1,n+2) 383 ! Calculate: An+1 - Xn+1 384#ifdef USE_GEMM3M 385 call zgemm3m( & 386#else 387 call zgemm( & 388#endif 389 'N','N',sNp1,sNp1,sNp2, & 390 zm1, Cnp2,sNp1, ztmp,sNp2,z1, zwork,sNp1) 391 end if 392 393 ! Calculate Xn/Cn+1 394 call zgesv(sNp1,sN,zwork,sNp1,ipiv,Xn,sNp1,ierr) 395 if ( ierr /= 0 ) then 396 write(cerr,'(3(a,i0))') & 397 'Error on inverting X',n,'/C',n+1,' with error: ',ierr 398 call die(trim(cerr)) 399 end if 400 401 end subroutine calc_Xn_div_Cn_p1 402 403 function Xn_div_Cn_p1(M,n) result(Xn) 404 type(zTriMat), intent(in) :: M 405 integer, intent(in) :: n 406 complex(dp), pointer :: Xn(:) 407 Xn => val(M,n+1,n) 408 end function Xn_div_Cn_p1 409 410 ! We will calculate the Yn/Bn-1 component of the 411 ! tri-diagonal inversion algorithm. 412 ! The Yn/Bn-1 will be saved in the Minv n-1,n (as that has 413 ! the same size). 414 subroutine calc_Yn_div_Bn_m1(M,Minv,n,zwork,nz) 415 type(zTriMat), intent(inout) :: M, Minv 416 integer, intent(in) :: n, nz 417 complex(dp), intent(inout) :: zwork(nz) 418 ! Local variables 419 complex(dp), pointer :: ztmp(:), Yn(:), Bnm2(:) 420 integer :: sN, sNm1, sNm1SQ, sNm2, ierr 421 character(len=50) :: cerr 422 423#ifndef TS_NOCHECKS 424 if ( n < 2 .or. parts(M) < n .or. parts(M) /= parts(Minv) ) then 425 call die('Could not calculate Yn on these matrices') 426 end if 427#endif 428 ! Collect all matrix sizes for this step... 429 sN = nrows_g(M,n) 430 sNm1 = nrows_g(M,n-1) 431 sNm1SQ = sNm1 ** 2 432#ifndef TS_NOCHECKS 433 if ( nz < sNm1SQ ) then 434 call die('Work array in Yn calculation not sufficiently & 435 &big.') 436 end if 437#endif 438 439 ! Copy over the Cn array 440 Bnm2 => val(M ,n-1,n) 441 ! This is where the inverted matrix will be located 442 Yn => Yn_div_Bn_m1(Minv,n) 443 ! Copy over the An-1 array 444 ztmp => val(M,n-1,n-1) 445 446 call zcopy(sN*sNm1, Bnm2(1), 1, Yn(1), 1) 447 call zcopy(sNm1SQ, ztmp(1), 1, zwork(1), 1) 448 449 if ( 2 < n ) then 450 ! Size... 451 sNm2 = nrows_g(M,n-2) 452 ! Retrieve the Yn-1/Bn-2 array 453 ztmp => Yn_div_Bn_m1(Minv,n-1) 454 ! Retrieve the Bn-2 array 455 Bnm2 => val(M,n-1,n-2) 456 ! Calculate: An-1 - Yn-1 457#ifdef USE_GEMM3M 458 call zgemm3m( & 459#else 460 call zgemm( & 461#endif 462 'N','N',sNm1,sNm1,sNm2, & 463 zm1, Bnm2,sNm1, ztmp,sNm2,z1, zwork,sNm1) 464 end if 465 466 ! Calculate Yn/Bn-1 467 call zgesv(sNm1,sN,zwork,sNm1,ipiv,Yn,sNm1,ierr) 468 if ( ierr /= 0 ) then 469 write(cerr,'(3(a,i0))') & 470 'Error on inverting Y',n,'/B',n-1,' with error: ',ierr 471 call die(trim(cerr)) 472 end if 473 474 end subroutine calc_Yn_div_Bn_m1 475 476 function Yn_div_Bn_m1(M,n) result(Yn) 477 type(zTriMat), intent(in) :: M 478 integer, intent(in) :: n 479 complex(dp), pointer :: Yn(:) 480 Yn => val(M,n-1,n) 481 end function Yn_div_Bn_m1 482 483 ! We initialize the pivoting array for rotating the inversion 484 subroutine init_TriMat_inversion(M) 485 type(zTriMat), intent(in) :: M 486 integer :: i, N 487 488 N = 0 489 do i = 1 , parts(M) 490 if ( nrows_g(M,i) > N ) then 491 N = nrows_g(M,i) 492 end if 493 end do 494 495 call init_pivot(N) 496 497 end subroutine init_TriMat_inversion 498 499end module m_trimat_invert 500 501