1 ////////////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (c) 2008 The Regents of the University of California
4 //
5 // This file is part of Qbox
6 //
7 // Qbox is distributed under the terms of the GNU General Public License
8 // as published by the Free Software Foundation, either version 2 of
9 // the License, or (at your option) any later version.
10 // See the file COPYING in the root directory of this distribution
11 // or <http://www.gnu.org/licenses/>.
12 //
13 ////////////////////////////////////////////////////////////////////////////////
14 //
15 // Matrix.cpp
16 //
17 ////////////////////////////////////////////////////////////////////////////////
18
19 #include <cassert>
20 #include <vector>
21 #include <complex>
22 #include <limits>
23 #include <iostream>
24 using namespace std;
25 #ifdef USE_MPI
26 #include <mpi.h>
27 #endif
28 #include "MPIdata.h"
29
30 #include "Context.h"
31 #ifdef SCALAPACK
32 #include "blacs.h"
33 #endif
34
35 #include "Matrix.h"
36
37 #ifdef ADD_
38 #define numroc numroc_
39 #define pdtran pdtran_
40 #define pztranc pztranc_
41 #define pdsymm pdsymm_
42 #define pzsymm pzsymm_
43 #define pzhemm pzhemm_
44 #define pdgemm pdgemm_
45 #define pzgemm pzgemm_
46 #define pdsyrk pdsyrk_
47 #define pzherk pzherk_
48 #define pdsyr pdsyr_
49 #define pdger pdger_
50 #define pzgerc pzgerc_
51 #define pzgeru pzgeru_
52 #define pigemr2d pigemr2d_
53 #define pdgemr2d pdgemr2d_
54 #define pzgemr2d pzgemr2d_
55 #define pdtrmm pdtrmm_
56 #define pdtrsm pdtrsm_
57 #define pztrsm pztrsm_
58 #define pdtrtrs pdtrtrs_
59 #define pztrtrs pztrtrs_
60 #define pdpotrf pdpotrf_
61 #define pzpotrf pzpotrf_
62 #define pdpotri pdpotri_
63 #define pdpocon pdpocon_
64 #define pdsygst pdsygst_
65 #define pdsyev pdsyev_
66 #define pdsyevd pdsyevd_
67 #define pdsyevx pdsyevx_
68 #define pzheev pzheev_
69 #define pzheevd pzheevd_
70 #define pdtrtri pdtrtri_
71 #define pztrtri pztrtri_
72 #define pdlatra pdlatra_
73 #define pdlacp2 pdlacp2_
74 #define pdlacp3 pdlacp3_
75 #define pdgetrf pdgetrf_
76 #define pzgetrf pzgetrf_
77 #define pdgetri pdgetri_
78 #define pzgetri pzgetri_
79 #define pdlapiv pdlapiv_
80 #define pzlapiv pzlapiv_
81 #define pdlapv2 pdlapv2_
82 #define pzlapv2 pzlapv2_
83
84 #define dscal dscal_
85 #define zscal zscal_
86 #define zdscal zdscal_
87 #define dcopy dcopy_
88 #define ddot ddot_
89 #define dnrm2 dnrm2_
90 #define dznrm2 dznrm2_
91 #define zdotu zdotu_
92 #define zdotc zdotc_
93 #define daxpy daxpy_
94 #define zaxpy zaxpy_
95 #define dsymm dsymm_
96 #define zsymm zsymm_
97 #define zhemm zhemm_
98 #define dgemm dgemm_
99 #define zgemm zgemm_
100 #define dsyr dsyr_
101 #define dger dger_
102 #define zgerc zgerc_
103 #define zgeru zgeru_
104 #define dsyrk dsyrk_
105 #define zherk zherk_
106 #define dtrmm dtrmm_
107 #define dtrsm dtrsm_
108 #define dtrtri dtrtri_
109 #define ztrtri ztrtri_
110 #define ztrsm ztrsm_
111 #define dtrtrs dtrtrs_
112 #define ztrtrs ztrtrs_
113 #define dpotrf dpotrf_
114 #define zpotrf zpotrf_
115 #define dpotri dpotri_
116 #define dpocon dpocon_
117 #define dsygst dsygst_
118 #define dsyev dsyev_
119 #define zheev zheev_
120 #define idamax idamax_
121 #define dgetrf dgetrf_
122 #define zgetrf zgetrf_
123 #define dgetri dgetri_
124 #define zgetri zgetri_
125 #endif
126
127 extern "C"
128 {
129 int numroc(const int*, const int*, const int*, const int*, const int*);
130 #ifdef SCALAPACK
131 // PBLAS
132 void pdsymm(const char*, const char*, const int*, const int*, const double*,
133 const double*, const int*, const int*, const int*,
134 const double*, const int*, const int*, const int*,
135 const double*, double*, const int*, const int*, const int*);
136 void pzsymm(const char*, const char*, const int*, const int*,
137 const complex<double>*,
138 const complex<double>*, const int*, const int*, const int*,
139 const complex<double>*, const int*, const int*, const int*,
140 const complex<double>*, complex<double>*, const int*, const int*,
141 const int*);
142 void pzhemm(const char*, const char*, const int*, const int*,
143 const complex<double>*,
144 const complex<double>*, const int*, const int*, const int*,
145 const complex<double>*, const int*, const int*, const int*,
146 const complex<double>*, complex<double>*, const int*, const int*,
147 const int*);
148 void pdgemm(const char*, const char*, const int*,
149 const int*, const int*, const double*,
150 const double*, const int*, const int*, const int*,
151 const double*, const int*, const int*, const int*,
152 const double*, double*, const int*, const int*, const int*);
153 void pzgemm(const char*, const char*, const int*,
154 const int*, const int*, const complex<double>*,
155 const complex<double>*, const int*, const int*, const int*,
156 const complex<double>*, const int*, const int*, const int*,
157 const complex<double>*, complex<double>*, const int*, const int*,
158 const int*);
159 void pdger(const int*, const int*, const double*,
160 const double*, const int*, const int*, const int*, const int*,
161 const double*, const int*, const int*, const int*, const int*,
162 double*, const int*, const int*, const int*);
163 void pzgerc(const int*, const int*, const complex<double>*,
164 const complex<double>*, const int*, const int*, const int*, const int*,
165 const complex<double>*, const int*, const int*, const int*, const int*,
166 complex<double>*, const int*, const int*, const int*);
167 void pzgeru(const int*, const int*, const complex<double>*,
168 const complex<double>*, const int*, const int*, const int*, const int*,
169 const complex<double>*, const int*, const int*, const int*, const int*,
170 complex<double>*, const int*, const int*, const int*);
171 void pdsyr(const char*, const int*,
172 const double*, const double*, const int*, const int*, const int*,
173 const int*, double*, const int*, const int*, const int*);
174 void pdsyrk(const char*, const char*, const int*, const int*,
175 const double*, const double*, const int*, const int*, const int*,
176 const double*, double*, const int*, const int*, const int*);
177 void pzherk(const char*, const char*, const int*, const int*,
178 const double*, const complex<double>*, const int*,
179 const int*, const int*,
180 const double*, complex<double>*, const int*,
181 const int*, const int*);
182 void pdtran(const int*,const int*, const double*,
183 const double*, const int*, const int*, const int*,
184 double*, const double*, const int*, const int*, const int*);
185 void pztranc(const int*, const int*, const complex<double>*,
186 const complex<double>*, const int*, const int*, const int*,
187 complex<double>*, const complex<double>*, const int*, const int*,
188 const int*);
189 void pdtrmm(const char*, const char*, const char*, const char*,
190 const int*, const int*, const double*,
191 const double*, const int*, const int*, const int*,
192 double*, const int*, const int*, const int*);
193 void pdtrsm(const char*, const char*, const char*, const char*,
194 const int*, const int*, const double*,
195 const double*, const int*, const int*, const int*,
196 double*, const int*, const int*, const int*);
197 void pztrsm(const char*, const char*, const char*, const char*,
198 const int*, const int*, const complex<double>*,
199 const complex<double>*, const int*, const int*, const int*,
200 complex<double>*, const int*, const int*, const int*);
201 double pdlatra(const int*,const double*,const int*,const int*,const int*);
202 // SCALAPACK
203 void pdtrtrs(const char*, const char*, const char*, const int*, const int*,
204 const double*, const int*, const int*, const int*,
205 double*, const int*, const int*, const int*, int*);
206 void pztrtrs(const char*, const char*, const char*, const int*, const int*,
207 const complex<double>*, const int*, const int*, const int*,
208 complex<double>*, const int*, const int*, const int*, int*);
209 void pigemr2d(const int*,const int*,
210 const int*,const int*,const int*, const int*,
211 int*,const int*,const int*,const int*,const int*);
212 void pdgemr2d(const int*,const int*,
213 const double*,const int*,const int*, const int*,
214 double*,const int*,const int*,const int*,const int*);
215 void pzgemr2d(const int*,const int*,
216 const complex<double>*,const int*,const int*, const int*,
217 complex<double>*,const int*,const int*,const int*,const int*);
218 void pdpotrf(const char*, const int*, double*, const int*,
219 const int*, const int*, const int*);
220 void pzpotrf(const char*, const int*, complex<double>*, const int*,
221 const int*, const int*, const int*);
222 void pdpotri(const char*, const int*, double*, const int*,
223 const int*, const int*, const int*);
224 void pdpocon(const char*, const int*, const double*,
225 const int*, const int*, const int*, const double*, double*,
226 double*, const int*, int*, const int*, int*);
227 void pdsygst(const int*, const char*, const int*, double*,
228 const int*, const int*, const int*, const double*, const int*,
229 const int*, const int*, double*, int*);
230 void pdsyev(const char*, const char*, const int*,
231 double*, const int*, const int*, const int*, double*, double*,
232 const int*, const int*, const int*, double*, const int*, int*);
233 void pdsyevd(const char*, const char*, const int*,
234 double*, const int*, const int*, const int*, double*, double*,
235 const int*, const int*, const int*, double*, const int*, int*,
236 int*, int*);
237 void pdsyevx(const char* jobz, const char* range, const char* uplo,
238 const int* n, double* a, const int* ia, const int* ja,
239 const int* desca, double* vl, double* vu,
240 const int* il, const int* iu, double* abstol,
241 int* nfound, int* nz, double* w,
242 const double* orfac, double* z, const int* iz, const int* jz,
243 const int* descz, double* work, const int* lwork,
244 int* iwork, int* liwork, int* ifail,
245 int* icluster, double* gap, int* info);
246 void pzheev(const char* jobz, const char* uplo, const int* n,
247 complex<double>* a, const int* ia, const int* ja,
248 const int* desca, double* w, complex<double> *z,
249 const int* iz, const int* jz, const int* descz,
250 complex<double>* work, int* lwork,
251 double* rwork, int* lrwork, int* info);
252 void pzheevd(const char* jobz, const char* uplo, const int* n,
253 complex<double>* a, const int* ia, const int* ja,
254 const int* desca, double* w, complex<double> *z,
255 const int* iz, const int* jz, const int* descz,
256 complex<double>* work, int* lwork,
257 double* rwork, const int* lrwork,
258 int* iwork, int* liwork, int* info);
259 void pdtrtri(const char*, const char*, const int*, double*,
260 const int*, const int*, const int*, int*);
261 void pztrtri(const char*, const char*, const int*, complex<double>*,
262 const int*, const int*, const int*, int*);
263 void pdgetrf(const int* m, const int* n, double* val,
264 int* ia, const int* ja, const int* desca, int* ipiv, int* info);
265 void pzgetrf(const int* m, const int* n, complex<double>* val,
266 int* ia, const int* ja, const int* desca, int* ipiv, int* info);
267 void pdgetri(const int* n, double* val,
268 const int* ia, const int* ja, int* desca, int* ipiv,
269 double* work, int* lwork, int* iwork, int* liwork, int* info);
270 void pzgetri(const int* n, complex<double>* val, const int* ia,
271 const int* ja, int* desca, int* ipiv, complex<double>* work,
272 int* lwork, int* iwork, int* liwork, int* info);
273
274 void pdlapiv(const char* direc, const char* rowcol, const char* pivroc,
275 const int* m, const int* n, double *a, const int* ia,
276 const int* ja, const int* desca, int* ipiv, const int* ip,
277 const int* jp, const int* descp, int* iwork);
278 void pzlapiv(const char* direc, const char* rowcol, const char* pivroc,
279 const int* m, const int* n, complex<double> *a, const int* ia,
280 const int* ja, const int* desca, int* ipiv, const int* ip,
281 const int* jp, const int* descp, int* iwork);
282 void pdlapv2(const char* direc, const char *rowcol,
283 const int* m, const int *n, double *val,
284 const int *ia, const int *ja, const int* desca,
285 int *ipiv, const int *ip, const int *jp, const int *descp);
286 void pzlapv2(const char* direc, const char *rowcol,
287 const int* m, const int *n, complex<double> *val,
288 const int *ia, const int *ja, const int* desca,
289 int *ipiv, const int *ip, const int *jp, const int *descp);
290
291 #endif
292 // BLAS1
293 void dscal(const int*, const double*, double*, const int*);
294 void zscal(const int*, const complex<double>*, complex<double>*, const int*);
295 void zdscal(const int*, const double*, complex<double>*, const int*);
296 void daxpy(const int *, const double *, const double *, const int *,
297 double *, const int *);
298 void zaxpy(const int *, const complex<double> *, const complex<double> *,
299 const int *, complex<double> *, const int *);
300 void dcopy(const int *, const double*, const int *, double*, const int*);
301 double ddot(const int *, const double *, const int *,
302 const double *, const int *);
303 double dnrm2(const int *, const double *, const int *);
304 double dznrm2(const int *, const complex<double> *, const int *);
305 complex<double> zdotc(const int *, const complex<double>*, const int *,
306 const complex<double>*, const int *);
307 complex<double> zdotu(const int *, const complex<double>*, const int *,
308 const complex<double>*, const int *);
309 int idamax(const int *, const double*, const int*);
310 // BLAS3
311 void dsymm(const char*, const char*, const int*, const int *,
312 const double*, const double*, const int*,
313 const double*, const int*,
314 const double*, double*, const int*);
315 void zsymm(const char*, const char*, const int*, const int *,
316 const complex<double>*, const complex<double>*, const int*,
317 const complex<double>*, const int*,
318 const complex<double>*, complex<double>*, const int*);
319 void zhemm(const char*, const char*, const int*, const int *,
320 const complex<double>*, const complex<double>*, const int*,
321 const complex<double>*, const int*,
322 const complex<double>*, complex<double>*, const int*);
323 void dgemm(const char*, const char*, const int*, const int *, const int*,
324 const double*, const double*, const int*,
325 const double*, const int*,
326 const double*, double*, const int*);
327 void zgemm(const char*, const char*, const int*, const int *, const int*,
328 const complex<double>*, const complex<double>*, const int*,
329 const complex<double>*, const int*,
330 const complex<double>*, complex<double>*, const int*);
331 void zgerc(const int*, const int *, const complex<double>*,
332 const complex<double>*, const int*,
333 const complex<double>*, const int*,
334 const complex<double>*, const int*);
335 void zgeru(const int*, const int *, const complex<double>*,
336 const complex<double>*, const int*,
337 const complex<double>*, const int*,
338 const complex<double>*, const int*);
339 void dger(const int *, const int*, const double *,
340 const double *, const int *, const double *, const int *,
341 double*, const int*);
342 void dsyr(const char*, const int *, const double *,
343 const double *, const int *, double *, const int *);
344 void dsyrk(const char*, const char*, const int *, const int *,
345 const double *, const double *, const int *,
346 const double *, double *, const int *);
347 void zherk(const char* uplo, const char* trans, const int* n, const int* k,
348 const double* alpha, const complex<double>* a,
349 const int* lda,
350 const double* beta, complex<double>* c, const int* ldc);
351 void dtrmm(const char*, const char*, const char*, const char*,
352 const int*, const int *, const double*, const double*,
353 const int*, double*, const int*);
354 void dtrsm(const char*, const char*, const char*, const char*,
355 const int*, const int *, const double*, const double*,
356 const int*, double*, const int*);
357 void ztrsm(const char*, const char*, const char*, const char*,
358 const int*, const int *, const complex<double>*,
359 const complex<double>*, const int*, complex<double>*, const int*);
360 // LAPACK
361 void dtrtrs(const char*, const char*, const char*,
362 const int*, const int*, const double*, const int*,
363 double*, const int*, int*);
364 void dpotrf(const char*, const int*, double*, const int*, int*);
365 void zpotrf(const char*, const int*, complex<double>*, const int*, int*);
366 void dpotri(const char*, const int*, double*, const int*, int*);
367 void dpocon(const char*, const int *, const double *, const int *,
368 const double *, double *, double *, const int *, int *);
369 void dsygst(const int*, const char*, const int*,
370 double*, const int*, const double*, const int*, int*);
371 void dsyev(const char* jobz, const char* uplo, const int* n, double* a,
372 const int *lda, double *w, double*work,
373 int *lwork, int *info);
374 void zheev(const char* jobz, const char* uplo, const int *n,
375 complex<double>* a, const int *lda, double* w,
376 complex<double>* work, int *lwork, double* rwork, int *info);
377 void dtrtri(const char*, const char*, const int*, double*, const int*, int* );
378 void ztrtri(const char*, const char*, const int*, complex<double>*,
379 const int*, int* );
380 void ztrtrs(const char*, const char*, char*, const int*, const int*,
381 complex<double>*, const int*, complex<double>*, int*, int* );
382 void dgetrf(const int* m, const int* n, double* a, const int* lda,
383 int* ipiv, int*info);
384 void zgetrf(const int* m, const int* n, complex<double>* a, const int* lda,
385 int* ipiv, int*info);
386 void dgetri(const int* m, double* val, const int* lda, int* ipiv,
387 double* work, int* lwork, int* info);
388 void zgetri(const int* m, complex<double>* val, const int* lda, int* ipiv,
389 complex<double>* work, int* lwork, int* info);
390 }
391
392 ////////////////////////////////////////////////////////////////////////////////
393 // numroc0: ScaLAPACK numroc function specialized for the case isrcproc=0
394 // i.e. the process holding the first row/col of the matrix is proc 0
numroc0(int n,int nb,int iproc,int nprocs)395 int numroc0(int n, int nb, int iproc, int nprocs)
396 {
397 // n number of rows/cols of the distributed matrix
398 // nb block size
399 // iproc coordinate of the process whose local array size is being computed
400 // iproc = 0..nprocs
401 // nprocs number of processes over which the matrix is distributed
402
403 // nblocks = total number of whole nb blocks
404 int n_whole_blocks = n / nb;
405
406 // minimum number of rows or cols a process can have
407 int nroc = ( n_whole_blocks / nprocs ) * nb;
408
409 // number of extra blocks needed
410 int n_extra_blocks = n_whole_blocks % nprocs;
411
412 // adjust numroc depending on iproc
413 if ( iproc < n_extra_blocks )
414 nroc += nb;
415 else if ( iproc == n_extra_blocks )
416 nroc += n % nb;
417
418 return nroc;
419 }
420
421 ////////////////////////////////////////////////////////////////////////////////
mloc(int irow) const422 int DoubleMatrix::mloc(int irow) const
423 {
424 return numroc0(m_,mb_,irow,nprow_);
425 }
426
427 ////////////////////////////////////////////////////////////////////////////////
nloc(int icol) const428 int DoubleMatrix::nloc(int icol) const
429 {
430 return numroc0(n_,nb_,icol,npcol_);
431 }
432
433 ////////////////////////////////////////////////////////////////////////////////
434 // reference constructor create a proxy for a ComplexMatrix rhs
DoubleMatrix(ComplexMatrix & rhs)435 DoubleMatrix::DoubleMatrix(ComplexMatrix& rhs) : ctxt_(rhs.context()),
436 reference_(true)
437 {
438 int new_m = 2 * rhs.m();
439 int new_mb = 2 * rhs.mb();
440 init_size(new_m,rhs.n(),new_mb,rhs.nb());
441 val = (double*) rhs.valptr();
442 }
443
444 ////////////////////////////////////////////////////////////////////////////////
445 // reference constructor create a proxy for a const ComplexMatrix rhs
DoubleMatrix(const ComplexMatrix & rhs)446 DoubleMatrix::DoubleMatrix(const ComplexMatrix& rhs) : ctxt_(rhs.context()),
447 reference_(true)
448 {
449 int new_m = 2 * rhs.m();
450 int new_mb = 2 * rhs.mb();
451 init_size(new_m,rhs.n(),new_mb,rhs.nb());
452 val = (double*) rhs.cvalptr();
453 }
454
455 ////////////////////////////////////////////////////////////////////////////////
456 // reference constructor create a proxy for a DoubleMatrix rhs
ComplexMatrix(DoubleMatrix & rhs)457 ComplexMatrix::ComplexMatrix(DoubleMatrix& rhs) : ctxt_(rhs.context()),
458 reference_(true)
459 {
460 assert(rhs.m()%2 == 0);
461 int new_m = rhs.m() / 2;
462 assert(rhs.mb()%2 == 0);
463 int new_mb = rhs.mb() / 2;
464 init_size(new_m,rhs.n(),new_mb,rhs.nb());
465 val = (complex<double>*) rhs.valptr();
466 }
467
468 ////////////////////////////////////////////////////////////////////////////////
469 // reference constructor create a proxy for a const DoubleMatrix rhs
ComplexMatrix(const DoubleMatrix & rhs)470 ComplexMatrix::ComplexMatrix(const DoubleMatrix& rhs) : ctxt_(rhs.context()),
471 reference_(true)
472 {
473 assert(rhs.m()%2 == 0);
474 int new_m = rhs.m() / 2;
475 assert(rhs.mb()%2 == 0);
476 int new_mb = rhs.mb() / 2;
477 init_size(new_m,rhs.n(),new_mb,rhs.nb());
478 val = (complex<double>*) rhs.cvalptr();
479 }
480
481 ////////////////////////////////////////////////////////////////////////////////
init_size(int m,int n,int mb,int nb)482 void DoubleMatrix::init_size(int m, int n, int mb, int nb)
483 {
484 assert(m>=0);
485 assert(n>=0);
486 assert(mb>=0);
487 assert(nb>=0);
488 m_ = m;
489 n_ = n;
490 #ifdef SCALAPACK
491 mb_ = mb;
492 nb_ = nb;
493 #else
494 mb_ = m;
495 nb_ = n;
496 #endif
497 if ( mb_ == 0 ) mb_ = 1;
498 if ( nb_ == 0 ) nb_ = 1;
499 ictxt_ = ctxt_.ictxt();
500 nprow_ = ctxt_.nprow();
501 npcol_ = ctxt_.npcol();
502 myrow_ = ctxt_.myrow();
503 mycol_ = ctxt_.mycol();
504 active_ = myrow_ >= 0;
505 int isrcproc=0;
506 mloc_ = 0;
507 nloc_ = 0;
508 if ( m_ != 0 )
509 mloc_ = numroc(&m_,&mb_,&myrow_,&isrcproc,&nprow_);
510 if ( n_ != 0 )
511 nloc_ = numroc(&n_,&nb_,&mycol_,&isrcproc,&npcol_);
512 size_ = mloc_ * nloc_;
513
514 // set leading dimension of val array to mloc_;
515 lld_ = mloc_;
516 if ( lld_ == 0 ) lld_ = 1;
517
518 // total and local number of blocks
519 mblocks_ = 0;
520 nblocks_ = 0;
521 m_incomplete_ = false;
522 n_incomplete_ = false;
523 if ( active_ && mb_ > 0 && nb_ > 0 )
524 {
525 mblocks_ = ( mloc_ + mb_ - 1 ) / mb_;
526 nblocks_ = ( nloc_ + nb_ - 1 ) / nb_;
527 m_incomplete_ = mloc_ % mb_ != 0;
528 n_incomplete_ = nloc_ % nb_ != 0;
529 }
530
531 if ( active_ )
532 {
533 desc_[0] = 1;
534 }
535 else
536 {
537 desc_[0] = -1;
538 }
539 desc_[1] = ictxt_;
540 desc_[2] = m_;
541 desc_[3] = n_;
542 desc_[4] = mb_;
543 desc_[5] = nb_;
544 desc_[6] = 0;
545 desc_[7] = 0;
546 desc_[8] = lld_;
547 }
548
549 ////////////////////////////////////////////////////////////////////////////////
init_size(int m,int n,int mb,int nb)550 void ComplexMatrix::init_size(int m, int n, int mb, int nb)
551 {
552 assert(m>=0);
553 assert(n>=0);
554 assert(mb>=0);
555 assert(nb>=0);
556 m_ = m;
557 n_ = n;
558 #ifdef SCALAPACK
559 mb_ = mb;
560 nb_ = nb;
561 #else
562 mb_ = m;
563 nb_ = n;
564 #endif
565 if ( mb_ == 0 ) mb_ = 1;
566 if ( nb_ == 0 ) nb_ = 1;
567 ictxt_ = ctxt_.ictxt();
568 nprow_ = ctxt_.nprow();
569 npcol_ = ctxt_.npcol();
570 myrow_ = ctxt_.myrow();
571 mycol_ = ctxt_.mycol();
572 active_ = myrow_ >= 0;
573 int isrcproc=0;
574 mloc_ = 0;
575 nloc_ = 0;
576
577 if ( m_ != 0 )
578 mloc_ = numroc(&m_,&mb_,&myrow_,&isrcproc,&nprow_);
579 if ( n_ != 0 )
580 nloc_ = numroc(&n_,&nb_,&mycol_,&isrcproc,&npcol_);
581 size_ = mloc_ * nloc_;
582
583 // set leading dimension of val array to mloc_;
584 lld_ = mloc_;
585 if ( lld_ == 0 ) lld_ = 1;
586
587 // total and local number of blocks
588 mblocks_ = 0;
589 nblocks_ = 0;
590 m_incomplete_ = false;
591 n_incomplete_ = false;
592 if ( active_ && mb_ > 0 && nb_ > 0 )
593 {
594 mblocks_ = ( mloc_ + mb_ - 1 ) / mb_;
595 nblocks_ = ( nloc_ + nb_ - 1 ) / nb_;
596 m_incomplete_ = mloc_ % mb_ != 0;
597 n_incomplete_ = nloc_ % nb_ != 0;
598 }
599
600 if ( active_ )
601 {
602 desc_[0] = 1;
603 }
604 else
605 {
606 desc_[0] = -1;
607 }
608 desc_[1] = ictxt_;
609 desc_[2] = m_;
610 desc_[3] = n_;
611 desc_[4] = mb_;
612 desc_[5] = nb_;
613 desc_[6] = 0;
614 desc_[7] = 0;
615 desc_[8] = lld_;
616 }
617
618 ////////////////////////////////////////////////////////////////////////////////
clear(void)619 void DoubleMatrix::clear(void)
620 {
621 assert(val!=0||size_==0);
622 memset(val,0,size_*sizeof(double));
623 }
624
625 ////////////////////////////////////////////////////////////////////////////////
clear(void)626 void ComplexMatrix::clear(void)
627 {
628 assert(val!=0||size_==0);
629 memset(val,0,size_*sizeof(complex<double>));
630 }
631
632 ////////////////////////////////////////////////////////////////////////////////
633 // real identity: initialize matrix to identity
634 ////////////////////////////////////////////////////////////////////////////////
identity(void)635 void DoubleMatrix::identity(void)
636 {
637 clear();
638 set('d',1.0);
639 }
640
641 ////////////////////////////////////////////////////////////////////////////////
642 // complex identity: initialize matrix to identity
643 ////////////////////////////////////////////////////////////////////////////////
identity(void)644 void ComplexMatrix::identity(void)
645 {
646 clear();
647 set('d',complex<double>(1.0,0.0));
648 }
649
650 ////////////////////////////////////////////////////////////////////////////////
651 // set value of diagonal or off-diagonal elements to a constant
652 // uplo=='u': set strictly upper part to x
653 // uplo=='l': set strictly lower part to x
654 // uplo=='d': set diagonal to x
655 ////////////////////////////////////////////////////////////////////////////////
set(char uplo,double xx)656 void DoubleMatrix::set(char uplo, double xx)
657 {
658 if ( active_ )
659 {
660 if ( uplo=='l' || uplo=='L' )
661 {
662 // initialize strictly lower part
663 for (int li=0; li < mblocks_;li++)
664 {
665 for (int lj=0; lj < nblocks_;lj++)
666 {
667 for (int ii=0; ii < mbs(li); ii++)
668 {
669 for (int jj=0; jj < nbs(lj);jj++)
670 {
671 if ( i(li,ii) > j(lj,jj) )
672 val[ (ii+li*mb_)+(jj+lj*nb_)*mloc_ ] = xx;
673 }
674 }
675 }
676 }
677 }
678 else if ( uplo=='u' || uplo=='U' )
679 {
680 // initialize strictly upper part
681 for ( int li=0; li < mblocks_; li++ )
682 {
683 for ( int lj=0; lj < nblocks_; lj++ )
684 {
685 for ( int ii=0; ii < mbs(li); ii++ )
686 {
687 for ( int jj=0; jj < nbs(lj); jj++ )
688 {
689 if ( i(li,ii) < j(lj,jj) )
690 val[ (ii+li*mb_)+(jj+lj*nb_)*mloc_ ] = xx;
691 }
692 }
693 }
694 }
695 }
696 else if ( uplo=='d' || uplo=='D' )
697 {
698 // initialize diagonal elements
699 if ( active() )
700 {
701 // loop through all local blocks (ll,mm)
702 for ( int ll = 0; ll < mblocks(); ll++)
703 {
704 for ( int mm = 0; mm < nblocks(); mm++)
705 {
706 // check if block (ll,mm) has diagonal elements
707 int imin = i(ll,0);
708 int imax = imin + mbs(ll)-1;
709 int jmin = j(mm,0);
710 int jmax = jmin + nbs(mm)-1;
711 // cout << " process (" << myrow_ << "," << mycol_ << ")"
712 // << " block (" << ll << "," << mm << ")"
713 // << " imin/imax=" << imin << "/" << imax
714 // << " jmin/jmax=" << jmin << "/" << jmax << endl;
715
716 if ((imin <= jmax) && (imax >= jmin))
717 {
718 // block (ll,mm) holds diagonal elements
719 int idiagmin = max(imin,jmin);
720 int idiagmax = min(imax,jmax);
721
722 // cout << " process (" << myrow_ << "," << mycol_ << ")"
723 // << " holds diagonal elements " << idiagmin << " to " <<
724 // idiagmax << " in block (" << ll << "," << mm << ")" << endl;
725
726 for ( int ii = idiagmin; ii <= idiagmax; ii++ )
727 {
728 // access element (ii,ii)
729 int jj = ii;
730 int iii = ll * mb_ + x(ii);
731 int jjj = mm * nb_ + y(jj);
732 val[iii+mloc_*jjj] = xx;
733 }
734 }
735 }
736 }
737 }
738 }
739 else
740 {
741 cout << " DoubleMatrix::set: invalid argument" << endl;
742 #ifdef USE_MPI
743 MPI_Abort(MPI_COMM_WORLD,2);
744 #else
745 exit(2);
746 #endif
747 }
748 }
749 }
750
751 ////////////////////////////////////////////////////////////////////////////////
set(char uplo,complex<double> xx)752 void ComplexMatrix::set(char uplo, complex<double> xx)
753 {
754 if ( active_ )
755 {
756 if ( uplo=='l' || uplo=='L' )
757 {
758 // initialize strictly lower part
759 for (int li=0; li < mblocks_;li++)
760 {
761 for (int lj=0; lj < nblocks_;lj++)
762 {
763 for (int ii=0; ii < mbs(li); ii++)
764 {
765 for (int jj=0; jj < nbs(lj);jj++)
766 {
767 if ( i(li,ii) > j(lj,jj) )
768 val[ (ii+li*mb_)+(jj+lj*nb_)*mloc_ ] = xx;
769 }
770 }
771 }
772 }
773 }
774 else if ( uplo=='u' || uplo=='U' )
775 {
776 // initialize strictly upper part
777 for ( int li=0; li < mblocks_; li++ )
778 {
779 for ( int lj=0; lj < nblocks_; lj++ )
780 {
781 for ( int ii=0; ii < mbs(li); ii++ )
782 {
783 for ( int jj=0; jj < nbs(lj); jj++ )
784 {
785 if ( i(li,ii) < j(lj,jj) )
786 val[ (ii+li*mb_)+(jj+lj*nb_)*mloc_ ] = xx;
787 }
788 }
789 }
790 }
791 }
792 else if ( uplo=='d' || uplo=='D' )
793 {
794 // initialize diagonal elements
795 if ( active() )
796 {
797 // loop through all local blocks (ll,mm)
798 for ( int ll = 0; ll < mblocks(); ll++)
799 {
800 for ( int mm = 0; mm < nblocks(); mm++)
801 {
802 // check if block (ll,mm) has diagonal elements
803 int imin = i(ll,0);
804 int imax = imin + mbs(ll)-1;
805 int jmin = j(mm,0);
806 int jmax = jmin + nbs(mm)-1;
807 // cout << " process (" << myrow_ << "," << mycol_ << ")"
808 // << " block (" << ll << "," << mm << ")"
809 // << " imin/imax=" << imin << "/" << imax
810 // << " jmin/jmax=" << jmin << "/" << jmax << endl;
811
812 if ((imin <= jmax) && (imax >= jmin))
813 {
814 // block (ll,mm) holds diagonal elements
815 int idiagmin = max(imin,jmin);
816 int idiagmax = min(imax,jmax);
817
818 // cout << " process (" << myrow_ << "," << mycol_ << ")"
819 // << " holds diagonal elements " << idiagmin << " to " <<
820 // idiagmax << " in block (" << ll << "," << mm << ")" << endl;
821
822 for ( int ii = idiagmin; ii <= idiagmax; ii++ )
823 {
824 // access element (ii,ii)
825 int jj = ii;
826 int iii = ll * mb_ + x(ii);
827 int jjj = mm * nb_ + y(jj);
828 val[iii+mloc_*jjj] = xx;
829 }
830 }
831 }
832 }
833 }
834 }
835 else
836 {
837 cout << " DoubleMatrix::set: invalid argument" << endl;
838 #ifdef USE_MPI
839 MPI_Abort(MPI_COMM_WORLD,2);
840 #else
841 exit(2);
842 #endif
843 }
844 }
845 }
846
847 ////////////////////////////////////////////////////////////////////////////////
848 // initialize *this using a replicated matrix a
init(const double * const a,int lda)849 void DoubleMatrix::init(const double* const a, int lda)
850 {
851 if ( active_ )
852 {
853 for ( int li=0; li < mblocks_; li++ )
854 {
855 for ( int lj=0; lj < nblocks_; lj++ )
856 {
857 for ( int ii=0; ii < mbs(li); ii++ )
858 {
859 for ( int jj=0; jj < nbs(lj); jj++ )
860 {
861 val[ (ii+li*mb_)+(jj+lj*nb_)*mloc_ ]
862 = a[ i(li,ii) + j(lj,jj)*lda ];
863 }
864 }
865 }
866 }
867 }
868 }
869
870 ////////////////////////////////////////////////////////////////////////////////
dot(const DoubleMatrix & x) const871 double DoubleMatrix::dot(const DoubleMatrix &x) const
872 {
873 assert( ictxt_ == x.ictxt() );
874 double sum=0.;
875 double tsum=0.;
876 if ( active_ )
877 {
878 assert( m_ == x.m() );
879 assert( n_ == x.n() );
880 assert( mb_ == x.mb() );
881 assert( nb_ == x.nb() );
882 assert( mloc_ == x.mloc() );
883 assert( nloc_ == x.nloc() );
884 assert(size_==x.size());
885 int ione=1;
886 tsum=ddot(&size_, val, &ione, x.val, &ione);
887 }
888 #ifdef SCALAPACK
889 if ( active_ )
890 MPI_Allreduce(&tsum, &sum, 1, MPI_DOUBLE, MPI_SUM, ctxt_.comm() );
891 #else
892 sum=tsum;
893 #endif
894 return sum;
895 }
896
897 ////////////////////////////////////////////////////////////////////////////////
dot(const ComplexMatrix & x) const898 complex<double> ComplexMatrix::dot(const ComplexMatrix &x) const
899 {
900 assert( ictxt_ == x.ictxt() );
901 complex<double> sum=0.0;
902 complex<double> tsum=0.0;
903 if ( active_ )
904 {
905 assert( m_ == x.m() );
906 assert( n_ == x.n() );
907 assert( mb_ == x.mb() );
908 assert( nb_ == x.nb() );
909 assert( mloc_ == x.mloc() );
910 assert( nloc_ == x.nloc() );
911 assert(size_==x.size());
912 //int ione=1;
913 //tsum=zdotc(&size_, val, &ione, x.val, &ione);
914 for ( int i = 0; i < size_; i++ )
915 tsum += conj(val[i]) * x.val[i];
916 }
917 #ifdef SCALAPACK
918 if ( active_ )
919 MPI_Allreduce((double*)&tsum, (double*)&sum, 2,
920 MPI_DOUBLE, MPI_SUM, ctxt_.comm() );
921 #else
922 sum=tsum;
923 #endif
924 return sum;
925 }
926
927 ////////////////////////////////////////////////////////////////////////////////
dotu(const ComplexMatrix & x) const928 complex<double> ComplexMatrix::dotu(const ComplexMatrix &x) const
929 {
930 assert( ictxt_ == x.ictxt() );
931 complex<double> sum=0.0;
932 complex<double> tsum=0.0;
933 if ( active_ )
934 {
935 assert( m_ == x.m() );
936 assert( n_ == x.n() );
937 assert( mb_ == x.mb() );
938 assert( nb_ == x.nb() );
939 assert( mloc_ == x.mloc() );
940 assert( nloc_ == x.nloc() );
941 assert(size_==x.size());
942 //int ione=1;
943 //tsum=zdotu(&size_, val, &ione, x.val, &ione);
944 for ( int i = 0; i < size_; i++ )
945 tsum += val[i] * x.val[i];
946 }
947 #ifdef SCALAPACK
948 if ( active_ )
949 MPI_Allreduce((double*)&tsum, (double*)&sum, 2,
950 MPI_DOUBLE, MPI_SUM, ctxt_.comm() );
951 #else
952 sum=tsum;
953 #endif
954 return sum;
955 }
956
957 ////////////////////////////////////////////////////////////////////////////////
amax(void) const958 double DoubleMatrix::amax(void) const
959 {
960 double am = 0.0, tam = 0.0;
961 if ( active_ )
962 {
963 int ione=1;
964 tam = val[idamax(&size_,val,&ione) - 1];
965 }
966 #ifdef SCALAPACK
967 if ( active_ )
968 MPI_Allreduce(&tam, &am, 1, MPI_DOUBLE, MPI_MAX, ctxt_.comm() );
969 #else
970 am=tam;
971 #endif
972 return am;
973 }
974
975 ////////////////////////////////////////////////////////////////////////////////
976 // axpy: *this = *this + alpha * x
axpy(double alpha,const DoubleMatrix & x)977 void DoubleMatrix::axpy(double alpha, const DoubleMatrix &x)
978 {
979 assert( ictxt_ == x.ictxt() );
980 int ione=1;
981 assert(m_==x.m());
982 assert(n_==x.n());
983 assert(mloc_==x.mloc());
984 assert(nloc_==x.nloc());
985 if( active_ )
986 daxpy(&size_, &alpha, x.val, &ione, val, &ione);
987 }
988
989 ////////////////////////////////////////////////////////////////////////////////
axpy(complex<double> alpha,const ComplexMatrix & x)990 void ComplexMatrix::axpy(complex<double> alpha, const ComplexMatrix &x)
991 {
992 assert( ictxt_ == x.ictxt() );
993 int ione=1;
994 assert(m_==x.m());
995 assert(n_==x.n());
996 assert(mloc_==x.mloc());
997 assert(nloc_==x.nloc());
998 if( active_ )
999 zaxpy(&size_, &alpha, x.val, &ione, val, &ione);
1000 }
1001
1002 ////////////////////////////////////////////////////////////////////////////////
axpy(double alpha,const ComplexMatrix & x)1003 void ComplexMatrix::axpy(double alpha, const ComplexMatrix &x)
1004 {
1005 assert( ictxt_ == x.ictxt() );
1006 int ione=1;
1007 assert(m_==x.m());
1008 assert(n_==x.n());
1009 assert(mloc_==x.mloc());
1010 assert(nloc_==x.nloc());
1011 int len = 2 * size_;
1012 if( active_ )
1013 daxpy(&len, &alpha, (double*) x.val, &ione, (double*) val, &ione);
1014 }
1015
1016 ////////////////////////////////////////////////////////////////////////////////
1017 // real getsub: *this = sub(A)
1018 // copy submatrix A(ia:ia+m, ja:ja+n) into *this;
1019 // *this and A may live in different contexts
getsub(const DoubleMatrix & a,int m,int n,int ia,int ja)1020 void DoubleMatrix::getsub(const DoubleMatrix &a,
1021 int m, int n, int ia, int ja)
1022 {
1023 #if SCALAPACK
1024 int iap=ia+1;
1025 int jap=ja+1;
1026 assert(n<=n_);
1027 assert(n<=a.n());
1028 assert(m<=m_);
1029 assert(m<=a.m());
1030 int ione = 1;
1031 int gictxt;
1032 Cblacs_get( 0, 0, &gictxt );
1033 pdgemr2d(&m,&n,a.val,&iap,&jap,a.desc_,val,&ione,&ione,desc_,&gictxt);
1034 #else
1035 for ( int j = 0; j < n; j++ )
1036 for ( int i = 0; i < m; i++ )
1037 val[i+j*m_] = a.val[(i+ia) + (j+ja)*a.m()];
1038 #endif
1039 }
1040
1041 ////////////////////////////////////////////////////////////////////////////////
1042 // real getsub: *this = sub(A)
1043 // copy submatrix A(ia:ia+m, ja:ja+n) into *this(idest:idest+m,jdest:jdest+n)
1044 // *this and A may live in different contexts
getsub(const DoubleMatrix & a,int m,int n,int isrc,int jsrc,int idest,int jdest)1045 void DoubleMatrix::getsub(const DoubleMatrix &a,
1046 int m, int n, int isrc, int jsrc, int idest, int jdest)
1047 {
1048 #if SCALAPACK
1049 int iap=isrc+1;
1050 int jap=jsrc+1;
1051 int idp=idest+1;
1052 int jdp=jdest+1;
1053 assert(n<=n_);
1054 assert(n<=a.n());
1055 assert(m<=m_);
1056 assert(m<=a.m());
1057 int gictxt;
1058 Cblacs_get( 0, 0, &gictxt );
1059 pdgemr2d(&m,&n,a.val,&iap,&jap,a.desc_,val,&idp,&jdp,desc_,&gictxt);
1060 #else
1061 for ( int j = 0; j < n; j++ )
1062 for ( int i = 0; i < m; i++ )
1063 val[(idest+i)+(jdest+j)*m_] = a.val[(i+isrc) + (j+jsrc)*a.m()];
1064 #endif
1065 }
1066
1067 ////////////////////////////////////////////////////////////////////////////////
1068 // complex getsub: *this = sub(A)
1069 // copy submatrix A(ia:ia+m, ja:ja+n) into *this
1070 // *this and A may live in different contexts
getsub(const ComplexMatrix & a,int m,int n,int ia,int ja)1071 void ComplexMatrix::getsub(const ComplexMatrix &a,
1072 int m, int n, int ia, int ja)
1073 {
1074 #if SCALAPACK
1075 int iap=ia+1;
1076 int jap=ja+1;
1077 assert(n<=n_);
1078 assert(n<=a.n());
1079 assert(m<=m_);
1080 assert(m<=a.m());
1081 int ione = 1;
1082 int gictxt;
1083 Cblacs_get( 0, 0, &gictxt );
1084 pzgemr2d(&m,&n,a.val,&iap,&jap,a.desc_,val,&ione,&ione,desc_,&gictxt);
1085 #else
1086 for ( int j = 0; j < n; j++ )
1087 for ( int i = 0; i < m; i++ )
1088 val[i+j*m_] = a.val[(i+ia) + (j+ja)*a.m()];
1089 #endif
1090 }
1091
1092 ////////////////////////////////////////////////////////////////////////////////
1093 // complex getsub: *this = sub(A)
1094 // copy submatrix A(ia:ia+m, ja:ja+n) into *this(idest:idest+m,jdest:jdest+n)
1095 // *this and A may live in different contexts
getsub(const ComplexMatrix & a,int m,int n,int isrc,int jsrc,int idest,int jdest)1096 void ComplexMatrix::getsub(const ComplexMatrix &a,
1097 int m, int n, int isrc, int jsrc, int idest, int jdest)
1098 {
1099 #if SCALAPACK
1100 int iap=isrc+1;
1101 int jap=jsrc+1;
1102 int idp=idest+1;
1103 int jdp=jdest+1;
1104 assert(n<=n_);
1105 assert(n<=a.n());
1106 assert(m<=m_);
1107 assert(m<=a.m());
1108 int gictxt;
1109 Cblacs_get( 0, 0, &gictxt );
1110 pzgemr2d(&m,&n,a.val,&iap,&jap,a.desc_,val,&idp,&jdp,desc_,&gictxt);
1111 #else
1112 for ( int j = 0; j < n; j++ )
1113 for ( int i = 0; i < m; i++ )
1114 val[(idest+i)+(jdest+j)*m_] = a.val[(i+isrc) + (j+jsrc)*a.m()];
1115 #endif
1116 }
1117
1118 ////////////////////////////////////////////////////////////////////////////////
1119 // real matrix transpose
1120 // this = alpha * transpose(A) + beta * this
1121 ////////////////////////////////////////////////////////////////////////////////
transpose(double alpha,const DoubleMatrix & a,double beta)1122 void DoubleMatrix::transpose(double alpha, const DoubleMatrix& a, double beta)
1123 {
1124 assert(this != &a);
1125 assert( ictxt_ == a.ictxt() );
1126
1127 if ( active() )
1128 {
1129 assert(a.m() == n_);
1130 assert(a.n() == m_);
1131
1132 #ifdef SCALAPACK
1133 int ione = 1;
1134 pdtran(&m_, &n_, &alpha,
1135 a.val, &ione, &ione, a.desc_,
1136 &beta, val, &ione, &ione, desc_);
1137 #else
1138 scal(beta);
1139 for ( int i=0; i<m_; i++ )
1140 for ( int j=0; j<i; j++ )
1141 {
1142 val[i*m_+j] += alpha * a.val[j*m_+i];
1143 val[j*m_+i] += alpha * a.val[i*m_+j];
1144 }
1145 for ( int i=0; i<m_; i++ )
1146 val[i*m_+i] += alpha * a.val[i*m_+i];
1147 #endif
1148 }
1149 }
1150
1151 ////////////////////////////////////////////////////////////////////////////////
1152 // real matrix transpose
1153 // *this = transpose(a)
1154 ////////////////////////////////////////////////////////////////////////////////
transpose(const DoubleMatrix & a)1155 void DoubleMatrix::transpose(const DoubleMatrix& a)
1156 {
1157 assert(this != &a);
1158 transpose(1.0,a,0.0);
1159 }
1160
1161 ////////////////////////////////////////////////////////////////////////////////
1162 // complex hermitian transpose
1163 // this = alpha * A^H + beta * this
1164 ////////////////////////////////////////////////////////////////////////////////
transpose(complex<double> alpha,const ComplexMatrix & a,complex<double> beta)1165 void ComplexMatrix::transpose(complex<double> alpha, const ComplexMatrix& a,
1166 complex<double> beta)
1167 {
1168 assert(this != &a);
1169 assert( ictxt_ == a.ictxt() );
1170
1171 if ( active() )
1172 {
1173 assert(a.m() == n_);
1174 assert(a.n() == m_);
1175
1176 #ifdef SCALAPACK
1177 int ione = 1;
1178 pztranc(&m_, &n_, &alpha,
1179 a.val, &ione, &ione, a.desc_,
1180 &beta, val, &ione, &ione, desc_);
1181 #else
1182 scal(beta);
1183 for ( int i=0; i<m_; i++ )
1184 for ( int j=0; j<i; j++ )
1185 {
1186 val[i*m_+j] += alpha * conj(a.val[j*m_+i]);
1187 val[j*m_+i] += alpha * conj(a.val[i*m_+j]);
1188 }
1189 for ( int i=0; i<m_; i++ )
1190 val[i*m_+i] += alpha * a.val[i*m_+i];
1191 #endif
1192 }
1193 }
1194
1195 ////////////////////////////////////////////////////////////////////////////////
1196 // complex matrix transpose
1197 // *this = transpose(a)
1198 ////////////////////////////////////////////////////////////////////////////////
transpose(const ComplexMatrix & a)1199 void ComplexMatrix::transpose(const ComplexMatrix& a)
1200 {
1201 assert(this != &a);
1202 transpose(complex<double>(1.0,0.0),a,complex<double>(0.0,0.0));
1203 }
1204
1205 ////////////////////////////////////////////////////////////////////////////////
symmetrize(char uplo)1206 void DoubleMatrix::symmetrize(char uplo)
1207 {
1208 // symmetrize
1209 // if uplo == 'l' : copy strictly lower triangle to strictly upper triangle
1210 // if uplo == 'u' : copy strictly upper triangle to strictly lower triangle
1211 // if uplo == 'n' : A = 0.5 * ( A^T + A )
1212
1213 if ( uplo == 'n' )
1214 {
1215 DoubleMatrix tmp(*this);
1216 transpose(0.5,tmp,0.5);
1217 }
1218 else if ( uplo == 'l' )
1219 {
1220 set('u',0.0);
1221 DoubleMatrix tmp(*this);
1222 tmp.set('d',0.0);
1223 transpose(1.0,tmp,1.0);
1224 }
1225 else if ( uplo == 'u' )
1226 {
1227 set('l',0.0);
1228 DoubleMatrix tmp(*this);
1229 tmp.set('d',0.0);
1230 transpose(1.0,tmp,1.0);
1231 }
1232 else
1233 {
1234 cout << " DoubleMatrix::symmetrize: invalid argument" << endl;
1235 #ifdef USE_MPI
1236 MPI_Abort(MPI_COMM_WORLD, 2);
1237 #else
1238 exit(2);
1239 #endif
1240 }
1241 }
1242
1243 ////////////////////////////////////////////////////////////////////////////////
symmetrize(char uplo)1244 void ComplexMatrix::symmetrize(char uplo)
1245 {
1246 // symmetrize
1247 // uplo == 'l' : copy conjugate of strictly lower triangle to strictly upper
1248 // uplo == 'u' : copy conjugate of strictly upper triangle to strictly lower
1249 // uplo == 'n' : A = 0.5 * ( A^H + A )
1250
1251 if ( uplo == 'n' )
1252 {
1253 ComplexMatrix tmp(*this);
1254 transpose(complex<double>(0.5,0.0),tmp,complex<double>(0.5,0.0));
1255 }
1256 else if ( uplo == 'l' )
1257 {
1258 set('u',complex<double>(0.0,0.0));
1259 ComplexMatrix tmp(*this);
1260 tmp.set('d',complex<double>(0.0,0.0));
1261 transpose(complex<double>(1.0,0.0),tmp,complex<double>(1.0,0.0));
1262 }
1263 else if ( uplo == 'u' )
1264 {
1265 set('l',complex<double>(0.0,0.0));
1266 ComplexMatrix tmp(*this);
1267 tmp.set('d',complex<double>(0.0,0.0));
1268 transpose(complex<double>(1.0,0.0),tmp,complex<double>(1.0,0.0));
1269 }
1270 else
1271 {
1272 cout << " ComplexMatrix::symmetrize: invalid argument" << endl;
1273 #ifdef USE_MPI
1274 MPI_Abort(MPI_COMM_WORLD, 2);
1275 #else
1276 exit(2);
1277 #endif
1278 }
1279 }
1280
1281 ////////////////////////////////////////////////////////////////////////////////
nrm2(void) const1282 double DoubleMatrix::nrm2(void) const
1283 {
1284 double sum=0.;
1285 double tsum=0.;
1286 if ( active_ )
1287 {
1288 int ione=1;
1289 // dnrm2 returns sqrt(sum_i x[i]*x[i])
1290 tsum = dnrm2(&size_,val,&ione);
1291 tsum = tsum*tsum;
1292 }
1293 #ifdef SCALAPACK
1294 if ( active_ )
1295 MPI_Allreduce(&tsum, &sum, 1, MPI_DOUBLE, MPI_SUM, ctxt_.comm() );
1296 #else
1297 sum=tsum;
1298 #endif
1299 return sqrt(sum);
1300 }
1301
1302 ////////////////////////////////////////////////////////////////////////////////
nrm2(void) const1303 double ComplexMatrix::nrm2(void) const
1304 {
1305 double sum=0.;
1306 double tsum=0.;
1307 if ( active_ )
1308 {
1309 int ione=1;
1310 // dznrm2 returns sqrt(sum_i conjg(x[i])*x[i])
1311 tsum = dznrm2(&size_,val,&ione);
1312 tsum = tsum*tsum;
1313 }
1314 #ifdef SCALAPACK
1315 if ( active_ )
1316 MPI_Allreduce(&tsum, &sum, 1, MPI_DOUBLE, MPI_SUM, ctxt_.comm() );
1317 #else
1318 sum=tsum;
1319 #endif
1320 return sqrt(sum);
1321 }
1322
1323 ////////////////////////////////////////////////////////////////////////////////
1324 // rank-1 update using row kx of x and (row ky of y)^T
1325 // *this = *this + alpha * x(kx) * y(ky)^T
ger(double alpha,const DoubleMatrix & x,int kx,const DoubleMatrix & y,int ky)1326 void DoubleMatrix::ger(double alpha,
1327 const DoubleMatrix& x, int kx, const DoubleMatrix& y, int ky)
1328 {
1329 assert(x.n()==m_);
1330 assert(y.n()==n_);
1331 #if SCALAPACK
1332 int ione=1;
1333
1334 int ix = kx+1;
1335 int jx = 1;
1336 int incx = x.m();
1337
1338 int iy = ky+1;
1339 int jy = 1;
1340 int incy = y.m();
1341 pdger(&m_,&n_,&alpha,x.val,&ix,&jx,x.desc_,&incx,
1342 y.val,&iy,&jy,y.desc_,&incy,
1343 val,&ione,&ione,desc_);
1344 #else
1345 int incx = x.m();
1346 int incy = y.m();
1347 dger(&m_,&n_,&alpha,&x.val[kx*x.m()],&incx,
1348 &y.val[ky*y.m()],&incy,val,&m_);
1349 #endif
1350 }
1351
1352 ////////////////////////////////////////////////////////////////////////////////
1353 // rank-1 update using row kx of x and conj(row ky of y)^T
1354 // *this = *this + alpha * x(kx) * conj(y(ky))^T
gerc(complex<double> alpha,const ComplexMatrix & x,int kx,const ComplexMatrix & y,int ky)1355 void ComplexMatrix::gerc(complex<double> alpha,
1356 const ComplexMatrix& x, int kx, const ComplexMatrix& y, int ky)
1357 {
1358 assert(x.n()==m_);
1359 assert(y.n()==n_);
1360 #if SCALAPACK
1361 int ione=1;
1362
1363 int ix = kx+1;
1364 int jx = 1;
1365 int incx = x.m();
1366
1367 int iy = ky+1;
1368 int jy = 1;
1369 int incy = y.m();
1370 pzgerc(&m_,&n_,&alpha,x.val,&ix,&jx,x.desc_,&incx,
1371 y.val,&iy,&jy,y.desc_,&incy,
1372 val,&ione,&ione,desc_);
1373 #else
1374 int incx = x.m();
1375 int incy = y.m();
1376 zgerc(&m_,&n_,&alpha,&x.val[kx*x.m()],&incx,
1377 &y.val[ky*y.m()],&incy,val,&m_);
1378 #endif
1379 }
1380
1381 ////////////////////////////////////////////////////////////////////////////////
1382 // rank-1 update using row kx of x and conj(row ky of y)^T
1383 // *this = *this + alpha * x(kx) * y(ky)^T
geru(complex<double> alpha,const ComplexMatrix & x,int kx,const ComplexMatrix & y,int ky)1384 void ComplexMatrix::geru(complex<double> alpha,
1385 const ComplexMatrix& x, int kx, const ComplexMatrix& y, int ky)
1386 {
1387 assert(x.n()==m_);
1388 assert(y.n()==n_);
1389 #if SCALAPACK
1390 int ione=1;
1391
1392 int ix = kx+1;
1393 int jx = 1;
1394 int incx = x.m();
1395
1396 int iy = ky+1;
1397 int jy = 1;
1398 int incy = y.m();
1399 pzgeru(&m_,&n_,&alpha,x.val,&ix,&jx,x.desc_,&incx,
1400 y.val,&iy,&jy,y.desc_,&incy,
1401 val,&ione,&ione,desc_);
1402 #else
1403 int incx = x.m();
1404 int incy = y.m();
1405 zgeru(&m_,&n_,&alpha,&x.val[kx*x.m()],&incx,
1406 &y.val[ky*y.m()],&incy,val,&m_);
1407 #endif
1408 }
1409
1410 ////////////////////////////////////////////////////////////////////////////////
1411 // symmetric rank-1 update using a row or a column of a Matrix x
syr(char uplo,double alpha,const DoubleMatrix & x,int k,char rowcol)1412 void DoubleMatrix::syr(char uplo, double alpha,
1413 const DoubleMatrix& x, int k, char rowcol)
1414 {
1415 assert(n_==m_);
1416 #if SCALAPACK
1417 int ix,jx,incx,ione=1;
1418 if ( rowcol == 'c' )
1419 {
1420 // use column k of matrix x
1421 assert(x.m()==n_);
1422 ix = 1;
1423 jx = k+1;
1424 incx = 1;
1425 }
1426 else if ( rowcol == 'r' )
1427 {
1428 // use row k of matrix x
1429 assert(x.n()==n_);
1430 ix = k+1;
1431 jx = 1;
1432 incx = x.m();
1433 }
1434 else
1435 {
1436 cout << " DoubleMatrix::syr: invalid argument rowcol" << endl;
1437 MPI_Abort(MPI_COMM_WORLD,2);
1438 }
1439 pdsyr(&uplo,&n_,&alpha,x.val,&ix,&jx,x.desc_,&incx,
1440 val,&ione,&ione,desc_);
1441 #else
1442 if ( rowcol == 'c' )
1443 {
1444 // use column k of matrix x
1445 assert(x.m()==n_);
1446 int incx = 1;
1447 dsyr(&uplo,&n_,&alpha,&x.val[k*x.m()],&incx,val,&m_);
1448 }
1449 else if ( rowcol == 'r' )
1450 {
1451 // use row k of matrix x
1452 assert(x.n()==n_);
1453 int incx = x.m();
1454 dsyr(&uplo,&n_,&alpha,&x.val[k],&incx,val,&m_);
1455 }
1456 else
1457 {
1458 cout << " DoubleMatrix::syr: invalid argument rowcol" << endl;
1459 exit(2);
1460 }
1461 #endif
1462 }
1463
1464 ////////////////////////////////////////////////////////////////////////////////
operator =(const DoubleMatrix & a)1465 DoubleMatrix& DoubleMatrix::operator=(const DoubleMatrix& a)
1466 {
1467 if ( this == &a ) return *this;
1468
1469 // operator= works only for matrices having same distribution on same context
1470 assert( a.ictxt() == ictxt_ && a.m() == m_ && a.mb() == mb_ &&
1471 a.n() == n_ && a.nb() == nb_ );
1472 if ( active() )
1473 {
1474 for ( int i = 0; i < 9; i++ )
1475 {
1476 assert( desc_[i] == a.desc_[i] );
1477 }
1478 memcpy(val, a.val, mloc_*nloc_*sizeof(double));
1479 }
1480 return *this;
1481 }
1482
1483 ////////////////////////////////////////////////////////////////////////////////
operator =(const ComplexMatrix & a)1484 ComplexMatrix& ComplexMatrix::operator=(const ComplexMatrix& a)
1485 {
1486 if ( this == &a ) return *this;
1487
1488 assert( a.ictxt() == ictxt_ && a.m() == m_ && a.mb() == mb_ &&
1489 a.n() == n_ && a.nb() == nb_ );
1490 if ( active() )
1491 {
1492 for ( int i = 0; i < 9; i++ )
1493 {
1494 assert( desc_[i] == a.desc_[i] );
1495 }
1496 memcpy(val, a.val, mloc_*nloc_*sizeof(complex<double>));
1497 }
1498 return *this;
1499 }
1500
1501 ////////////////////////////////////////////////////////////////////////////////
1502 // operator+=
operator +=(const DoubleMatrix & x)1503 DoubleMatrix& DoubleMatrix::operator+=(const DoubleMatrix &x)
1504 {
1505 assert( ictxt_ == x.ictxt() );
1506 int ione=1;
1507 assert(m_==x.m());
1508 assert(n_==x.n());
1509 assert(mloc_==x.mloc());
1510 assert(nloc_==x.nloc());
1511 double alpha = 1.0;
1512 if( active_ )
1513 daxpy(&size_, &alpha, x.val, &ione, val, &ione);
1514 return *this;
1515 }
1516
1517 ////////////////////////////////////////////////////////////////////////////////
1518 // operator-=
operator -=(const DoubleMatrix & x)1519 DoubleMatrix& DoubleMatrix::operator-=(const DoubleMatrix &x)
1520 {
1521 assert( ictxt_ == x.ictxt() );
1522 int ione=1;
1523 assert(m_==x.m());
1524 assert(n_==x.n());
1525 assert(mloc_==x.mloc());
1526 assert(nloc_==x.nloc());
1527 double alpha = -1.0;
1528 if( active_ )
1529 daxpy(&size_, &alpha, x.val, &ione, val, &ione);
1530 return *this;
1531 }
1532
1533 ////////////////////////////////////////////////////////////////////////////////
1534 // operator+=
operator +=(const ComplexMatrix & x)1535 ComplexMatrix& ComplexMatrix::operator+=(const ComplexMatrix& x)
1536 {
1537 assert( ictxt_ == x.ictxt() );
1538 int ione=1;
1539 assert(m_==x.m());
1540 assert(n_==x.n());
1541 assert(mloc_==x.mloc());
1542 assert(nloc_==x.nloc());
1543 double alpha = 1.0;
1544 int two_size = 2 * size_;
1545 if( active_ )
1546 daxpy(&two_size, &alpha, (double*) x.val, &ione, (double*) val, &ione);
1547 return *this;
1548 }
1549
1550 ////////////////////////////////////////////////////////////////////////////////
1551 // operator-=
operator -=(const ComplexMatrix & x)1552 ComplexMatrix& ComplexMatrix::operator-=(const ComplexMatrix& x)
1553 {
1554 assert( ictxt_ == x.ictxt() );
1555 int ione=1;
1556 assert(m_==x.m());
1557 assert(n_==x.n());
1558 assert(mloc_==x.mloc());
1559 assert(nloc_==x.nloc());
1560 double alpha = -1.0;
1561 int two_size = 2 * size_;
1562 if( active_ )
1563 daxpy(&two_size, &alpha, (double*) x.val, &ione, (double*) val, &ione);
1564 return *this;
1565 }
1566
1567 ////////////////////////////////////////////////////////////////////////////////
1568 // operator*=
operator *=(double alpha)1569 DoubleMatrix& DoubleMatrix::operator*=(double alpha)
1570 {
1571 int ione=1;
1572 if( active_ )
1573 dscal(&size_, &alpha, val, &ione);
1574 return *this;
1575 }
1576
1577 ////////////////////////////////////////////////////////////////////////////////
operator *=(double alpha)1578 ComplexMatrix& ComplexMatrix::operator*=(double alpha)
1579 {
1580 int ione=1;
1581 if( active_ )
1582 zdscal(&size_, &alpha, val, &ione);
1583 return *this;
1584 }
1585
1586 ////////////////////////////////////////////////////////////////////////////////
1587 // operator*=
operator *=(complex<double> alpha)1588 ComplexMatrix& ComplexMatrix::operator*=(complex<double> alpha)
1589 {
1590 int ione=1;
1591 if( active_ )
1592 zscal(&size_, &alpha, val, &ione);
1593 return *this;
1594 }
1595
1596 ////////////////////////////////////////////////////////////////////////////////
1597 // scal
scal(double alpha)1598 void DoubleMatrix::scal(double alpha)
1599 {
1600 *this *= alpha;
1601 }
1602
1603 ////////////////////////////////////////////////////////////////////////////////
1604 // scal
scal(double alpha)1605 void ComplexMatrix::scal(double alpha)
1606 {
1607 *this *= alpha;
1608 }
1609
1610 ////////////////////////////////////////////////////////////////////////////////
1611 // scal
scal(complex<double> alpha)1612 void ComplexMatrix::scal(complex<double> alpha)
1613 {
1614 *this *= alpha;
1615 }
1616
1617 ////////////////////////////////////////////////////////////////////////////////
1618 // matrix multiplication
1619 // this = alpha*op(A)*op(B)+beta*this
1620 ////////////////////////////////////////////////////////////////////////////////
gemm(char transa,char transb,double alpha,const DoubleMatrix & a,const DoubleMatrix & b,double beta)1621 void DoubleMatrix::gemm(char transa, char transb,
1622 double alpha, const DoubleMatrix& a,
1623 const DoubleMatrix& b, double beta)
1624 {
1625 assert( ictxt_ == a.ictxt() );
1626 assert( ictxt_ == b.ictxt() );
1627
1628 if ( active() )
1629 {
1630 int m, n, k;
1631 if ( transa == 'N' || transa == 'n' )
1632 {
1633 m = a.m();
1634 k = a.n();
1635 assert(a.m()==m_);
1636 }
1637 else
1638 {
1639 m = a.n();
1640 k = a.m();
1641 assert(a.n()==m_);
1642 }
1643 if ( transb == 'N' || transb == 'n' )
1644 {
1645 n = b.n();
1646 assert(k==b.m());
1647 }
1648 else
1649 {
1650 n = b.m();
1651 assert(k==b.n());
1652 }
1653
1654 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
1655 {
1656 dgemm(&transa, &transb, &m, &n, &k, &alpha, a.val, &a.lld_,
1657 b.val, &b.lld_, &beta, val, &lld_);
1658 }
1659 else
1660 {
1661 int ione=1;
1662 pdgemm(&transa, &transb, &m, &n, &k, &alpha,
1663 a.val, &ione, &ione, a.desc_,
1664 b.val, &ione, &ione, b.desc_,
1665 &beta, val, &ione, &ione, desc_);
1666 }
1667 }
1668 }
1669
1670 ////////////////////////////////////////////////////////////////////////////////
1671 // complex matrix multiplication
1672 // this = alpha*op(A)*op(B)+beta*this
1673 ////////////////////////////////////////////////////////////////////////////////
gemm(char transa,char transb,complex<double> alpha,const ComplexMatrix & a,const ComplexMatrix & b,complex<double> beta)1674 void ComplexMatrix::gemm(char transa, char transb,
1675 complex<double> alpha, const ComplexMatrix& a,
1676 const ComplexMatrix& b, complex<double> beta)
1677 {
1678 assert( ictxt_ == a.ictxt() );
1679 assert( ictxt_ == b.ictxt() );
1680
1681 if ( active() )
1682 {
1683 int m, n, k;
1684 if ( transa == 'N' || transa == 'n' )
1685 {
1686 m = a.m();
1687 k = a.n();
1688 assert(a.m()==m_);
1689 }
1690 else
1691 {
1692 m = a.n();
1693 k = a.m();
1694 assert(a.n()==m_);
1695 }
1696 if ( transb == 'N' || transb == 'n' )
1697 {
1698 n = b.n();
1699 assert(k==b.m());
1700 }
1701 else
1702 {
1703 n = b.m();
1704 assert(k==b.n());
1705 }
1706
1707 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
1708 {
1709 zgemm(&transa, &transb, &m, &n, &k, &alpha, a.val, &a.lld_,
1710 b.val, &b.lld_, &beta, val, &lld_);
1711 }
1712 else
1713 {
1714 int ione=1;
1715 pzgemm(&transa, &transb, &m, &n, &k, &alpha,
1716 a.val, &ione, &ione, a.desc_,
1717 b.val, &ione, &ione, b.desc_,
1718 &beta, val, &ione, &ione, desc_);
1719 }
1720 }
1721 }
1722
1723 ////////////////////////////////////////////////////////////////////////////////
1724 // symmetric_matrix * matrix multiplication
1725 // this = beta * this + alpha * a * b
1726 // this = beta * this + alpha * b * a
1727 ////////////////////////////////////////////////////////////////////////////////
symm(char side,char uplo,double alpha,const DoubleMatrix & a,const DoubleMatrix & b,double beta)1728 void DoubleMatrix::symm(char side, char uplo,
1729 double alpha, const DoubleMatrix& a,
1730 const DoubleMatrix& b, double beta)
1731 {
1732 assert( ictxt_ == a.ictxt() );
1733 assert( ictxt_ == b.ictxt() );
1734
1735 if ( active() )
1736 {
1737 assert(a.n()==a.m());
1738 if ( side == 'L' || side == 'l' )
1739 {
1740 assert(a.n()==b.m());
1741 }
1742 else
1743 {
1744 assert(a.m()==b.n());
1745 }
1746
1747 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
1748 {
1749 dsymm(&side, &uplo, &m_, &n_, &alpha, a.val, &a.lld_,
1750 b.val, &b.lld_, &beta, val, &lld_);
1751 }
1752 else
1753 {
1754 int ione=1;
1755 pdsymm(&side, &uplo, &m_, &n_, &alpha,
1756 a.val, &ione, &ione, a.desc_,
1757 b.val, &ione, &ione, b.desc_,
1758 &beta, val, &ione, &ione, desc_);
1759 }
1760 }
1761 }
1762
1763 ////////////////////////////////////////////////////////////////////////////////
1764 // hermitian_matrix * matrix multiplication
1765 // this = beta * this + alpha * a * b
1766 // this = beta * this + alpha * b * a
1767 ////////////////////////////////////////////////////////////////////////////////
hemm(char side,char uplo,complex<double> alpha,const ComplexMatrix & a,const ComplexMatrix & b,complex<double> beta)1768 void ComplexMatrix::hemm(char side, char uplo,
1769 complex<double> alpha, const ComplexMatrix& a,
1770 const ComplexMatrix& b, complex<double> beta)
1771 {
1772 assert( ictxt_ == a.ictxt() );
1773 assert( ictxt_ == b.ictxt() );
1774
1775 if ( active() )
1776 {
1777 assert(a.n()==a.m());
1778 if ( side == 'L' || side == 'l' )
1779 {
1780 assert(a.n()==b.m());
1781 }
1782 else
1783 {
1784 assert(a.m()==b.n());
1785 }
1786
1787 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
1788 {
1789 zhemm(&side, &uplo, &m_, &n_, &alpha, a.val, &a.lld_,
1790 b.val, &b.lld_, &beta, val, &lld_);
1791 }
1792 else
1793 {
1794 int ione=1;
1795 pzhemm(&side, &uplo, &m_, &n_, &alpha,
1796 a.val, &ione, &ione, a.desc_,
1797 b.val, &ione, &ione, b.desc_,
1798 &beta, val, &ione, &ione, desc_);
1799 }
1800 }
1801 }
1802
1803 ////////////////////////////////////////////////////////////////////////////////
1804 // complex_symmetric_matrix * matrix multiplication
1805 // this = beta * this + alpha * a * b
1806 // this = beta * this + alpha * b * a
1807 ////////////////////////////////////////////////////////////////////////////////
symm(char side,char uplo,complex<double> alpha,const ComplexMatrix & a,const ComplexMatrix & b,complex<double> beta)1808 void ComplexMatrix::symm(char side, char uplo,
1809 complex<double> alpha, const ComplexMatrix& a,
1810 const ComplexMatrix& b, complex<double> beta)
1811 {
1812 assert( ictxt_ == a.ictxt() );
1813 assert( ictxt_ == b.ictxt() );
1814
1815 if ( active() )
1816 {
1817 assert(a.n()==a.m());
1818 if ( side == 'L' || side == 'l' )
1819 {
1820 assert(a.n()==b.m());
1821 }
1822 else
1823 {
1824 assert(a.m()==b.n());
1825 }
1826
1827 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
1828 {
1829 zsymm(&side, &uplo, &m_, &n_, &alpha, a.val, &a.lld_,
1830 b.val, &b.lld_, &beta, val, &lld_);
1831 }
1832 else
1833 {
1834 int ione=1;
1835 pzsymm(&side, &uplo, &m_, &n_, &alpha,
1836 a.val, &ione, &ione, a.desc_,
1837 b.val, &ione, &ione, b.desc_,
1838 &beta, val, &ione, &ione, desc_);
1839 }
1840 }
1841 }
1842
1843 ////////////////////////////////////////////////////////////////////////////////
1844 // Compute a matrix-matrix product for a real triangular
1845 // matrix or its transpose.
1846 // *this = alpha op(A) * *this if side=='l'
1847 // *this = alpha * *this * op(A) if side=='r'
1848 // where op(A) = A or trans(A)
1849 // alpha is a scalar, *this is an m by n matrix, and A is a unit or non-unit,
1850 // upper- or lower-triangular matrix.
1851 ////////////////////////////////////////////////////////////////////////////////
trmm(char side,char uplo,char trans,char diag,double alpha,const DoubleMatrix & a)1852 void DoubleMatrix::trmm(char side, char uplo, char trans, char diag,
1853 double alpha, const DoubleMatrix& a)
1854 {
1855 if ( active() )
1856 {
1857 assert(a.m_==a.n_);
1858 if ( side=='L' || side=='l' )
1859 {
1860 assert(a.n_==m_);
1861 }
1862 else
1863 {
1864 assert(a.n_==n_);
1865 }
1866 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
1867 {
1868 dtrmm(&side, &uplo, &trans, &diag,
1869 &m_, &n_, &alpha, a.val, &a.m_, val, &m_);
1870 }
1871 else
1872 {
1873 int ione=1;
1874 pdtrmm(&side, &uplo, &trans, &diag, &m_, &n_,
1875 &alpha, a.val, &ione, &ione, a.desc_,
1876 val, &ione, &ione, desc_);
1877 }
1878 }
1879 }
1880
1881 ////////////////////////////////////////////////////////////////////////////////
1882 // Solve op(A) * X = alpha * *this (if side=='l')
1883 // or X * op(A) = alpha * *this (if side=='r')
1884 // where op(A) = A or trans(A)
1885 // alpha is a scalar, *this is an m by n matrix, and A is a unit or non-unit,
1886 // upper- or lower-triangular matrix.
1887 ////////////////////////////////////////////////////////////////////////////////
trsm(char side,char uplo,char trans,char diag,double alpha,const DoubleMatrix & a)1888 void DoubleMatrix::trsm(char side, char uplo, char trans, char diag,
1889 double alpha, const DoubleMatrix& a)
1890 {
1891 if ( active() )
1892 {
1893 assert(a.m_==a.n_);
1894 if ( side=='L' || side=='l' )
1895 {
1896 assert(a.n_==m_);
1897 }
1898 else
1899 {
1900 assert(a.n_==n_);
1901 }
1902 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
1903 {
1904 dtrsm(&side, &uplo, &trans, &diag,
1905 &m_, &n_, &alpha, a.val, &a.m_, val, &m_);
1906 }
1907 else
1908 {
1909 int ione=1;
1910 pdtrsm(&side, &uplo, &trans, &diag, &m_, &n_,
1911 &alpha, a.val, &ione, &ione, a.desc_,
1912 val, &ione, &ione, desc_);
1913 }
1914 }
1915 }
1916
1917 ////////////////////////////////////////////////////////////////////////////////
1918 // Solve op(A) * X = alpha * *this (if side=='l')
1919 // or X * op(A) = alpha * *this (if side=='r')
1920 // where op(A) = A or trans(A)
1921 // alpha is a scalar, *this is an m by n matrix, and A is a unit or non-unit,
1922 // upper- or lower-triangular matrix.
1923 ////////////////////////////////////////////////////////////////////////////////
trsm(char side,char uplo,char trans,char diag,complex<double> alpha,const ComplexMatrix & a)1924 void ComplexMatrix::trsm(char side, char uplo, char trans,
1925 char diag, complex<double> alpha, const ComplexMatrix& a)
1926 {
1927 if ( active() )
1928 {
1929 assert(a.m_==a.n_);
1930 if ( side=='L' || side=='l' )
1931 {
1932 assert(a.n_==m_);
1933 }
1934 else
1935 {
1936 assert(a.n_==n_);
1937 }
1938 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
1939 {
1940 ztrsm(&side, &uplo, &trans, &diag,
1941 &m_, &n_, &alpha, a.val, &a.m_, val, &m_);
1942 }
1943 else
1944 {
1945 int ione=1;
1946 pztrsm(&side, &uplo, &trans, &diag, &m_, &n_,
1947 &alpha, a.val, &ione, &ione, a.desc_,
1948 val, &ione, &ione, desc_);
1949 }
1950 }
1951 }
1952
1953 ////////////////////////////////////////////////////////////////////////////////
1954 // Solves a triangular system of the form A * X = B or
1955 // A**T * X = B, where A is a triangular matrix of order N,
1956 // and B is an N-by-NRHS matrix.
1957 // Output in B.
1958 ////////////////////////////////////////////////////////////////////////////////
trtrs(char uplo,char trans,char diag,DoubleMatrix & b) const1959 void DoubleMatrix::trtrs(char uplo, char trans, char diag,
1960 DoubleMatrix& b) const
1961 {
1962 int info;
1963 if ( active() )
1964 {
1965 assert(m_==n_);
1966
1967 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
1968 {
1969 dtrtrs(&uplo, &trans, &diag, &m_, &b.n_, val, &m_,
1970 b.val, &b.m_, &info);
1971 }
1972 else
1973 {
1974 int ione=1;
1975 pdtrtrs(&uplo, &trans, &diag, &m_, &b.n_,
1976 val, &ione, &ione, desc_,
1977 b.val, &ione, &ione, b.desc_, &info);
1978 }
1979 if(info!=0)
1980 {
1981 cout <<" Matrix::trtrs, info=" << info << endl;
1982 #ifdef USE_MPI
1983 MPI_Abort(MPI_COMM_WORLD, 2);
1984 #else
1985 exit(2);
1986 #endif
1987 }
1988 }
1989 }
1990
1991 ////////////////////////////////////////////////////////////////////////////////
1992 // Solves a triangular system of the form A * X = B or
1993 // A**T * X = B, where A is a triangular matrix of order N,
1994 // and B is an N-by-NRHS matrix.
1995 // Output in B.
1996 ////////////////////////////////////////////////////////////////////////////////
trtrs(char uplo,char trans,char diag,ComplexMatrix & b) const1997 void ComplexMatrix::trtrs(char uplo, char trans, char diag,
1998 ComplexMatrix& b) const
1999 {
2000 int info;
2001 if ( active() )
2002 {
2003 assert(m_==n_);
2004
2005 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2006 {
2007 ztrtrs(&uplo, &trans, &diag, &m_, &b.n_, val, &m_,
2008 b.val, &b.m_, &info);
2009 }
2010 else
2011 {
2012 int ione=1;
2013 pztrtrs(&uplo, &trans, &diag, &m_, &b.n_,
2014 val, &ione, &ione, desc_,
2015 b.val, &ione, &ione, b.desc_, &info);
2016 }
2017 if(info!=0)
2018 {
2019 cout <<" ComplexMatrix::trtrs, info=" << info << endl;
2020 #ifdef USE_MPI
2021 MPI_Abort(MPI_COMM_WORLD, 2);
2022 #else
2023 exit(2);
2024 #endif
2025 }
2026 }
2027 }
2028
2029 ////////////////////////////////////////////////////////////////////////////////
2030 // LU decomposition of a double matrix
2031 ////////////////////////////////////////////////////////////////////////////////
lu(valarray<int> & ipiv)2032 void DoubleMatrix::lu(valarray<int>& ipiv)
2033 {
2034 int info;
2035 if ( active() )
2036 {
2037 assert(m_==n_);
2038 ipiv.resize(mloc_+mb_);
2039
2040 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2041 {
2042 dgetrf(&m_, &n_, val, &m_, &ipiv[0], &info);
2043 }
2044 else
2045 {
2046 int ione=1;
2047 pdgetrf(&m_, &n_, val, &ione, &ione, desc_, &ipiv[0], &info);
2048 }
2049 if(info!=0)
2050 {
2051 cout << " DoubleMatrix::lu, info=" << info << endl;
2052 MPI_Abort(MPI_COMM_WORLD, 2);
2053 }
2054 }
2055 }
2056
2057 ////////////////////////////////////////////////////////////////////////////////
2058 // LU decomposition of a complex matrix
2059 ////////////////////////////////////////////////////////////////////////////////
lu(valarray<int> & ipiv)2060 void ComplexMatrix::lu(valarray<int>& ipiv)
2061 {
2062 int info;
2063 if ( active() )
2064 {
2065 assert(m_==n_);
2066 ipiv.resize(mloc_+mb_);
2067
2068 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2069 {
2070 zgetrf(&m_, &n_, val, &m_, &ipiv[0], &info);
2071 }
2072 else
2073 {
2074 int ione=1;
2075 pzgetrf(&m_, &n_, val, &ione, &ione, desc_, &ipiv[0], &info);
2076 }
2077 if(info!=0)
2078 {
2079 cout << " ComplexMatrix::lu, info=" << info << endl;
2080 #ifdef USE_MPI
2081 MPI_Abort(MPI_COMM_WORLD, 2);
2082 #else
2083 exit(2);
2084 #endif
2085 }
2086 }
2087 }
2088
2089 ////////////////////////////////////////////////////////////////////////////////
2090 // inverse of a square double matrix
2091 ////////////////////////////////////////////////////////////////////////////////
inverse(void)2092 void DoubleMatrix::inverse(void)
2093 {
2094 if ( active() )
2095 {
2096 assert(m_==n_);
2097 valarray<int> ipiv(mloc_+mb_);
2098 // LU decomposition
2099 lu(ipiv);
2100 inverse_from_lu(ipiv);
2101 }
2102 }
2103
2104 ////////////////////////////////////////////////////////////////////////////////
2105 // determinant of a square double matrix in LU form
2106 ////////////////////////////////////////////////////////////////////////////////
det_from_lu(valarray<int> ipiv)2107 double DoubleMatrix::det_from_lu(valarray<int> ipiv)
2108 {
2109 if ( active() )
2110 {
2111 assert(m_==n_);
2112
2113 // compute determinant
2114 valarray<double> diag(n_);
2115 for ( int ii = 0; ii < n_; ii++ )
2116 {
2117 int iii = l(ii) * mb_ + x(ii);
2118 int jjj = m(ii) * nb_ + y(ii);
2119 if ( pr(ii) == ctxt_.myrow()
2120 && pc(ii) == ctxt_.mycol() )
2121 diag[ii] = val[iii+mloc_*jjj];
2122 }
2123 ctxt_.dsum(n_,1,(double*)&diag[0],n_);
2124
2125 double det = 1.0;
2126 for ( int ii = 0; ii < n_; ii++ )
2127 det *= diag[ii];
2128 det *= signature(ipiv);
2129
2130 return det;
2131 }
2132 return 0.0;
2133 }
2134
2135 ////////////////////////////////////////////////////////////////////////////////
2136 // inverse and determinant of a square double matrix
2137 ////////////////////////////////////////////////////////////////////////////////
inverse_det(void)2138 double DoubleMatrix::inverse_det(void)
2139 {
2140 if ( active() )
2141 {
2142 assert(m_==n_);
2143 valarray<int> ipiv(mloc_+mb_);
2144 lu(ipiv);
2145 double det = det_from_lu(ipiv);
2146 inverse_from_lu(ipiv);
2147 return det;
2148 }
2149 return 0.0;
2150 }
2151
2152 ////////////////////////////////////////////////////////////////////////////////
2153 // inverse from an LU decomposed square matrix
2154 ////////////////////////////////////////////////////////////////////////////////
inverse_from_lu(valarray<int> & ipiv)2155 void DoubleMatrix::inverse_from_lu(valarray<int>& ipiv)
2156 {
2157 int info;
2158 if ( active() )
2159 {
2160 assert(m_==n_);
2161
2162 // Compute inverse using LU decomposition and array ipiv
2163 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2164 {
2165 valarray<double> work(1);
2166 int lwork = -1;
2167 // First call to compute optimal size of work array, returned in work[0]
2168 dgetri(&m_, val, &m_, &ipiv[0], &work[0], &lwork, &info);
2169 lwork = (int) work[0] + 1;
2170 work.resize(lwork);
2171 dgetri(&m_, val, &m_, &ipiv[0], &work[0], &lwork, &info);
2172 }
2173 else
2174 {
2175 valarray<double> work(1);
2176 valarray<int> iwork(1);
2177 int lwork = -1;
2178 int liwork = -1;
2179 int ione = 1;
2180 // First call to compute dimensions of work arrays lwork and liwork
2181 // dimensions are returned in work[0] and iwork[0];
2182 pdgetri(&n_, val, &ione, &ione, desc_, &ipiv[0],
2183 &work[0], &lwork, &iwork[0], &liwork, &info);
2184 lwork = (int) work[0] + 1;
2185 liwork = iwork[0];
2186 work.resize(lwork);
2187 iwork.resize(liwork);
2188
2189 // Compute inverse
2190 pdgetri(&n_, val, &ione, &ione, desc_, &ipiv[0],
2191 &work[0], &lwork, &iwork[0], &liwork, &info);
2192 }
2193 if(info!=0)
2194 {
2195 cout << " DoubleMatrix::inverse_from_lu, info(getri)=" << info << endl;
2196 MPI_Abort(MPI_COMM_WORLD, 2);
2197 }
2198 }
2199 }
2200
2201 ////////////////////////////////////////////////////////////////////////////////
2202 // inverse of a complex square matrix
2203 ////////////////////////////////////////////////////////////////////////////////
inverse(void)2204 void ComplexMatrix::inverse(void)
2205 {
2206 valarray<int> ipiv;
2207 lu(ipiv);
2208 inverse_from_lu(ipiv);
2209 }
2210
2211 ////////////////////////////////////////////////////////////////////////////////
2212 // determinant of a complex square matrix in LU form
2213 ////////////////////////////////////////////////////////////////////////////////
det_from_lu(valarray<int> ipiv)2214 complex<double> ComplexMatrix::det_from_lu(valarray<int> ipiv)
2215 {
2216 if ( active() )
2217 {
2218 assert(m_==n_);
2219
2220 // compute determinant
2221 valarray<complex<double> > diag(n_);
2222 for ( int ii = 0; ii < n_; ii++ )
2223 {
2224 int iii = l(ii) * mb_ + x(ii);
2225 int jjj = m(ii) * nb_ + y(ii);
2226 if ( pr(ii) == ctxt_.myrow() && pc(ii) == ctxt_.mycol() )
2227 diag[ii] = val[iii+mloc_*jjj];
2228 }
2229 ctxt_.dsum(n_*2,1,(double*)&diag[0],n_*2);
2230
2231 complex<double> det = 1.0;
2232 for ( int ii = 0; ii < n_; ii++ )
2233 det *= diag[ii];
2234 det *= signature(ipiv);
2235
2236 return det;
2237 }
2238 return complex<double>(0.0,0.0);
2239 }
2240
2241 ////////////////////////////////////////////////////////////////////////////////
2242 // inverse and determinant of a complex square matrix
2243 ////////////////////////////////////////////////////////////////////////////////
inverse_det(void)2244 complex<double> ComplexMatrix::inverse_det(void)
2245 {
2246 if ( active() )
2247 {
2248 assert(m_==n_);
2249 valarray<int> ipiv(mloc_+mb_);
2250 lu(ipiv);
2251 complex<double> det = det_from_lu(ipiv);
2252 inverse_from_lu(ipiv);
2253 return det;
2254 }
2255 return complex<double>(0.0,0.0);
2256 }
2257
2258 ////////////////////////////////////////////////////////////////////////////////
2259 // compute inverse of an LU decomposed matrix
inverse_from_lu(valarray<int> & ipiv)2260 void ComplexMatrix::inverse_from_lu(valarray<int>& ipiv)
2261 {
2262 // it is assumed that the current matrix is LU decomposed
2263 int info;
2264 if ( active() )
2265 {
2266 assert(m_==n_);
2267 // Compute inverse using LU decomposition and array ipiv computed in lu()
2268 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2269 {
2270 valarray< complex<double> > work(1);
2271 int lwork = -1;
2272 // First call to compute optimal size of work array, returned in work[0]
2273 zgetri(&m_, val, &m_, &ipiv[0], &work[0], &lwork, &info);
2274 lwork = (int) work[0].real() + 1;
2275 work.resize(lwork);
2276 zgetri(&m_, val, &m_, &ipiv[0], &work[0], &lwork, &info);
2277 }
2278 else
2279 {
2280 valarray< complex<double> > work(1);
2281 valarray<int> iwork(1);
2282 int lwork = -1;
2283 int liwork = -1;
2284 int ione = 1;
2285 pzgetri(&n_, val, &ione, &ione, desc_, &ipiv[0],
2286 &work[0], &lwork, &iwork[0], &liwork, &info);
2287 lwork = (int) work[0].real() + 1;
2288 liwork = iwork[0];
2289 work.resize(lwork);
2290 iwork.resize(liwork);
2291
2292 // Compute inverse
2293 pzgetri(&n_, val, &ione, &ione, desc_, &ipiv[0],
2294 &work[0], &lwork, &iwork[0], &liwork, &info);
2295 }
2296 if(info!=0)
2297 {
2298 cout << " ComplexMatrix::inverse, info(getri)=" << info << endl;
2299 MPI_Abort(MPI_COMM_WORLD, 2);
2300 }
2301 }
2302 }
2303
2304 ////////////////////////////////////////////////////////////////////////////////
2305 // Real Cholesky factorization of a
2306 // symmetric positive definite distributed matrix
2307 ////////////////////////////////////////////////////////////////////////////////
potrf(char uplo)2308 void DoubleMatrix::potrf(char uplo)
2309 {
2310 int info;
2311 if ( active() )
2312 {
2313 assert(m_==n_);
2314
2315 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2316 {
2317 dpotrf(&uplo, &m_, val, &m_, &info);
2318 }
2319 else
2320 {
2321 int ione=1;
2322 pdpotrf(&uplo, &m_, val, &ione, &ione, desc_, &info);
2323 }
2324 if(info!=0)
2325 {
2326 cout << " DoubleMatrix::potrf, info=" << info << endl;
2327 MPI_Abort(MPI_COMM_WORLD, 2);
2328 }
2329 }
2330 }
2331
2332 ////////////////////////////////////////////////////////////////////////////////
2333 // Complex Cholesky factorization of a
2334 // hermitian positive definite distributed matrix
2335 ////////////////////////////////////////////////////////////////////////////////
potrf(char uplo)2336 void ComplexMatrix::potrf(char uplo)
2337 {
2338 int info;
2339 if ( active() )
2340 {
2341 assert(m_==n_);
2342
2343 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2344 {
2345 zpotrf(&uplo, &m_, val, &m_, &info);
2346 }
2347 else
2348 {
2349 int ione=1;
2350 pzpotrf(&uplo, &m_, val, &ione, &ione, desc_, &info);
2351 }
2352 if(info!=0)
2353 {
2354 cout << " ComplexMatrix::potrf, info=" << info << endl;
2355 MPI_Abort(MPI_COMM_WORLD, 2);
2356 }
2357 }
2358 }
2359
2360 ////////////////////////////////////////////////////////////////////////////////
2361 // Compute the inverse of a real symmetric positive definite matrix
2362 // using the Cholesky factorization A = U**T*U or A = L*L**T computed
2363 // by DoubleMatrix::potrf
2364 ////////////////////////////////////////////////////////////////////////////////
potri(char uplo)2365 void DoubleMatrix::potri(char uplo)
2366 {
2367 int info;
2368 if ( active() )
2369 {
2370 assert(m_==n_);
2371
2372 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2373 {
2374 dpotri(&uplo, &m_, val, &m_, &info);
2375 }
2376 else
2377 {
2378 int ione=1;
2379 pdpotri(&uplo, &m_, val, &ione, &ione, desc_, &info);
2380 }
2381 if(info!=0)
2382 {
2383 cout << " Matrix::potri, info=" << info << endl;
2384 MPI_Abort(MPI_COMM_WORLD, 2);
2385 }
2386 }
2387 }
2388
2389 ////////////////////////////////////////////////////////////////////////////////
2390 // Inverse of a triangular matrix
2391 ////////////////////////////////////////////////////////////////////////////////
trtri(char uplo,char diag)2392 void DoubleMatrix::trtri(char uplo, char diag)
2393 {
2394 int info;
2395 if ( active() )
2396 {
2397 assert(m_==n_);
2398
2399 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2400 {
2401 dtrtri(&uplo, &diag, &m_, val, &m_, &info);
2402 }
2403 else
2404 {
2405 int ione=1;
2406 pdtrtri(&uplo, &diag, &m_, val, &ione, &ione, desc_, &info);
2407 }
2408 if(info!=0)
2409 {
2410 cout << " Matrix::trtri, info=" << info << endl;
2411 MPI_Abort(MPI_COMM_WORLD, 2);
2412 }
2413 }
2414 }
2415
2416 ////////////////////////////////////////////////////////////////////////////////
trtri(char uplo,char diag)2417 void ComplexMatrix::trtri(char uplo, char diag)
2418 {
2419 int info;
2420 if ( active() )
2421 {
2422 assert(m_==n_);
2423
2424 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2425 {
2426 ztrtri(&uplo, &diag, &m_, val, &m_, &info);
2427 }
2428 else
2429 {
2430 int ione=1;
2431 pztrtri(&uplo, &diag, &m_, val, &ione, &ione, desc_, &info);
2432 }
2433 if(info!=0)
2434 {
2435 cout << " Matrix::trtri, info=" << info << endl;
2436 MPI_Abort(MPI_COMM_WORLD, 2);
2437 }
2438 }
2439 }
2440
2441 ////////////////////////////////////////////////////////////////////////////////
2442 // Polar decomposition A = UH
2443 // Replace *this with its orthogonal polar factor U
2444 // return when iter > maxiter or ||I - X^T*X|| < tol
2445 ////////////////////////////////////////////////////////////////////////////////
polar(double tol,int maxiter)2446 void DoubleMatrix::polar(double tol, int maxiter)
2447 {
2448 DoubleMatrix x(ctxt_,m_,n_,mb_,nb_);
2449 DoubleMatrix xp(ctxt_,m_,n_,mb_,nb_);
2450
2451 DoubleMatrix q(ctxt_,n_,n_,nb_,nb_);
2452 DoubleMatrix qt(ctxt_,n_,n_,nb_,nb_);
2453 DoubleMatrix t(ctxt_,n_,n_,nb_,nb_);
2454
2455 double qnrm2 = numeric_limits<double>::max();
2456 int iter = 0;
2457 x = *this;
2458 while ( iter < maxiter && qnrm2 > tol )
2459 {
2460 // q = I - x^T * x
2461 q.identity();
2462 q.syrk('l','t',-1.0,x,1.0);
2463 q.symmetrize('l');
2464
2465 double qnrm2 = q.nrm2();
2466 #ifdef DEBUG
2467 if ( ctxt_.onpe0() )
2468 cout << " DoubleMatrix::polar: qnrm2 = " << qnrm2 << endl;
2469 #endif
2470
2471 // choose Bjork-Bowie or Higham iteration depending on q.nrm2
2472
2473 // threshold value
2474 // see A. Bjork and C. Bowie, SIAM J. Num. Anal. 8, 358 (1971) p.363
2475 if ( qnrm2 < 1.0 )
2476 {
2477 // Bjork-Bowie iteration
2478 // compute xp = x * ( I + 0.5*q * ( I + 0.75 * q ) )
2479
2480 // t = ( I + 0.75 * q )
2481 t.identity();
2482 t.axpy(0.75,q);
2483
2484 // compute q*t
2485 qt.gemm('n','n',1.0,q,t,0.0);
2486
2487 // xp = x * ( I + 0.5*q * ( I + 0.75 * q ) )
2488 // = x * ( I + 0.5 * qt )
2489 // Use t to store (I + 0.5 * qt)
2490 t.identity();
2491 t.axpy(0.5,qt);
2492
2493 // t now contains (I + 0.5 * qt)
2494 // xp = x * t
2495 xp.gemm('n','n',1.0,x,t,0.0);
2496
2497 // update x
2498 x = xp;
2499 }
2500 else
2501 {
2502 // Higham iteration
2503 assert(m_==n_);
2504 //if ( ctxt_.onpe0() )
2505 // cout << " DoubleMatrix::polar: using Higham algorithm" << endl;
2506 // t = X^T
2507 t.transpose(1.0,x,0.0);
2508 t.inverse();
2509 // t now contains X^-T
2510 // xp = 0.5 * ( x + x^-T );
2511 for ( int i = 0; i < x.size(); i++ )
2512 x[i] = 0.5 * ( x[i] + t[i] );
2513 }
2514 iter++;
2515 }
2516 *this = x;
2517 }
2518
2519 ////////////////////////////////////////////////////////////////////////////////
2520 // Polar decomposition A = UH (complex case)
2521 // Replace *this with its unitary polar factor U
2522 // return when iter > maxiter or ||I - X^H*X|| < tol
2523 ////////////////////////////////////////////////////////////////////////////////
polar(double tol,int maxiter)2524 void ComplexMatrix::polar(double tol, int maxiter)
2525 {
2526 ComplexMatrix x(ctxt_,m_,n_,mb_,nb_);
2527 ComplexMatrix xp(ctxt_,m_,n_,mb_,nb_);
2528
2529 ComplexMatrix q(ctxt_,n_,n_,nb_,nb_);
2530 ComplexMatrix qt(ctxt_,n_,n_,nb_,nb_);
2531 ComplexMatrix t(ctxt_,n_,n_,nb_,nb_);
2532
2533 double qnrm2 = numeric_limits<double>::max();
2534 int iter = 0;
2535 x = *this;
2536 while ( iter < maxiter && qnrm2 > tol )
2537 {
2538 // q = I - x^T * x
2539 q.identity();
2540 q.herk('l','c',-1.0,x,1.0);
2541 q.symmetrize('l');
2542
2543 double qnrm2 = q.nrm2();
2544 #ifdef DEBUG
2545 if ( ctxt_.onpe0() )
2546 cout << " ComplexMatrix::polar: qnrm2 = " << qnrm2 << endl;
2547 #endif
2548
2549 // choose Bjork-Bowie or Higham iteration depending on q.nrm2
2550
2551 // threshold value
2552 // see A. Bjork and C. Bowie, SIAM J. Num. Anal. 8, 358 (1971) p.363
2553 if ( qnrm2 < 1.0 )
2554 {
2555 // Bjork-Bowie iteration
2556 // compute xp = x * ( I + 0.5*q * ( I + 0.75 * q ) )
2557
2558 // t = ( I + 0.75 * q )
2559 t.identity();
2560 t.axpy(0.75,q);
2561
2562 // compute q*t
2563 qt.gemm('n','n',1.0,q,t,0.0);
2564
2565 // xp = x * ( I + 0.5*q * ( I + 0.75 * q ) )
2566 // = x * ( I + 0.5 * qt )
2567 // Use t to store (I + 0.5 * qt)
2568 t.identity();
2569 t.axpy(0.5,qt);
2570
2571 // t now contains (I + 0.5 * qt)
2572 // xp = x * t
2573 xp.gemm('n','n',1.0,x,t,0.0);
2574
2575 // update x
2576 x = xp;
2577 }
2578 else
2579 {
2580 // Higham iteration
2581 assert(m_==n_);
2582 //if ( ctxt_.onpe0() )
2583 // cout << " ComplexMatrix::polar: using Higham algorithm" << endl;
2584 // t = X^H
2585 t.transpose(1.0,x,0.0);
2586 t.inverse();
2587 // t now contains X^-H
2588 // xp = 0.5 * ( x + x^-H );
2589 for ( int i = 0; i < x.size(); i++ )
2590 x[i] = 0.5 * ( x[i] + t[i] );
2591 }
2592 iter++;
2593 }
2594 *this = x;
2595 }
2596
2597 ////////////////////////////////////////////////////////////////////////////////
2598 // estimate the reciprocal of the condition number (in the 1-norm) of a
2599 // real symmetric positive definite matrix using the Cholesky factorization
2600 // A = U**T*U or A = L*L**T computed by DoubleMatrix::potrf
2601 ////////////////////////////////////////////////////////////////////////////////
pocon(char uplo) const2602 double DoubleMatrix::pocon(char uplo) const
2603 {
2604 int info;
2605 double rcond=1.;
2606 double anorm=1.;
2607 if ( active() )
2608 {
2609 assert(m_==n_);
2610
2611 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2612 {
2613 double* work=new double[3*m_];
2614 int* iwork=new int[m_];
2615 dpocon(&uplo, &m_, val, &m_, &anorm, &rcond, work, iwork, &info);
2616 delete[] iwork;
2617 delete[] work;
2618 }
2619 else
2620 {
2621 int ione=1;
2622 int lwork=2*mloc_+3*nloc_+nb_;
2623 int liwork=mloc_;
2624 double* work=new double[lwork];
2625 int* iwork=new int[liwork];
2626 pdpocon(&uplo, &m_, val, &ione, &ione, desc_,
2627 &anorm, &rcond, work, &lwork, iwork, &liwork, &info);
2628 if (info!=0)
2629 {
2630 cout << "DoubleMatrix::pocon: lwork=" << lwork
2631 << ", but should be at least " << work[0] << endl;
2632 cout << "DoubleMatrix::pocon: liwork=" << liwork
2633 << ", but should be at least " << iwork[0] << endl;
2634 }
2635 delete[] iwork;
2636 delete[] work;
2637 }
2638 if(info!=0)
2639 {
2640 cout << " Matrix::pocon, info=" << info << endl;
2641 MPI_Abort(MPI_COMM_WORLD, 2);
2642 }
2643 }
2644 return rcond;
2645 }
2646
2647 ////////////////////////////////////////////////////////////////////////////////
2648 // symmetric rank k update
2649 // this = beta * this + alpha * A * A^T (trans=='n')
2650 // this = beta * this + alpha * A^T * A (trans=='t')
2651 ////////////////////////////////////////////////////////////////////////////////
syrk(char uplo,char trans,double alpha,const DoubleMatrix & a,double beta)2652 void DoubleMatrix::syrk(char uplo, char trans,
2653 double alpha, const DoubleMatrix& a, double beta)
2654 {
2655 assert( ictxt_ == a.ictxt() );
2656 assert( n_ == m_ ); // *this must be a square matrix
2657
2658 if ( active() )
2659 {
2660 int n, k;
2661 if ( trans == 'N' || trans == 'n' )
2662 {
2663 n = m_;
2664 k = a.n();
2665 }
2666 else
2667 {
2668 n = m_;
2669 k = a.m();
2670 }
2671
2672 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2673 {
2674 dsyrk(&uplo, &trans, &n, &k, &alpha, a.val, &a.m_, &beta, val, &m_);
2675 }
2676 else
2677 {
2678 int ione = 1;
2679 pdsyrk(&uplo, &trans, &n, &k, &alpha,
2680 a.val, &ione, &ione, a.desc_,
2681 &beta, val, &ione, &ione, desc_);
2682 }
2683 }
2684 }
2685
2686 ////////////////////////////////////////////////////////////////////////////////
2687 // hermitian rank k update
2688 // this = beta * this + alpha * A * A^H (trans=='n')
2689 // this = beta * this + alpha * A^H * A (trans=='c')
2690 ////////////////////////////////////////////////////////////////////////////////
herk(char uplo,char trans,double alpha,const ComplexMatrix & a,double beta)2691 void ComplexMatrix::herk(char uplo, char trans,
2692 double alpha, const ComplexMatrix& a, double beta)
2693 {
2694 assert( ictxt_ == a.ictxt() );
2695 assert( n_ == m_ ); // *this must be a square matrix
2696
2697 if ( active() )
2698 {
2699 int n, k;
2700 if ( trans == 'N' || trans == 'n' )
2701 {
2702 n = m_;
2703 k = a.n();
2704 }
2705 else if ( trans == 'C' || trans == 'c' )
2706 {
2707 n = m_;
2708 k = a.m();
2709 }
2710 else
2711 {
2712 cout << " Matrix::herk: invalid parameter trans" << endl;
2713 MPI_Abort(MPI_COMM_WORLD,2);
2714 }
2715
2716 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2717 {
2718 zherk(&uplo, &trans, &n, &k, &alpha, a.val, &a.m_,
2719 &beta, val, &m_);
2720 }
2721 else
2722 {
2723 int ione = 1;
2724 pzherk(&uplo, &trans, &n, &k, &alpha,
2725 a.val, &ione, &ione, a.desc_,
2726 &beta, val, &ione, &ione, desc_);
2727 }
2728 }
2729 }
2730
2731 ////////////////////////////////////////////////////////////////////////////////
2732 //
2733 // Generate a duplicated matrix from a distributed matrix
2734 //
2735 ////////////////////////////////////////////////////////////////////////////////
matgather(double * a,int lda) const2736 void DoubleMatrix::matgather(double *a, int lda) const
2737 {
2738 if ( active_ )
2739 {
2740 memset(a,0,lda*n_*sizeof(double));
2741
2742 if ( active_ )
2743 {
2744 for ( int li=0; li<mblocks(); li++)
2745 {
2746 for ( int lj=0; lj<nblocks(); lj++)
2747 {
2748 for ( int ii=0; ii<mbs(li); ii++)
2749 {
2750 for ( int jj=0; jj<nbs(lj); jj++)
2751 {
2752 assert(i(li,ii)<lda);
2753 assert((ii+li*mb_)<mloc_);
2754 assert((jj+lj*nb_)<nloc_);
2755 a[ i(li,ii) + j(lj,jj)*lda ]
2756 = val[ (ii+li*mb_)+(jj+lj*nb_)*mloc_ ];
2757 }
2758 }
2759 }
2760 }
2761 }
2762
2763 int max_size=200000;
2764 int size = lda*n_;
2765 int ione=1;
2766 int nblocks = size / max_size;
2767 int sizer = (size % max_size);
2768 double* work = new double[max_size];
2769
2770 double *ptra = a;
2771 for ( int steps = 0; steps < nblocks; steps++ )
2772 {
2773 MPI_Allreduce(ptra, work, max_size, MPI_DOUBLE, MPI_SUM, ctxt_.comm() );
2774 dcopy(&max_size, work, &ione, ptra, &ione);
2775 ptra += max_size;
2776 }
2777 if ( sizer != 0 )
2778 {
2779 MPI_Allreduce(ptra, work, sizer, MPI_DOUBLE, MPI_SUM, ctxt_.comm() );
2780 dcopy(&sizer, work, &ione, ptra, &ione);
2781 }
2782
2783 delete[] work;
2784 }
2785 }
2786
2787
2788 ////////////////////////////////////////////////////////////////////////////////
2789 // initdiag: initialize diagonal elements using a replicated array a[i]
initdiag(const double * const dmat)2790 void DoubleMatrix::initdiag(const double* const dmat)
2791 {
2792 if ( active() )
2793 {
2794 // initialize diagonal elements
2795 if ( active() )
2796 {
2797 // loop through all local blocks (ll,mm)
2798 for ( int ll = 0; ll < mblocks(); ll++)
2799 {
2800 for ( int mm = 0; mm < nblocks(); mm++)
2801 {
2802 // check if block (ll,mm) has diagonal elements
2803 int imin = i(ll,0);
2804 int imax = imin + mbs(ll)-1;
2805 int jmin = j(mm,0);
2806 int jmax = jmin + nbs(mm)-1;
2807 // cout << " process (" << myrow_ << "," << mycol_ << ")"
2808 // << " block (" << ll << "," << mm << ")"
2809 // << " imin/imax=" << imin << "/" << imax
2810 // << " jmin/jmax=" << jmin << "/" << jmax << endl;
2811
2812 if ((imin <= jmax) && (imax >= jmin))
2813 {
2814 // block (ll,mm) holds diagonal elements
2815 int idiagmin = max(imin,jmin);
2816 int idiagmax = min(imax,jmax);
2817
2818 // cout << " process (" << myrow_ << "," << mycol_ << ")"
2819 // << " holds diagonal elements " << idiagmin << " to " <<
2820 // idiagmax << " in block (" << ll << "," << mm << ")" << endl;
2821
2822 for ( int ii = idiagmin; ii <= idiagmax; ii++ )
2823 {
2824 // access element (ii,ii)
2825 int jj = ii;
2826 int iii = ll * mb_ + x(ii);
2827 int jjj = mm * nb_ + y(jj);
2828 val[iii+mloc_*jjj] = dmat[ii];
2829 }
2830 }
2831 }
2832 }
2833 }
2834 }
2835 }
2836
2837 ////////////////////////////////////////////////////////////////////////////////
2838 // trace
trace(void) const2839 double DoubleMatrix::trace(void) const
2840 {
2841 assert(m_==n_);
2842
2843 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2844 {
2845 double trace = 0.0;
2846 for ( int i = 0; i < n_; i++ )
2847 trace += val[i*m_];
2848 return trace;
2849 }
2850 else
2851 {
2852 int ione=1;
2853 return pdlatra(&n_,val,&ione,&ione,desc_);
2854 }
2855 }
2856
2857 ////////////////////////////////////////////////////////////////////////////////
2858 // Reduces a real symmetric-definite generalized eigenproblem to standard
2859 // form. If itype = 1, the problem is A*x = lambda*B*x,
2860 // and A (=*this) is overwritten by inv(U**T)*A*inv(U) or inv(L)*A*inv(L**T)
2861 // If itype = 2 or 3, the problem is A*B*x = lambda*x or
2862 // B*A*x = lambda*x, and *this is overwritten by U*A*U**T or L**T*A*L.
2863 // B must have been previously factorized as U**T*U or L*L**T by
2864 // DoubleMatrix::dpotrf.
2865 ////////////////////////////////////////////////////////////////////////////////
sygst(int itype,char uplo,const DoubleMatrix & b)2866 void DoubleMatrix::sygst(int itype, char uplo, const DoubleMatrix& b)
2867 {
2868 int info;
2869 if ( active_ )
2870 {
2871 assert(m_==n_);
2872
2873 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2874 {
2875 dsygst(&itype, &uplo, &m_, val, &m_, b.val, &b.m_, &info);
2876 }
2877 else
2878 {
2879 int ione=1;
2880 double scale;
2881 pdsygst(&itype, &uplo, &m_, val, &ione, &ione, desc_,
2882 b.val, &ione, &ione, b.desc_, &scale, &info);
2883 }
2884 if ( info != 0 )
2885 {
2886 cout << " Matrix::sygst, info=" << info << endl;
2887 MPI_Abort(MPI_COMM_WORLD, 2);
2888 }
2889 }
2890 }
2891
2892 ////////////////////////////////////////////////////////////////////////////////
2893 // compute eigenvalues and eigenvectors of *this
2894 // store eigenvalues in w, eigenvectors in z
syev(char uplo,valarray<double> & w,DoubleMatrix & z)2895 void DoubleMatrix::syev(char uplo, valarray<double>& w, DoubleMatrix& z)
2896 {
2897 int info;
2898 if ( active_ )
2899 {
2900 assert(m_==n_);
2901 char jobz = 'V';
2902 double* work;
2903 int lwork;
2904 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2905 {
2906 lwork=-1;
2907 double tmplwork;
2908 dsyev(&jobz, &uplo, &m_, z.val, &m_, &w[0], &tmplwork, &lwork, &info);
2909 lwork = (int) tmplwork + 1;
2910 work = new double[lwork];
2911 z = *this;
2912 dsyev(&jobz, &uplo, &m_, z.val, &m_, &w[0], work, &lwork, &info);
2913 }
2914 else
2915 {
2916 int ione=1;
2917 lwork=-1;
2918 double tmpwork;
2919 pdsyev(&jobz, &uplo, &m_, val, &ione, &ione, desc_, &w[0],
2920 z.val, &ione, &ione, z.desc_, &tmpwork, &lwork,
2921 &info);
2922 lwork = (int) (tmpwork + 1);
2923 // set lwork to max value among all tasks
2924 ctxt_.imax(1,1,&lwork,1);
2925 work=new double[lwork];
2926 pdsyev(&jobz, &uplo, &m_, val, &ione, &ione, desc_, &w[0],
2927 z.val, &ione, &ione, z.desc_, work, &lwork,
2928 &info);
2929 MPI_Bcast(&w[0], m_, MPI_DOUBLE, 0, ctxt_.comm());
2930 }
2931 if ( info != 0 )
2932 {
2933 cout << " Matrix::syev requires lwork>=" << work[0] << endl;
2934 cout << " Matrix::syev, lwork>=" << lwork << endl;
2935 cout << " Matrix::syev, info=" << info<< endl;
2936 MPI_Abort(MPI_COMM_WORLD, 2);
2937 }
2938 delete[] work;
2939 }
2940 }
2941
2942 ////////////////////////////////////////////////////////////////////////////////
2943 // compute eigenvalues and eigenvectors of *this
2944 // store eigenvalues in w, eigenvectors in z
2945 // using the divide and conquer algorithm of Tisseur and Dongarra
syevd(char uplo,valarray<double> & w,DoubleMatrix & z)2946 void DoubleMatrix::syevd(char uplo, valarray<double>& w, DoubleMatrix& z)
2947 {
2948 int info;
2949 if ( active_ )
2950 {
2951 assert(m_==n_);
2952 char jobz = 'V';
2953 int lwork;
2954 double* work;
2955 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
2956 {
2957 lwork=-1;
2958 double tmplwork;
2959 dsyev(&jobz, &uplo, &m_, z.val, &m_, &w[0], &tmplwork, &lwork, &info);
2960 lwork = (int) tmplwork + 1;
2961 work = new double[lwork];
2962 z = *this;
2963 dsyev(&jobz, &uplo, &m_, z.val, &m_, &w[0], work, &lwork, &info);
2964 }
2965 else
2966 {
2967 int ione=1;
2968 lwork=-1;
2969 double tmpwork;
2970 int liwork=-1;
2971 int tmpiwork;
2972 pdsyevd(&jobz, &uplo, &m_, val, &ione, &ione, desc_, &w[0],
2973 z.val, &ione, &ione, z.desc_, &tmpwork, &lwork,
2974 &tmpiwork, &liwork, &info);
2975 lwork = (int) (tmpwork + 1);
2976 // set lwork to max value among all tasks
2977 ctxt_.imax(1,1,&lwork,1);
2978 work=new double[lwork];
2979 liwork = tmpiwork;
2980 int* iwork = new int[liwork];
2981 pdsyevd(&jobz, &uplo, &m_, val, &ione, &ione, desc_, &w[0],
2982 z.val, &ione, &ione, z.desc_, work, &lwork, iwork, &liwork, &info);
2983 MPI_Bcast(&w[0], m_, MPI_DOUBLE, 0, ctxt_.comm());
2984 delete[] iwork;
2985 }
2986 if ( info != 0 )
2987 {
2988 cout << " Matrix::syev requires lwork>=" << work[0] << endl;
2989 cout << " Matrix::syev, lwork>=" << lwork << endl;
2990 cout << " Matrix::syev, info=" << info<< endl;
2991 MPI_Abort(MPI_COMM_WORLD, 2);
2992 }
2993 delete[] work;
2994 }
2995 }
2996
2997 ////////////////////////////////////////////////////////////////////////////////
2998 // compute eigenvalues and eigenvectors of *this
2999 // store eigenvalues in w, eigenvectors in z
3000 // using the expert driver
syevx(char uplo,valarray<double> & w,DoubleMatrix & z,double abstol)3001 void DoubleMatrix::syevx(char uplo, valarray<double>& w, DoubleMatrix& z,
3002 double abstol)
3003 {
3004 int info;
3005 if ( active_ )
3006 {
3007 assert(m_==n_);
3008 char jobz = 'V';
3009 int lwork;
3010 double* work;
3011 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
3012 {
3013 lwork=-1;
3014 double tmplwork;
3015 dsyev(&jobz, &uplo, &m_, z.val, &m_, &w[0], &tmplwork, &lwork, &info);
3016 lwork = (int) tmplwork + 1;
3017 work = new double[lwork];
3018 z = *this;
3019 dsyev(&jobz, &uplo, &m_, z.val, &m_, &w[0], work, &lwork, &info);
3020 }
3021 else
3022 {
3023 char range = 'A';
3024 int ione=1;
3025 lwork=-1;
3026 double tmpwork;
3027 int liwork=-1;
3028 int tmpiwork;
3029 valarray<int> ifail(n_);
3030 int nfound=-1;
3031 int nz=-1;
3032 int il=1, iu=n_;
3033 double vl=0, vu=0;
3034 double orfac=-1.0;
3035 valarray<int> icluster(2*ctxt_.size());
3036 valarray<double> gap(ctxt_.size());
3037 pdsyevx(&jobz, &range, &uplo, &m_, val, &ione, &ione, desc_,
3038 &vl, &vu, &il, &iu, &abstol, &nfound, &nz, &w[0],
3039 &orfac, z.val, &ione, &ione, z.desc_, &tmpwork, &lwork,
3040 &tmpiwork, &liwork, &ifail[0], &icluster[0], &gap[0], &info);
3041 assert(info==0);
3042 lwork = (int) (tmpwork + 1);
3043 work=new double[lwork];
3044 liwork = tmpiwork;
3045 int* iwork = new int[liwork];
3046 pdsyevx(&jobz, &range, &uplo, &m_, val, &ione, &ione, desc_,
3047 &vl, &vu, &il, &iu, &abstol, &nfound, &nz, &w[0],
3048 &orfac, z.val, &ione, &ione, z.desc_, work, &lwork,
3049 iwork, &liwork, &ifail[0], &icluster[0], &gap[0], &info);
3050 MPI_Bcast(&w[0], m_, MPI_DOUBLE, 0, ctxt_.comm());
3051 delete[] iwork;
3052 }
3053 if ( info != 0 )
3054 {
3055 cout << " Matrix::syev requires lwork>=" << work[0] << endl;
3056 cout << " Matrix::syev, lwork>=" << lwork << endl;
3057 cout << " Matrix::syev, info=" << info<< endl;
3058 MPI_Abort(MPI_COMM_WORLD, 2);
3059 }
3060 delete[] work;
3061 }
3062 }
3063
3064 ////////////////////////////////////////////////////////////////////////////////
3065 // compute eigenvalues (only) of *this
3066 // store eigenvalues in w
syev(char uplo,valarray<double> & w)3067 void DoubleMatrix::syev(char uplo, valarray<double>& w)
3068 {
3069 int info;
3070 if ( active_ )
3071 {
3072 assert(m_==n_);
3073 char jobz = 'N';
3074 int lwork;
3075 double* work;
3076
3077 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
3078 {
3079 lwork=-1;
3080 double tmplwork;
3081 dsyev(&jobz, &uplo, &m_, val, &m_, &w[0], &tmplwork, &lwork, &info);
3082 lwork = (int) tmplwork + 1;
3083 work = new double[lwork];
3084 dsyev(&jobz, &uplo, &m_, val, &m_, &w[0], work, &lwork, &info);
3085 }
3086 else
3087 {
3088 int ione=1;
3089 lwork=-1;
3090 double tmplwork;
3091 double *zval = 0; // zval is not referenced since jobz == 'N'
3092 int * descz = 0;
3093 pdsyev(&jobz, &uplo, &m_, val, &ione, &ione, desc_, &w[0],
3094 zval, &ione, &ione, descz, &tmplwork, &lwork, &info);
3095 lwork = (int) tmplwork + 1;
3096 work=new double[lwork];
3097 pdsyev(&jobz, &uplo, &m_, val, &ione, &ione, desc_, &w[0],
3098 zval, &ione, &ione, descz, work, &lwork, &info);
3099 MPI_Bcast(&w[0], m_, MPI_DOUBLE, 0, ctxt_.comm());
3100 }
3101 if ( info != 0 )
3102 {
3103 cout << " Matrix::syev requires lwork>=" << work[0] << endl;
3104 cout << " Matrix::syev, lwork>=" << lwork << endl;
3105 cout << " Matrix::syev, info=" << info<< endl;
3106 MPI_Abort(MPI_COMM_WORLD, 2);
3107 }
3108 delete[] work;
3109 }
3110 }
3111
3112 ////////////////////////////////////////////////////////////////////////////////
3113 // compute eigenvalues (only) of *this
3114 // store eigenvalues in w
3115 // using the divide and conquer method of Tisseur and Dongarra
syevd(char uplo,valarray<double> & w)3116 void DoubleMatrix::syevd(char uplo, valarray<double>& w)
3117 {
3118 int info;
3119 if ( active_ )
3120 {
3121 assert(m_==n_);
3122 char jobz = 'N';
3123 int lwork,liwork;
3124 double* work;
3125
3126 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
3127 {
3128 lwork=-1;
3129 double tmplwork;
3130 dsyev(&jobz, &uplo, &m_, val, &m_, &w[0], &tmplwork, &lwork, &info);
3131 lwork = (int) tmplwork + 1;
3132 work = new double[lwork];
3133 dsyev(&jobz, &uplo, &m_, val, &m_, &w[0], work, &lwork, &info);
3134 }
3135 else
3136 {
3137 int ione=1;
3138 lwork=-1;
3139 double tmpwork;
3140 liwork=-1;
3141 int tmpiwork;
3142 double *zval = 0; // zval is not referenced since jobz == 'N'
3143 int * descz = 0;
3144 pdsyevd(&jobz, &uplo, &m_, val, &ione, &ione, desc_, &w[0],
3145 zval, &ione, &ione, descz, &tmpwork, &lwork,
3146 &tmpiwork, &liwork, &info);
3147 lwork = (int) (tmpwork + 1);
3148 work=new double[lwork];
3149 liwork = tmpiwork;
3150 int* iwork = new int[liwork];
3151 pdsyevd(&jobz, &uplo, &m_, val, &ione, &ione, desc_, &w[0],
3152 zval, &ione, &ione, descz, work, &lwork, iwork, &liwork, &info);
3153 MPI_Bcast(&w[0], m_, MPI_DOUBLE, 0, ctxt_.comm());
3154 delete[] iwork;
3155 }
3156 if ( info != 0 )
3157 {
3158 cout << " Matrix::syev requires lwork>=" << work[0] << endl;
3159 cout << " Matrix::syev, lwork>=" << lwork << endl;
3160 cout << " Matrix::syev, info=" << info<< endl;
3161 MPI_Abort(MPI_COMM_WORLD, 2);
3162 }
3163 delete[] work;
3164 }
3165 }
3166
3167 ////////////////////////////////////////////////////////////////////////////////
heev(char uplo,valarray<double> & w,ComplexMatrix & z)3168 void ComplexMatrix::heev(char uplo, valarray<double>& w, ComplexMatrix& z)
3169 {
3170 int info;
3171 if ( active_ )
3172 {
3173 assert(m_==n_);
3174 char jobz = 'V';
3175 int lwork;
3176 complex<double>* work;
3177 double* rwork;
3178
3179 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
3180 {
3181 // request optimal lwork size
3182 int lwork=-1;
3183 complex<double> tmplwork;
3184 int lrwork = max(1,3*n_-2);
3185 rwork = new double[lrwork];
3186 zheev(&jobz, &uplo, &m_, z.val, &m_, &w[0], &tmplwork, &lwork,
3187 rwork, &info);
3188 lwork = (int) real(tmplwork) + 1;
3189 work = new complex<double>[lwork];
3190 z=*this;
3191 zheev(&jobz, &uplo, &m_, z.val, &m_, &w[0], work, &lwork,
3192 rwork, &info);
3193 }
3194 else
3195 {
3196 int ione=1;
3197 lwork=-1;
3198 int lrwork=-1;
3199 complex<double> tmplwork;
3200 double tmplrwork;
3201 // first call to get optimal lwork and lrwork sizes
3202 pzheev(&jobz, &uplo, &n_, val, &ione, &ione, desc_, &w[0],
3203 z.val, &ione, &ione, z.desc_, &tmplwork, &lwork,
3204 &tmplrwork, &lrwork, &info);
3205 lwork = (int) real(tmplwork) + 1;
3206 work = new complex<double>[lwork];
3207 lrwork = (int) tmplrwork + 1;
3208 // direct calculation of lrwork to avoid bug in pzheev
3209 lrwork = 1 + 9*n_ + 3*mloc_*nloc_;
3210 rwork = new double[lrwork];
3211 pzheev(&jobz, &uplo, &n_, val, &ione, &ione, desc_, &w[0],
3212 z.val, &ione, &ione, z.desc_, work, &lwork,
3213 rwork, &lrwork, &info);
3214 MPI_Bcast(&w[0], m_, MPI_DOUBLE, 0, ctxt_.comm());
3215 }
3216 if ( info != 0 )
3217 {
3218 cout << " Matrix::heev requires lwork>=" << work[0] << endl;
3219 cout << " Matrix::heev, lwork>=" << lwork << endl;
3220 cout << " Matrix::heev, info=" << info<< endl;
3221 MPI_Abort(MPI_COMM_WORLD, 2);
3222 }
3223 delete[] work;
3224 delete[] rwork;
3225 }
3226 }
3227
3228 ////////////////////////////////////////////////////////////////////////////////
heevd(char uplo,valarray<double> & w,ComplexMatrix & z)3229 void ComplexMatrix::heevd(char uplo, valarray<double>& w, ComplexMatrix& z)
3230 {
3231 int info;
3232 if ( active_ )
3233 {
3234 assert(m_==n_);
3235 char jobz = 'V';
3236 int lwork,liwork;
3237 complex<double>* work;
3238 double* rwork;
3239
3240 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
3241 {
3242 // request optimal lwork size
3243 lwork=-1;
3244 complex<double> tmplwork;
3245 int lrwork = max(1,3*n_-2);
3246 rwork = new double[lrwork];
3247 zheev(&jobz, &uplo, &m_, z.val, &m_, &w[0], &tmplwork, &lwork,
3248 rwork, &info);
3249 lwork = (int) real(tmplwork) + 1;
3250 work = new complex<double>[lwork];
3251 z=*this;
3252 zheev(&jobz, &uplo, &m_, z.val, &m_, &w[0], work, &lwork,
3253 rwork, &info);
3254 if ( info != 0 )
3255 {
3256 cout << " Matrix::heevd requires lwork>=" << work[0] << endl;
3257 cout << " Matrix::heevd, lwork>=" << lwork << endl;
3258 cout << " Matrix::heevd, info=" << info<< endl;
3259 MPI_Abort(MPI_COMM_WORLD, 2);
3260 }
3261 }
3262 else
3263 {
3264 int ione=1;
3265 lwork=-1;
3266 int lrwork=-1;
3267 liwork=-1;
3268 complex<double> tmplwork;
3269 double tmplrwork;
3270 int tmpliwork;
3271 // first call to get optimal lwork and lrwork sizes
3272 pzheevd(&jobz, &uplo, &n_, val, &ione, &ione, desc_, &w[0],
3273 z.val, &ione, &ione, z.desc_, &tmplwork, &lwork,
3274 &tmplrwork, &lrwork, &tmpliwork, &liwork, &info);
3275 lwork = (int) real(tmplwork) + 1;
3276 work = new complex<double>[lwork];
3277 lrwork = (int) tmplrwork + 1;
3278 rwork = new double[lrwork];
3279 liwork = tmpliwork;
3280 int* iwork = new int[liwork];
3281 pzheevd(&jobz, &uplo, &n_, val, &ione, &ione, desc_, &w[0],
3282 z.val, &ione, &ione, z.desc_, work, &lwork,
3283 rwork, &lrwork, iwork, &liwork, &info);
3284 //MPI_Bcast(&w[0], m_, MPI_DOUBLE, 0, ctxt_.comm());
3285 if ( info != 0 )
3286 {
3287 cout << " Matrix::heevd requires lwork>=" << work[0] << endl;
3288 cout << " Matrix::heevd, lwork>=" << lwork << endl;
3289 cout << " Matrix::heevd, liwork>=" << liwork << endl;
3290 cout << " Matrix::heevd, info=" << info<< endl;
3291 MPI_Abort(MPI_COMM_WORLD, 2);
3292 delete[] work;
3293 delete[] rwork;
3294 delete[] iwork;
3295 }
3296 }
3297 delete[] work;
3298 delete[] rwork;
3299 }
3300 }
3301
3302 ////////////////////////////////////////////////////////////////////////////////
3303 // compute eigenvalues (only) of hermitian matrix *this
heev(char uplo,valarray<double> & w)3304 void ComplexMatrix::heev(char uplo, valarray<double>& w)
3305 {
3306 int info;
3307 if ( active_ )
3308 {
3309 assert(m_==n_);
3310 char jobz = 'N';
3311 int lwork;
3312 complex<double>* work;
3313
3314 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
3315 {
3316 // request optimal lwork size
3317 lwork=-1;
3318 complex<double> tmplwork;
3319 int lrwork = max(1,3*n_-2);
3320 double* rwork = new double[lrwork];
3321 zheev(&jobz, &uplo, &m_, val, &m_, &w[0], &tmplwork, &lwork,
3322 rwork, &info);
3323 lwork = (int) real(tmplwork);
3324 work = new complex<double>[lwork];
3325 zheev(&jobz, &uplo, &m_, val, &m_, &w[0], work, &lwork,
3326 rwork, &info);
3327 delete[] rwork;
3328 }
3329 else
3330 {
3331 int ione=1;
3332 lwork=-1;
3333 int lrwork=-1;
3334 complex<double> tmplwork;
3335 double tmplrwork;
3336 complex<double> *zval = 0;
3337 int *descz = 0;
3338 // first call to get optimal lwork and lrwork sizes
3339 pzheev(&jobz, &uplo, &n_, val, &ione, &ione, desc_, &w[0],
3340 zval, &ione, &ione, descz, &tmplwork, &lwork,
3341 &tmplrwork, &lrwork, &info);
3342 lwork = (int) real(tmplwork) + 1;
3343 work = new complex<double>[lwork];
3344 lrwork = (int) tmplrwork + 1;
3345 double* rwork = new double[lrwork];
3346 pzheev(&jobz, &uplo, &n_, val, &ione, &ione, desc_, &w[0],
3347 zval, &ione, &ione, descz, work, &lwork,
3348 rwork, &lrwork, &info);
3349 MPI_Bcast(&w[0], m_, MPI_DOUBLE, 0, ctxt_.comm());
3350 delete[] rwork;
3351 }
3352 if ( info != 0 )
3353 {
3354 cout << " Matrix::heev requires lwork>=" << work[0] << endl;
3355 cout << " Matrix::heev, lwork>=" << lwork << endl;
3356 cout << " Matrix::heev, info=" << info << endl;
3357 MPI_Abort(MPI_COMM_WORLD, 2);
3358 }
3359 delete[] work;
3360 }
3361 }
3362
3363 ////////////////////////////////////////////////////////////////////////////////
3364 // compute eigenvalues (only) of hermitian matrix *this
heevd(char uplo,valarray<double> & w)3365 void ComplexMatrix::heevd(char uplo, valarray<double>& w)
3366 {
3367 int info;
3368 if ( active_ )
3369 {
3370 assert(m_==n_);
3371 char jobz = 'N';
3372
3373 if ( ( nprow_ == 1 ) && ( npcol_ == 1 ) )
3374 {
3375 // request optimal lwork size
3376 int lwork=-1;
3377 complex<double> tmplwork;
3378 int lrwork = max(1,3*n_-2);
3379 double* rwork = new double[lrwork];
3380 zheev(&jobz, &uplo, &m_, val, &m_, &w[0], &tmplwork, &lwork,
3381 rwork, &info);
3382 lwork = (int) real(tmplwork);
3383 complex<double>* work = new complex<double>[lwork];
3384 zheev(&jobz, &uplo, &m_, val, &m_, &w[0], work, &lwork,
3385 rwork, &info);
3386 if ( info != 0 )
3387 {
3388 cout << " Matrix::heevd requires lwork>=" << work[0] << endl;
3389 cout << " Matrix::heevd, lwork>=" << lwork << endl;
3390 cout << " Matrix::heevd, lrwork>=" << lrwork << endl;
3391 cout << " Matrix::heevd, info=" << info << endl;
3392 MPI_Abort(MPI_COMM_WORLD, 2);
3393 }
3394 delete[] work;
3395 delete[] rwork;
3396 }
3397 else
3398 {
3399 int ione=1;
3400 int lwork=-1;
3401 int lrwork=-1;
3402 int liwork=-1;
3403 complex<double> tmplwork;
3404 double tmplrwork;
3405 int tmpliwork;
3406 complex<double> *zval = 0;
3407 int *descz = 0;
3408 // first call to get optimal lwork and lrwork sizes
3409 pzheevd(&jobz, &uplo, &n_, val, &ione, &ione, desc_, &w[0],
3410 zval, &ione, &ione, descz, &tmplwork, &lwork,
3411 &tmplrwork, &lrwork, &tmpliwork, &liwork, &info);
3412 lwork = (int) real(tmplwork) + 1;
3413 complex<double>* work = new complex<double>[lwork];
3414 lrwork = (int) tmplrwork + 1;
3415 liwork = tmpliwork;
3416 double* rwork = new double[lrwork];
3417 int* iwork = new int[liwork];
3418 pzheevd(&jobz, &uplo, &n_, val, &ione, &ione, desc_, &w[0],
3419 zval, &ione, &ione, descz, work, &lwork,
3420 rwork, &lrwork, iwork, &liwork, &info);
3421 MPI_Bcast(&w[0], m_, MPI_DOUBLE, 0, ctxt_.comm());
3422 if ( info != 0 )
3423 {
3424 cout << " Matrix::heevd requires lwork>=" << work[0] << endl;
3425 cout << " Matrix::heevd, lwork>=" << lwork << endl;
3426 cout << " Matrix::heevd, lrwork>=" << lrwork << endl;
3427 cout << " Matrix::heevd, liwork>=" << liwork << endl;
3428 cout << " Matrix::heevd, info=" << info << endl;
3429 MPI_Abort(MPI_COMM_WORLD, 2);
3430 }
3431 delete[] work;
3432 delete[] rwork;
3433 delete[] iwork;
3434 }
3435 }
3436 }
3437
3438 ////////////////////////////////////////////////////////////////////////////////
lapiv(char direc,char rowcol,const int * ipiv)3439 void DoubleMatrix::lapiv(char direc, char rowcol, const int *ipiv)
3440 {
3441 // Perform a permutation of rows or columns of *this
3442 //
3443 // Permutation of rows: (rowcol='R' or 'r')
3444 // the array ipiv is distributed over a process column
3445 // and is replicated on all process columns
3446 // ipiv has size mloc and contains the local values of the permutation
3447 //
3448 // Permutation of columns: (rowcol='C' or 'c')
3449 // the array ipiv is distributed over a process row
3450 // and is replicated on all process rows
3451 // ipiv has size nloc and contains the local values of the permutation
3452
3453 const bool rowcol_r = ( rowcol=='R' || rowcol=='r' );
3454 const bool rowcol_c = ( rowcol=='C' || rowcol=='c' );
3455 assert(rowcol_r || rowcol_c);
3456
3457 // ipivtmp: extended permutation array for use in pdlapv2
3458 // (see scalapack documentation)
3459 vector<int> ipivtmp;
3460 // descriptor of the ipivtmp distributed vector
3461 int desc_ip[9];
3462 if ( rowcol_r )
3463 {
3464 // permuting rows
3465 ipivtmp.resize(mloc_ + mb_);
3466 for ( int i = 0; i < mloc_; i++)
3467 ipivtmp[i] = ipiv[i];
3468 // initialize descriptor: ipivtmp is (mx1)
3469 desc_ip[0] = 1; // dtype
3470 desc_ip[1] = ictxt_;// ctxt
3471 desc_ip[2] = m_+mb_*nprow_; // m (see details in pldapv2.f)
3472 desc_ip[3] = 1; // n
3473 desc_ip[4] = mb_; // mb
3474 desc_ip[5] = 1; // nb
3475 desc_ip[6] = 0; // rsrc
3476 desc_ip[7] = 0; // csrc
3477 desc_ip[8] = mloc_; // lld
3478 }
3479 else
3480 {
3481 // permuting columns
3482 ipivtmp.resize(nloc_ + nb_);
3483 for ( int i = 0; i < nloc_; i++)
3484 ipivtmp[i] = ipiv[i];
3485 // initialize descriptor: ipivtmp is (1xn)
3486 desc_ip[0] = 1; // dtype
3487 desc_ip[1] = ictxt_;// ctxt
3488 desc_ip[2] = 1; // m
3489 desc_ip[3] = n_+nb_*npcol_; // n (see details in pdlapv2.f)
3490 desc_ip[4] = 1; // mb
3491 desc_ip[5] = nb_; // nb
3492 desc_ip[6] = 0; // rsrc
3493 desc_ip[7] = 0; // csrc
3494 desc_ip[8] = 1; // lld
3495 }
3496
3497 int one = 1;
3498 pdlapv2(&direc, &rowcol, &m_, &n_, val, &one, &one, desc_,
3499 &ipivtmp[0], &one, &one, desc_ip);
3500 }
3501
3502 ////////////////////////////////////////////////////////////////////////////////
lapiv(char direc,char rowcol,const int * ipiv)3503 void ComplexMatrix::lapiv(char direc, char rowcol, const int *ipiv)
3504 {
3505 // Perform a permutation of rows or columns of *this
3506 //
3507 // Permutation of rows: (rowcol='R' or 'r')
3508 // the array ipiv is distributed over a process column
3509 // and is replicated on all process columns
3510 // ipiv has size mloc and contains the local values of the permutation
3511 //
3512 // Permutation of columns: (rowcol='C' or 'c')
3513 // the array ipiv is distributed over a process row
3514 // and is replicated on all process rows
3515 // ipiv has size nloc and contains the local values of the permutation
3516
3517 const bool rowcol_r = ( rowcol=='R' || rowcol=='r' );
3518 const bool rowcol_c = ( rowcol=='C' || rowcol=='c' );
3519 assert(rowcol_r || rowcol_c);
3520
3521 // ipivtmp: extended permutation array for use in pdlapv2
3522 // (see scalapack documentation)
3523 vector<int> ipivtmp;
3524 // descriptor of the ipivtmp distributed vector
3525 int desc_ip[9];
3526 if ( rowcol_r )
3527 {
3528 // permuting rows
3529 ipivtmp.resize(mloc_ + mb_);
3530 for ( int i = 0; i < mloc_; i++)
3531 ipivtmp[i] = ipiv[i];
3532 // initialize descriptor: ipivtmp is (mx1)
3533 desc_ip[0] = 1; // dtype
3534 desc_ip[1] = ictxt_;// ctxt
3535 desc_ip[2] = m_+mb_*nprow_; // m (see details in pldapv2.f)
3536 desc_ip[3] = 1; // n
3537 desc_ip[4] = mb_; // mb
3538 desc_ip[5] = 1; // nb
3539 desc_ip[6] = 0; // rsrc
3540 desc_ip[7] = 0; // csrc
3541 desc_ip[8] = mloc_; // lld
3542 }
3543 else
3544 {
3545 // permuting columns
3546 ipivtmp.resize(nloc_ + nb_);
3547 for ( int i = 0; i < nloc_; i++)
3548 ipivtmp[i] = ipiv[i];
3549 // initialize descriptor: ipivtmp is (1x(n+nb))
3550 desc_ip[0] = 1; // dtype
3551 desc_ip[1] = ictxt_;// ctxt
3552 desc_ip[2] = 1; // m
3553 desc_ip[3] = n_+nb_*npcol_; // n (see details in pdlapv2.f)
3554 desc_ip[4] = 1; // mb
3555 desc_ip[5] = nb_; // nb
3556 desc_ip[6] = 0; // rsrc
3557 desc_ip[7] = 0; // csrc
3558 desc_ip[8] = 1; // lld
3559 }
3560
3561 int one = 1;
3562 pzlapv2(&direc, &rowcol, &m_, &n_, val, &one, &one, desc_,
3563 &ipivtmp[0], &one, &one, desc_ip);
3564 }
3565
3566 ////////////////////////////////////////////////////////////////////////////////
print(ostream & os) const3567 void DoubleMatrix::print(ostream& os) const
3568 {
3569 // Copy blocks of <blocksize> columns and print them on process (0,0)
3570 if ( m_ == 0 || n_ == 0 ) return;
3571 Context ctxtl(MPI_COMM_WORLD,1,1);
3572 const int blockmemsize = 32768; // maximum memory size of a block in bytes
3573 // compute maximum block size: must be at least 1
3574 int maxbs = max(1, (int) ((blockmemsize/sizeof(double))/m_));
3575 DoubleMatrix t(ctxtl,m_,maxbs);
3576 int nblocks = n_ / maxbs + ( (n_%maxbs == 0) ? 0 : 1 );
3577 int ia = 0;
3578 int ja = 0;
3579 for ( int jb = 0; jb < nblocks; jb++ )
3580 {
3581 int blocksize = ( (jb+1) * maxbs > n_ ) ? n_ % maxbs : maxbs;
3582 t.getsub(*this,t.m(),blocksize,ia,ja);
3583 ja += blocksize;
3584 if ( t.active() )
3585 {
3586 // this is done only on pe 0
3587 for ( int jj = 0; jj < blocksize; jj++ )
3588 {
3589 for ( int ii = 0; ii < m_; ii++ )
3590 {
3591 os << "(" << ii << "," << jj+jb*maxbs << ")="
3592 << t.val[ii+t.mloc()*jj] << endl;
3593 }
3594 }
3595 }
3596 }
3597 }
3598 ////////////////////////////////////////////////////////////////////////////////
print(ostream & os) const3599 void ComplexMatrix::print(ostream& os) const
3600 {
3601 // Copy blocks of <blocksize> columns and print them on process (0,0)
3602 if ( m_ == 0 || n_ == 0 ) return;
3603 Context ctxtl(MPI_COMM_WORLD,1,1);
3604 const int blockmemsize = 32768; // maximum memory size of a block in bytes
3605 // compute maximum block size: must be at least 1
3606 int maxbs = max(1, (int) ((blockmemsize/sizeof(complex<double>))/m_));
3607 ComplexMatrix t(ctxtl,m_,maxbs);
3608 int nblocks = n_ / maxbs + ( (n_%maxbs == 0) ? 0 : 1 );
3609 int ia = 0;
3610 int ja = 0;
3611 for ( int jb = 0; jb < nblocks; jb++ )
3612 {
3613 int blocksize = ( (jb+1) * maxbs > n_ ) ? n_ % maxbs : maxbs;
3614 t.getsub(*this,t.m(),blocksize,ia,ja);
3615 ja += blocksize;
3616 if ( t.active() )
3617 {
3618 // this is done only on pe 0
3619 for ( int jj = 0; jj < blocksize; jj++ )
3620 {
3621 for ( int ii = 0; ii < m_; ii++ )
3622 {
3623 os << "(" << ii << "," << jj+jb*maxbs << ")="
3624 << t.val[ii+t.mloc()*jj] << endl;
3625 }
3626 }
3627 }
3628 }
3629 }
3630
3631 ////////////////////////////////////////////////////////////////////////////////
operator <<(ostream & os,const DoubleMatrix & a)3632 ostream& operator<<(ostream& os, const DoubleMatrix& a)
3633 {
3634 a.print(os);
3635 return os;
3636 }
3637 ////////////////////////////////////////////////////////////////////////////////
operator <<(ostream & os,const ComplexMatrix & a)3638 ostream& operator<<(ostream& os, const ComplexMatrix& a)
3639 {
3640 a.print(os);
3641 return os;
3642 }
3643
3644 ////////////////////////////////////////////////////////////////////////////////
3645 //
3646 // signature: compute the signature of a row permutation
3647 // defined by a distributed pivot vector ipiv
3648 //
3649 // the vector ipiv is computed by the lu decomposition
3650 //
signature(valarray<int> ipiv)3651 int DoubleMatrix::signature(valarray<int> ipiv)
3652 {
3653 // count the number of non-trivial transpositions in the local ipiv vector
3654 int ntrans = 0;
3655 for ( int i = 0; i < mloc_; i++ )
3656 {
3657 if ( ipiv[i] != iglobal(i) )
3658 ntrans++;
3659 }
3660 // accumulate total number of transpositions
3661 ctxt_.isum('c',1,1,&ntrans,1);
3662 return 1 - 2 * ((m_ - ntrans)%2);
3663 }
3664
3665 ////////////////////////////////////////////////////////////////////////////////
3666 //
3667 // signature: compute the signature of a row permutation
3668 // defined by a distributed pivot vector ipiv
3669 //
3670 // the vector ipiv is computed by the lu decomposition
3671 //
signature(valarray<int> ipiv)3672 int ComplexMatrix::signature(valarray<int> ipiv)
3673 {
3674 // count the number of non-trivial transpositions in the local ipiv vector
3675 int ntrans = 0;
3676 for ( int i = 0; i < mloc_; i++ )
3677 {
3678 if ( ipiv[i] != iglobal(i) )
3679 ntrans++;
3680 }
3681 // accumulate total number of transpositions
3682 ctxt_.isum('c',1,1,&ntrans,1);
3683 return 1 - 2 * ((m_ - ntrans)%2);
3684 }
3685