1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17 
18 
19 namespace newarp
20 {
21 
22 
23 template<typename eT>
24 inline
TridiagEigen()25 TridiagEigen<eT>::TridiagEigen()
26   : n(0)
27   , computed(false)
28   {
29   arma_extra_debug_sigprint();
30   }
31 
32 
33 
34 template<typename eT>
35 inline
TridiagEigen(const Mat<eT> & mat_obj)36 TridiagEigen<eT>::TridiagEigen(const Mat<eT>& mat_obj)
37   : n(mat_obj.n_rows)
38   , computed(false)
39   {
40   arma_extra_debug_sigprint();
41 
42   compute(mat_obj);
43   }
44 
45 
46 
47 template<typename eT>
48 inline
49 void
compute(const Mat<eT> & mat_obj)50 TridiagEigen<eT>::compute(const Mat<eT>& mat_obj)
51   {
52   arma_extra_debug_sigprint();
53 
54   arma_debug_check( (mat_obj.is_square() == false), "newarp::TridiagEigen::compute(): matrix must be square" );
55 
56   n = blas_int(mat_obj.n_rows);
57 
58   main_diag = mat_obj.diag();
59   sub_diag  = mat_obj.diag(-1);
60 
61   evecs.set_size(n, n);
62 
63   char     compz      = 'I';
64   blas_int lwork      = blas_int(-1);
65   eT       lwork_opt  = eT(0);
66 
67   blas_int liwork     = blas_int(-1);
68   blas_int liwork_opt = blas_int(0);
69   blas_int info       = blas_int(0);
70 
71   // query for lwork and liwork
72   lapack::stedc(&compz, &n, main_diag.memptr(), sub_diag.memptr(), evecs.memptr(), &n, &lwork_opt, &lwork, &liwork_opt, &liwork, &info);
73 
74   if(info == 0)
75     {
76     lwork  = blas_int(lwork_opt);
77     liwork = liwork_opt;
78     }
79   else
80     {
81     lwork  = 1 + 4 * n + n * n;
82     liwork = 3 + 5 * n;
83     }
84 
85   info = blas_int(0);
86 
87   podarray<eT>        work(static_cast<uword>(lwork) );
88   podarray<blas_int> iwork(static_cast<uword>(liwork));
89 
90   lapack::stedc(&compz, &n, main_diag.memptr(), sub_diag.memptr(), evecs.memptr(), &n, work.memptr(), &lwork, iwork.memptr(), &liwork, &info);
91 
92   if(info < 0)  { arma_stop_logic_error("lapack::stedc(): illegal value"); return; }
93 
94   if(info > 0)  { arma_stop_runtime_error("lapack::stedc(): failed to compute all eigenvalues"); return; }
95 
96   computed = true;
97   }
98 
99 
100 
101 template<typename eT>
102 inline
103 Col<eT>
eigenvalues()104 TridiagEigen<eT>::eigenvalues()
105   {
106   arma_extra_debug_sigprint();
107 
108   arma_debug_check( (computed == false), "newarp::TridiagEigen::eigenvalues(): need to call compute() first" );
109 
110   // After calling compute(), main_diag will contain the eigenvalues.
111   return main_diag;
112   }
113 
114 
115 
116 template<typename eT>
117 inline
118 Mat<eT>
eigenvectors()119 TridiagEigen<eT>::eigenvectors()
120   {
121   arma_extra_debug_sigprint();
122 
123   arma_debug_check( (computed == false), "newarp::TridiagEigen::eigenvectors(): need to call compute() first" );
124 
125   return evecs;
126   }
127 
128 
129 }  // namespace newarp
130