1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 1994-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29
30 #include "Array.h"
31 #include "EIG.h"
32 #include "dColVector.h"
33 #include "dMatrix.h"
34 #include "lo-error.h"
35 #include "lo-lapack-proto.h"
36
37 octave_idx_type
init(const Matrix & a,bool calc_rev,bool calc_lev,bool balance)38 EIG::init (const Matrix& a, bool calc_rev, bool calc_lev, bool balance)
39 {
40 if (a.any_element_is_inf_or_nan ())
41 (*current_liboctave_error_handler)
42 ("EIG: matrix contains Inf or NaN values");
43
44 if (a.issymmetric ())
45 return symmetric_init (a, calc_rev, calc_lev);
46
47 F77_INT n = octave::to_f77_int (a.rows ());
48 F77_INT a_nc = octave::to_f77_int (a.cols ());
49
50 if (n != a_nc)
51 (*current_liboctave_error_handler) ("EIG requires square matrix");
52
53 F77_INT info = 0;
54
55 Matrix atmp = a;
56 double *tmp_data = atmp.fortran_vec ();
57
58 Array<double> wr (dim_vector (n, 1));
59 double *pwr = wr.fortran_vec ();
60
61 Array<double> wi (dim_vector (n, 1));
62 double *pwi = wi.fortran_vec ();
63
64 F77_INT tnvr = (calc_rev ? n : 0);
65 Matrix vr (tnvr, tnvr);
66 double *pvr = vr.fortran_vec ();
67
68 F77_INT tnvl = (calc_lev ? n : 0);
69 Matrix vl (tnvl, tnvl);
70 double *pvl = vl.fortran_vec ();
71
72 F77_INT lwork = -1;
73 double dummy_work;
74
75 F77_INT ilo;
76 F77_INT ihi;
77
78 Array<double> scale (dim_vector (n, 1));
79 double *pscale = scale.fortran_vec ();
80
81 double abnrm;
82
83 Array<double> rconde (dim_vector (n, 1));
84 double *prconde = rconde.fortran_vec ();
85
86 Array<double> rcondv (dim_vector (n, 1));
87 double *prcondv = rcondv.fortran_vec ();
88
89 F77_INT dummy_iwork;
90
91 F77_XFCN (dgeevx, DGEEVX, (F77_CONST_CHAR_ARG2 (balance ? "B" : "N", 1),
92 F77_CONST_CHAR_ARG2 (calc_lev ? "V" : "N", 1),
93 F77_CONST_CHAR_ARG2 (calc_rev ? "V" : "N", 1),
94 F77_CONST_CHAR_ARG2 ("N", 1),
95 n, tmp_data, n, pwr, pwi, pvl,
96 n, pvr, n, ilo, ihi, pscale,
97 abnrm, prconde, prcondv, &dummy_work,
98 lwork, &dummy_iwork, info
99 F77_CHAR_ARG_LEN (1)
100 F77_CHAR_ARG_LEN (1)
101 F77_CHAR_ARG_LEN (1)
102 F77_CHAR_ARG_LEN (1)));
103
104 if (info != 0)
105 (*current_liboctave_error_handler) ("dgeevx workspace query failed");
106
107 lwork = static_cast<F77_INT> (dummy_work);
108 Array<double> work (dim_vector (lwork, 1));
109 double *pwork = work.fortran_vec ();
110
111 F77_XFCN (dgeevx, DGEEVX, (F77_CONST_CHAR_ARG2 (balance ? "B" : "N", 1),
112 F77_CONST_CHAR_ARG2 (calc_lev ? "V" : "N", 1),
113 F77_CONST_CHAR_ARG2 (calc_rev ? "V" : "N", 1),
114 F77_CONST_CHAR_ARG2 ("N", 1),
115 n, tmp_data, n, pwr, pwi, pvl,
116 n, pvr, n, ilo, ihi, pscale,
117 abnrm, prconde, prcondv, pwork,
118 lwork, &dummy_iwork, info
119 F77_CHAR_ARG_LEN (1)
120 F77_CHAR_ARG_LEN (1)
121 F77_CHAR_ARG_LEN (1)
122 F77_CHAR_ARG_LEN (1)));
123
124 if (info < 0)
125 (*current_liboctave_error_handler) ("unrecoverable error in dgeevx");
126
127 if (info > 0)
128 (*current_liboctave_error_handler) ("dgeevx failed to converge");
129
130 lambda.resize (n);
131 F77_INT nvr = (calc_rev ? n : 0);
132 v.resize (nvr, nvr);
133 F77_INT nvl = (calc_lev ? n : 0);
134 w.resize (nvl, nvl);
135
136 for (F77_INT j = 0; j < n; j++)
137 {
138 if (wi.elem (j) == 0.0)
139 {
140 lambda.elem (j) = Complex (wr.elem (j));
141 for (F77_INT i = 0; i < nvr; i++)
142 v.elem (i, j) = vr.elem (i, j);
143
144 for (F77_INT i = 0; i < nvl; i++)
145 w.elem (i, j) = vl.elem (i, j);
146 }
147 else
148 {
149 if (j+1 >= n)
150 (*current_liboctave_error_handler) ("EIG: internal error");
151
152 lambda.elem (j) = Complex (wr.elem (j), wi.elem (j));
153 lambda.elem (j+1) = Complex (wr.elem (j+1), wi.elem (j+1));
154
155 for (F77_INT i = 0; i < nvr; i++)
156 {
157 double real_part = vr.elem (i, j);
158 double imag_part = vr.elem (i, j+1);
159 v.elem (i, j) = Complex (real_part, imag_part);
160 v.elem (i, j+1) = Complex (real_part, -imag_part);
161 }
162
163 for (F77_INT i = 0; i < nvl; i++)
164 {
165 double real_part = vl.elem (i, j);
166 double imag_part = vl.elem (i, j+1);
167 w.elem (i, j) = Complex (real_part, imag_part);
168 w.elem (i, j+1) = Complex (real_part, -imag_part);
169 }
170 j++;
171 }
172 }
173
174 return info;
175 }
176
177 octave_idx_type
symmetric_init(const Matrix & a,bool calc_rev,bool calc_lev)178 EIG::symmetric_init (const Matrix& a, bool calc_rev, bool calc_lev)
179 {
180 F77_INT n = octave::to_f77_int (a.rows ());
181 F77_INT a_nc = octave::to_f77_int (a.cols ());
182
183 if (n != a_nc)
184 (*current_liboctave_error_handler) ("EIG requires square matrix");
185
186 F77_INT info = 0;
187
188 Matrix atmp = a;
189 double *tmp_data = atmp.fortran_vec ();
190
191 ColumnVector wr (n);
192 double *pwr = wr.fortran_vec ();
193
194 F77_INT lwork = -1;
195 double dummy_work;
196
197 F77_XFCN (dsyev, DSYEV, (F77_CONST_CHAR_ARG2 (calc_rev ? "V" : "N", 1),
198 F77_CONST_CHAR_ARG2 ("U", 1),
199 n, tmp_data, n, pwr, &dummy_work, lwork, info
200 F77_CHAR_ARG_LEN (1)
201 F77_CHAR_ARG_LEN (1)));
202
203 if (info != 0)
204 (*current_liboctave_error_handler) ("dsyev workspace query failed");
205
206 lwork = static_cast<F77_INT> (dummy_work);
207 Array<double> work (dim_vector (lwork, 1));
208 double *pwork = work.fortran_vec ();
209
210 F77_XFCN (dsyev, DSYEV, (F77_CONST_CHAR_ARG2 (calc_rev ? "V" : "N", 1),
211 F77_CONST_CHAR_ARG2 ("U", 1),
212 n, tmp_data, n, pwr, pwork, lwork, info
213 F77_CHAR_ARG_LEN (1)
214 F77_CHAR_ARG_LEN (1)));
215
216 if (info < 0)
217 (*current_liboctave_error_handler) ("unrecoverable error in dsyev");
218
219 if (info > 0)
220 (*current_liboctave_error_handler) ("dsyev failed to converge");
221
222 lambda = ComplexColumnVector (wr);
223 v = (calc_rev ? ComplexMatrix (atmp) : ComplexMatrix ());
224 w = (calc_lev ? ComplexMatrix (atmp) : ComplexMatrix ());
225
226 return info;
227 }
228
229 octave_idx_type
init(const ComplexMatrix & a,bool calc_rev,bool calc_lev,bool balance)230 EIG::init (const ComplexMatrix& a, bool calc_rev, bool calc_lev, bool balance)
231 {
232 if (a.any_element_is_inf_or_nan ())
233 (*current_liboctave_error_handler)
234 ("EIG: matrix contains Inf or NaN values");
235
236 if (a.ishermitian ())
237 return hermitian_init (a, calc_rev, calc_lev);
238
239 F77_INT n = octave::to_f77_int (a.rows ());
240 F77_INT a_nc = octave::to_f77_int (a.cols ());
241
242 if (n != a_nc)
243 (*current_liboctave_error_handler) ("EIG requires square matrix");
244
245 F77_INT info = 0;
246
247 ComplexMatrix atmp = a;
248 Complex *tmp_data = atmp.fortran_vec ();
249
250 ComplexColumnVector wr (n);
251 Complex *pw = wr.fortran_vec ();
252
253 F77_INT nvr = (calc_rev ? n : 0);
254 ComplexMatrix vrtmp (nvr, nvr);
255 Complex *pvr = vrtmp.fortran_vec ();
256
257 F77_INT nvl = (calc_lev ? n : 0);
258 ComplexMatrix vltmp (nvl, nvl);
259 Complex *pvl = vltmp.fortran_vec ();
260
261 F77_INT lwork = -1;
262 Complex dummy_work;
263
264 F77_INT lrwork = 2*n;
265 Array<double> rwork (dim_vector (lrwork, 1));
266 double *prwork = rwork.fortran_vec ();
267
268 F77_INT ilo;
269 F77_INT ihi;
270
271 Array<double> scale (dim_vector (n, 1));
272 double *pscale = scale.fortran_vec ();
273
274 double abnrm;
275
276 Array<double> rconde (dim_vector (n, 1));
277 double *prconde = rconde.fortran_vec ();
278
279 Array<double> rcondv (dim_vector (n, 1));
280 double *prcondv = rcondv.fortran_vec ();
281
282 F77_XFCN (zgeevx, ZGEEVX, (F77_CONST_CHAR_ARG2 (balance ? "B" : "N", 1),
283 F77_CONST_CHAR_ARG2 (calc_lev ? "V" : "N", 1),
284 F77_CONST_CHAR_ARG2 (calc_rev ? "V" : "N", 1),
285 F77_CONST_CHAR_ARG2 ("N", 1),
286 n, F77_DBLE_CMPLX_ARG (tmp_data), n,
287 F77_DBLE_CMPLX_ARG (pw), F77_DBLE_CMPLX_ARG (pvl),
288 n, F77_DBLE_CMPLX_ARG (pvr), n, ilo, ihi,
289 pscale, abnrm, prconde, prcondv,
290 F77_DBLE_CMPLX_ARG (&dummy_work), lwork, prwork,
291 info
292 F77_CHAR_ARG_LEN (1)
293 F77_CHAR_ARG_LEN (1)
294 F77_CHAR_ARG_LEN (1)
295 F77_CHAR_ARG_LEN (1)));
296
297 if (info != 0)
298 (*current_liboctave_error_handler) ("zgeevx workspace query failed");
299
300 lwork = static_cast<F77_INT> (dummy_work.real ());
301 Array<Complex> work (dim_vector (lwork, 1));
302 Complex *pwork = work.fortran_vec ();
303
304 F77_XFCN (zgeevx, ZGEEVX, (F77_CONST_CHAR_ARG2 (balance ? "B" : "N", 1),
305 F77_CONST_CHAR_ARG2 (calc_lev ? "V" : "N", 1),
306 F77_CONST_CHAR_ARG2 (calc_rev ? "V" : "N", 1),
307 F77_CONST_CHAR_ARG2 ("N", 1),
308 n, F77_DBLE_CMPLX_ARG (tmp_data), n,
309 F77_DBLE_CMPLX_ARG (pw), F77_DBLE_CMPLX_ARG (pvl),
310 n, F77_DBLE_CMPLX_ARG (pvr), n, ilo, ihi,
311 pscale, abnrm, prconde, prcondv,
312 F77_DBLE_CMPLX_ARG (pwork), lwork, prwork, info
313 F77_CHAR_ARG_LEN (1)
314 F77_CHAR_ARG_LEN (1)
315 F77_CHAR_ARG_LEN (1)
316 F77_CHAR_ARG_LEN (1)));
317
318 if (info < 0)
319 (*current_liboctave_error_handler) ("unrecoverable error in zgeevx");
320
321 if (info > 0)
322 (*current_liboctave_error_handler) ("zgeevx failed to converge");
323
324 lambda = wr;
325 v = vrtmp;
326 w = vltmp;
327
328 return info;
329 }
330
331 octave_idx_type
hermitian_init(const ComplexMatrix & a,bool calc_rev,bool calc_lev)332 EIG::hermitian_init (const ComplexMatrix& a, bool calc_rev, bool calc_lev)
333 {
334 F77_INT n = octave::to_f77_int (a.rows ());
335 F77_INT a_nc = octave::to_f77_int (a.cols ());
336
337 if (n != a_nc)
338 (*current_liboctave_error_handler) ("EIG requires square matrix");
339
340 F77_INT info = 0;
341
342 ComplexMatrix atmp = a;
343 Complex *tmp_data = atmp.fortran_vec ();
344
345 ColumnVector wr (n);
346 double *pwr = wr.fortran_vec ();
347
348 F77_INT lwork = -1;
349 Complex dummy_work;
350
351 F77_INT lrwork = 3*n;
352 Array<double> rwork (dim_vector (lrwork, 1));
353 double *prwork = rwork.fortran_vec ();
354
355 F77_XFCN (zheev, ZHEEV, (F77_CONST_CHAR_ARG2 (calc_rev ? "V" : "N", 1),
356 F77_CONST_CHAR_ARG2 ("U", 1),
357 n, F77_DBLE_CMPLX_ARG (tmp_data), n, pwr,
358 F77_DBLE_CMPLX_ARG (&dummy_work), lwork,
359 prwork, info
360 F77_CHAR_ARG_LEN (1)
361 F77_CHAR_ARG_LEN (1)));
362
363 if (info != 0)
364 (*current_liboctave_error_handler) ("zheev workspace query failed");
365
366 lwork = static_cast<F77_INT> (dummy_work.real ());
367 Array<Complex> work (dim_vector (lwork, 1));
368 Complex *pwork = work.fortran_vec ();
369
370 F77_XFCN (zheev, ZHEEV, (F77_CONST_CHAR_ARG2 (calc_rev ? "V" : "N", 1),
371 F77_CONST_CHAR_ARG2 ("U", 1),
372 n, F77_DBLE_CMPLX_ARG (tmp_data), n, pwr,
373 F77_DBLE_CMPLX_ARG (pwork), lwork, prwork, info
374 F77_CHAR_ARG_LEN (1)
375 F77_CHAR_ARG_LEN (1)));
376
377 if (info < 0)
378 (*current_liboctave_error_handler) ("unrecoverable error in zheev");
379
380 if (info > 0)
381 (*current_liboctave_error_handler) ("zheev failed to converge");
382
383 lambda = ComplexColumnVector (wr);
384 v = (calc_rev ? ComplexMatrix (atmp) : ComplexMatrix ());
385 w = (calc_lev ? ComplexMatrix (atmp) : ComplexMatrix ());
386
387 return info;
388 }
389
390 octave_idx_type
init(const Matrix & a,const Matrix & b,bool calc_rev,bool calc_lev,bool force_qz)391 EIG::init (const Matrix& a, const Matrix& b, bool calc_rev, bool calc_lev,
392 bool force_qz)
393 {
394 if (a.any_element_is_inf_or_nan () || b.any_element_is_inf_or_nan ())
395 (*current_liboctave_error_handler)
396 ("EIG: matrix contains Inf or NaN values");
397
398 F77_INT n = octave::to_f77_int (a.rows ());
399 F77_INT nb = octave::to_f77_int (b.rows ());
400
401 F77_INT a_nc = octave::to_f77_int (a.cols ());
402 F77_INT b_nc = octave::to_f77_int (b.cols ());
403
404 if (n != a_nc || nb != b_nc)
405 (*current_liboctave_error_handler) ("EIG requires square matrix");
406
407 if (n != nb)
408 (*current_liboctave_error_handler) ("EIG requires same size matrices");
409
410 F77_INT info = 0;
411
412 Matrix tmp = b;
413 double *tmp_data = tmp.fortran_vec ();
414
415 if (! force_qz)
416 {
417 F77_XFCN (dpotrf, DPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1),
418 n, tmp_data, n,
419 info
420 F77_CHAR_ARG_LEN (1)));
421
422 if (a.issymmetric () && b.issymmetric () && info == 0)
423 return symmetric_init (a, b, calc_rev, calc_lev);
424 }
425
426 Matrix atmp = a;
427 double *atmp_data = atmp.fortran_vec ();
428
429 Matrix btmp = b;
430 double *btmp_data = btmp.fortran_vec ();
431
432 Array<double> ar (dim_vector (n, 1));
433 double *par = ar.fortran_vec ();
434
435 Array<double> ai (dim_vector (n, 1));
436 double *pai = ai.fortran_vec ();
437
438 Array<double> beta (dim_vector (n, 1));
439 double *pbeta = beta.fortran_vec ();
440
441 F77_INT tnvr = (calc_rev ? n : 0);
442 Matrix vr (tnvr, tnvr);
443 double *pvr = vr.fortran_vec ();
444
445 F77_INT tnvl = (calc_lev ? n : 0);
446 Matrix vl (tnvl, tnvl);
447 double *pvl = vl.fortran_vec ();
448
449 F77_INT lwork = -1;
450 double dummy_work;
451
452 F77_XFCN (dggev, DGGEV, (F77_CONST_CHAR_ARG2 (calc_lev ? "V" : "N", 1),
453 F77_CONST_CHAR_ARG2 (calc_rev ? "V" : "N", 1),
454 n, atmp_data, n, btmp_data, n,
455 par, pai, pbeta,
456 pvl, n, pvr, n,
457 &dummy_work, lwork, info
458 F77_CHAR_ARG_LEN (1)
459 F77_CHAR_ARG_LEN (1)));
460
461 if (info != 0)
462 (*current_liboctave_error_handler) ("dggev workspace query failed");
463
464 lwork = static_cast<F77_INT> (dummy_work);
465 Array<double> work (dim_vector (lwork, 1));
466 double *pwork = work.fortran_vec ();
467
468 F77_XFCN (dggev, DGGEV, (F77_CONST_CHAR_ARG2 (calc_lev ? "V" : "N", 1),
469 F77_CONST_CHAR_ARG2 (calc_rev ? "V" : "N", 1),
470 n, atmp_data, n, btmp_data, n,
471 par, pai, pbeta,
472 pvl, n, pvr, n,
473 pwork, lwork, info
474 F77_CHAR_ARG_LEN (1)
475 F77_CHAR_ARG_LEN (1)));
476
477 if (info < 0)
478 (*current_liboctave_error_handler) ("unrecoverable error in dggev");
479
480 if (info > 0)
481 (*current_liboctave_error_handler) ("dggev failed to converge");
482
483 lambda.resize (n);
484 F77_INT nvr = (calc_rev ? n : 0);
485 v.resize (nvr, nvr);
486
487 F77_INT nvl = (calc_lev ? n : 0);
488 w.resize (nvl, nvl);
489
490 for (F77_INT j = 0; j < n; j++)
491 {
492 if (ai.elem (j) == 0.0)
493 {
494 lambda.elem (j) = Complex (ar.elem (j) / beta.elem (j));
495 for (F77_INT i = 0; i < nvr; i++)
496 v.elem (i, j) = vr.elem (i, j);
497 for (F77_INT i = 0; i < nvl; i++)
498 w.elem (i, j) = vl.elem (i, j);
499 }
500 else
501 {
502 if (j+1 >= n)
503 (*current_liboctave_error_handler) ("EIG: internal error");
504
505 lambda.elem (j) = Complex (ar.elem (j) / beta.elem (j),
506 ai.elem (j) / beta.elem (j));
507 lambda.elem (j+1) = Complex (ar.elem (j+1) / beta.elem (j+1),
508 ai.elem (j+1) / beta.elem (j+1));
509
510 for (F77_INT i = 0; i < nvr; i++)
511 {
512 double real_part = vr.elem (i, j);
513 double imag_part = vr.elem (i, j+1);
514 v.elem (i, j) = Complex (real_part, imag_part);
515 v.elem (i, j+1) = Complex (real_part, -imag_part);
516 }
517 for (F77_INT i = 0; i < nvl; i++)
518 {
519 double real_part = vl.elem (i, j);
520 double imag_part = vl.elem (i, j+1);
521 w.elem (i, j) = Complex (real_part, imag_part);
522 w.elem (i, j+1) = Complex (real_part, -imag_part);
523 }
524 j++;
525 }
526 }
527
528 return info;
529 }
530
531 octave_idx_type
symmetric_init(const Matrix & a,const Matrix & b,bool calc_rev,bool calc_lev)532 EIG::symmetric_init (const Matrix& a, const Matrix& b, bool calc_rev,
533 bool calc_lev)
534 {
535 F77_INT n = octave::to_f77_int (a.rows ());
536 F77_INT nb = octave::to_f77_int (b.rows ());
537
538 F77_INT a_nc = octave::to_f77_int (a.cols ());
539 F77_INT b_nc = octave::to_f77_int (b.cols ());
540
541 if (n != a_nc || nb != b_nc)
542 (*current_liboctave_error_handler) ("EIG requires square matrix");
543
544 if (n != nb)
545 (*current_liboctave_error_handler) ("EIG requires same size matrices");
546
547 F77_INT info = 0;
548
549 Matrix atmp = a;
550 double *atmp_data = atmp.fortran_vec ();
551
552 Matrix btmp = b;
553 double *btmp_data = btmp.fortran_vec ();
554
555 ColumnVector wr (n);
556 double *pwr = wr.fortran_vec ();
557
558 F77_INT lwork = -1;
559 double dummy_work;
560
561 F77_XFCN (dsygv, DSYGV, (1, F77_CONST_CHAR_ARG2 (calc_rev ? "V" : "N", 1),
562 F77_CONST_CHAR_ARG2 ("U", 1),
563 n, atmp_data, n,
564 btmp_data, n,
565 pwr, &dummy_work, lwork, info
566 F77_CHAR_ARG_LEN (1)
567 F77_CHAR_ARG_LEN (1)));
568
569 if (info != 0)
570 (*current_liboctave_error_handler) ("dsygv workspace query failed");
571
572 lwork = static_cast<F77_INT> (dummy_work);
573 Array<double> work (dim_vector (lwork, 1));
574 double *pwork = work.fortran_vec ();
575
576 F77_XFCN (dsygv, DSYGV, (1, F77_CONST_CHAR_ARG2 (calc_rev ? "V" : "N", 1),
577 F77_CONST_CHAR_ARG2 ("U", 1),
578 n, atmp_data, n,
579 btmp_data, n,
580 pwr, pwork, lwork, info
581 F77_CHAR_ARG_LEN (1)
582 F77_CHAR_ARG_LEN (1)));
583
584 if (info < 0)
585 (*current_liboctave_error_handler) ("unrecoverable error in dsygv");
586
587 if (info > 0)
588 (*current_liboctave_error_handler) ("dsygv failed to converge");
589
590 lambda = ComplexColumnVector (wr);
591 v = (calc_rev ? ComplexMatrix (atmp) : ComplexMatrix ());
592 w = (calc_lev ? ComplexMatrix (atmp) : ComplexMatrix ());
593
594 return info;
595 }
596
597 octave_idx_type
init(const ComplexMatrix & a,const ComplexMatrix & b,bool calc_rev,bool calc_lev,bool force_qz)598 EIG::init (const ComplexMatrix& a, const ComplexMatrix& b, bool calc_rev,
599 bool calc_lev, bool force_qz)
600 {
601 if (a.any_element_is_inf_or_nan () || b.any_element_is_inf_or_nan ())
602 (*current_liboctave_error_handler)
603 ("EIG: matrix contains Inf or NaN values");
604
605 F77_INT n = octave::to_f77_int (a.rows ());
606 F77_INT nb = octave::to_f77_int (b.rows ());
607
608 F77_INT a_nc = octave::to_f77_int (a.cols ());
609 F77_INT b_nc = octave::to_f77_int (b.cols ());
610
611 if (n != a_nc || nb != b_nc)
612 (*current_liboctave_error_handler) ("EIG requires square matrix");
613
614 if (n != nb)
615 (*current_liboctave_error_handler) ("EIG requires same size matrices");
616
617 F77_INT info = 0;
618
619 ComplexMatrix tmp = b;
620 Complex*tmp_data = tmp.fortran_vec ();
621
622 if (! force_qz)
623 {
624 F77_XFCN (zpotrf, ZPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1),
625 n, F77_DBLE_CMPLX_ARG (tmp_data), n,
626 info
627 F77_CHAR_ARG_LEN (1)));
628
629 if (a.ishermitian () && b.ishermitian () && info == 0)
630 return hermitian_init (a, b, calc_rev, calc_lev);
631 }
632
633 ComplexMatrix atmp = a;
634 Complex *atmp_data = atmp.fortran_vec ();
635
636 ComplexMatrix btmp = b;
637 Complex *btmp_data = btmp.fortran_vec ();
638
639 ComplexColumnVector alpha (n);
640 Complex *palpha = alpha.fortran_vec ();
641
642 ComplexColumnVector beta (n);
643 Complex *pbeta = beta.fortran_vec ();
644
645 F77_INT nvr = (calc_rev ? n : 0);
646 ComplexMatrix vrtmp (nvr, nvr);
647 Complex *pvr = vrtmp.fortran_vec ();
648
649 F77_INT nvl = (calc_lev ? n : 0);
650 ComplexMatrix vltmp (nvl, nvl);
651 Complex *pvl = vltmp.fortran_vec ();
652
653 F77_INT lwork = -1;
654 Complex dummy_work;
655
656 F77_INT lrwork = 8*n;
657 Array<double> rwork (dim_vector (lrwork, 1));
658 double *prwork = rwork.fortran_vec ();
659
660 F77_XFCN (zggev, ZGGEV, (F77_CONST_CHAR_ARG2 (calc_lev ? "V" : "N", 1),
661 F77_CONST_CHAR_ARG2 (calc_rev ? "V" : "N", 1),
662 n, F77_DBLE_CMPLX_ARG (atmp_data), n,
663 F77_DBLE_CMPLX_ARG (btmp_data), n,
664 F77_DBLE_CMPLX_ARG (palpha),
665 F77_DBLE_CMPLX_ARG (pbeta),
666 F77_DBLE_CMPLX_ARG (pvl), n,
667 F77_DBLE_CMPLX_ARG (pvr), n,
668 F77_DBLE_CMPLX_ARG (&dummy_work), lwork, prwork,
669 info
670 F77_CHAR_ARG_LEN (1)
671 F77_CHAR_ARG_LEN (1)));
672
673 if (info != 0)
674 (*current_liboctave_error_handler) ("zggev workspace query failed");
675
676 lwork = static_cast<F77_INT> (dummy_work.real ());
677 Array<Complex> work (dim_vector (lwork, 1));
678 Complex *pwork = work.fortran_vec ();
679
680 F77_XFCN (zggev, ZGGEV, (F77_CONST_CHAR_ARG2 (calc_lev ? "V" : "N", 1),
681 F77_CONST_CHAR_ARG2 (calc_rev ? "V" : "N", 1),
682 n, F77_DBLE_CMPLX_ARG (atmp_data), n,
683 F77_DBLE_CMPLX_ARG (btmp_data), n,
684 F77_DBLE_CMPLX_ARG (palpha),
685 F77_DBLE_CMPLX_ARG (pbeta),
686 F77_DBLE_CMPLX_ARG (pvl), n,
687 F77_DBLE_CMPLX_ARG (pvr), n,
688 F77_DBLE_CMPLX_ARG (pwork), lwork, prwork, info
689 F77_CHAR_ARG_LEN (1)
690 F77_CHAR_ARG_LEN (1)));
691
692 if (info < 0)
693 (*current_liboctave_error_handler) ("unrecoverable error in zggev");
694
695 if (info > 0)
696 (*current_liboctave_error_handler) ("zggev failed to converge");
697
698 lambda.resize (n);
699
700 for (F77_INT j = 0; j < n; j++)
701 lambda.elem (j) = alpha.elem (j) / beta.elem (j);
702
703 v = vrtmp;
704 w = vltmp;
705
706 return info;
707 }
708
709 octave_idx_type
hermitian_init(const ComplexMatrix & a,const ComplexMatrix & b,bool calc_rev,bool calc_lev)710 EIG::hermitian_init (const ComplexMatrix& a, const ComplexMatrix& b,
711 bool calc_rev, bool calc_lev)
712 {
713 F77_INT n = octave::to_f77_int (a.rows ());
714 F77_INT nb = octave::to_f77_int (b.rows ());
715
716 F77_INT a_nc = octave::to_f77_int (a.cols ());
717 F77_INT b_nc = octave::to_f77_int (b.cols ());
718
719 if (n != a_nc || nb != b_nc)
720 (*current_liboctave_error_handler) ("EIG requires square matrix");
721
722 if (n != nb)
723 (*current_liboctave_error_handler) ("EIG requires same size matrices");
724
725 F77_INT info = 0;
726
727 ComplexMatrix atmp = a;
728 Complex *atmp_data = atmp.fortran_vec ();
729
730 ComplexMatrix btmp = b;
731 Complex *btmp_data = btmp.fortran_vec ();
732
733 ColumnVector wr (n);
734 double *pwr = wr.fortran_vec ();
735
736 F77_INT lwork = -1;
737 Complex dummy_work;
738
739 F77_INT lrwork = 3*n;
740 Array<double> rwork (dim_vector (lrwork, 1));
741 double *prwork = rwork.fortran_vec ();
742
743 F77_XFCN (zhegv, ZHEGV, (1, F77_CONST_CHAR_ARG2 (calc_rev ? "V" : "N", 1),
744 F77_CONST_CHAR_ARG2 ("U", 1),
745 n, F77_DBLE_CMPLX_ARG (atmp_data), n,
746 F77_DBLE_CMPLX_ARG (btmp_data), n,
747 pwr, F77_DBLE_CMPLX_ARG (&dummy_work), lwork,
748 prwork, info
749 F77_CHAR_ARG_LEN (1)
750 F77_CHAR_ARG_LEN (1)));
751
752 if (info != 0)
753 (*current_liboctave_error_handler) ("zhegv workspace query failed");
754
755 lwork = static_cast<F77_INT> (dummy_work.real ());
756 Array<Complex> work (dim_vector (lwork, 1));
757 Complex *pwork = work.fortran_vec ();
758
759 F77_XFCN (zhegv, ZHEGV, (1, F77_CONST_CHAR_ARG2 (calc_rev ? "V" : "N", 1),
760 F77_CONST_CHAR_ARG2 ("U", 1),
761 n, F77_DBLE_CMPLX_ARG (atmp_data), n,
762 F77_DBLE_CMPLX_ARG (btmp_data), n,
763 pwr, F77_DBLE_CMPLX_ARG (pwork), lwork, prwork, info
764 F77_CHAR_ARG_LEN (1)
765 F77_CHAR_ARG_LEN (1)));
766
767 if (info < 0)
768 (*current_liboctave_error_handler) ("unrecoverable error in zhegv");
769
770 if (info > 0)
771 (*current_liboctave_error_handler) ("zhegv failed to converge");
772
773 lambda = ComplexColumnVector (wr);
774 v = (calc_rev ? ComplexMatrix (atmp) : ComplexMatrix ());
775 w = (calc_lev ? ComplexMatrix (atmp) : ComplexMatrix ());
776
777 return info;
778 }
779