1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17
18
19
20 //! \addtogroup op_pinv
21 //! @{
22
23
24
25 template<typename T1>
26 inline
27 void
apply(Mat<typename T1::elem_type> & out,const Op<T1,op_pinv> & in)28 op_pinv::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_pinv>& in)
29 {
30 arma_extra_debug_sigprint();
31
32 typedef typename T1::pod_type T;
33
34 const T tol = access::tmp_real(in.aux);
35 const uword method_id = in.aux_uword_a;
36
37 const bool status = op_pinv::apply_direct(out, in.m, tol, method_id);
38
39 if(status == false)
40 {
41 out.soft_reset();
42 arma_stop_runtime_error("pinv(): svd failed");
43 }
44 }
45
46
47
48 template<typename T1>
49 inline
50 bool
apply_direct(Mat<typename T1::elem_type> & out,const Base<typename T1::elem_type,T1> & expr,typename T1::pod_type tol,const uword method_id)51 op_pinv::apply_direct(Mat<typename T1::elem_type>& out, const Base<typename T1::elem_type,T1>& expr, typename T1::pod_type tol, const uword method_id)
52 {
53 arma_extra_debug_sigprint();
54
55 typedef typename T1::elem_type eT;
56 typedef typename T1::pod_type T;
57
58 arma_debug_check((tol < T(0)), "pinv(): tolerance must be >= 0");
59
60 // method_id = 0 -> default setting
61 // method_id = 1 -> use standard algorithm
62 // method_id = 2 -> use divide and conquer algorithm
63
64 Mat<eT> A(expr.get_ref());
65
66 const uword n_rows = A.n_rows;
67 const uword n_cols = A.n_cols;
68
69 if(A.is_empty()) { out.set_size(n_cols,n_rows); return true; }
70
71 #if defined(ARMA_OPTIMISE_SYMPD)
72 const bool try_sympd = (auxlib::crippled_lapack(A) == false) && (tol == T(0)) && (method_id == uword(0)) && sympd_helper::guess_sympd_anysize(A);
73 #else
74 const bool try_sympd = false;
75 #endif
76
77 if(try_sympd)
78 {
79 arma_extra_debug_print("op_pinv: attempting sympd optimisation");
80
81 out = A;
82
83 const T rcond_threshold = T((std::max)(uword(100), uword(A.n_rows))) * std::numeric_limits<T>::epsilon();
84
85 const bool status = auxlib::inv_sympd_rcond(out, rcond_threshold);
86
87 if(status) { return true; }
88
89 arma_extra_debug_print("op_pinv: sympd optimisation failed");
90 // auxlib::inv_sympd_rcond() will fail if A isn't really positive definite or its rcond is below rcond_threshold
91 }
92
93 // economical SVD decomposition
94 Mat<eT> U;
95 Col< T> s;
96 Mat<eT> V;
97
98 if(n_cols > n_rows) { A = trans(A); }
99
100 const bool status = ((method_id == uword(0)) || (method_id == uword(2))) ? auxlib::svd_dc_econ(U, s, V, A) : auxlib::svd_econ(U, s, V, A, 'b');
101
102 if(status == false) { return false; }
103
104 const uword s_n_elem = s.n_elem;
105 const T* s_mem = s.memptr();
106
107 // set tolerance to default if it hasn't been specified
108 if( (tol == T(0)) && (s_n_elem > 0) )
109 {
110 tol = (std::max)(n_rows, n_cols) * s_mem[0] * std::numeric_limits<T>::epsilon();
111 }
112
113
114 uword count = 0;
115
116 for(uword i = 0; i < s_n_elem; ++i) { count += (s_mem[i] >= tol) ? uword(1) : uword(0); }
117
118 if(count == 0) { out.zeros(n_cols, n_rows); return true; }
119
120 Col<T> s2(count, arma_nozeros_indicator());
121
122 T* s2_mem = s2.memptr();
123
124 uword count2 = 0;
125
126 for(uword i=0; i < s_n_elem; ++i)
127 {
128 const T val = s_mem[i];
129
130 if(val >= tol) { s2_mem[count2] = (val > T(0)) ? T(T(1) / val) : T(0); ++count2; }
131 }
132
133
134 Mat<eT> tmp;
135
136 if(n_rows >= n_cols)
137 {
138 // out = ( (V.n_cols > count) ? V.cols(0,count-1) : V ) * diagmat(s2) * trans( (U.n_cols > count) ? U.cols(0,count-1) : U );
139
140 if(count < V.n_cols)
141 {
142 tmp = V.cols(0,count-1) * diagmat(s2);
143 }
144 else
145 {
146 tmp = V * diagmat(s2);
147 }
148
149 if(count < U.n_cols)
150 {
151 out = tmp * trans(U.cols(0,count-1));
152 }
153 else
154 {
155 out = tmp * trans(U);
156 }
157 }
158 else
159 {
160 // out = ( (U.n_cols > count) ? U.cols(0,count-1) : U ) * diagmat(s2) * trans( (V.n_cols > count) ? V.cols(0,count-1) : V );
161
162 if(count < U.n_cols)
163 {
164 tmp = U.cols(0,count-1) * diagmat(s2);
165 }
166 else
167 {
168 tmp = U * diagmat(s2);
169 }
170
171 if(count < V.n_cols)
172 {
173 out = tmp * trans(V.cols(0,count-1));
174 }
175 else
176 {
177 out = tmp * trans(V);
178 }
179 }
180
181 return true;
182 }
183
184
185
186 //! @}
187