1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17 
18 
19 
20 //! \addtogroup op_pinv
21 //! @{
22 
23 
24 
25 template<typename T1>
26 inline
27 void
apply(Mat<typename T1::elem_type> & out,const Op<T1,op_pinv> & in)28 op_pinv::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_pinv>& in)
29   {
30   arma_extra_debug_sigprint();
31 
32   typedef typename T1::pod_type T;
33 
34   const T     tol       = access::tmp_real(in.aux);
35   const uword method_id = in.aux_uword_a;
36 
37   const bool status = op_pinv::apply_direct(out, in.m, tol, method_id);
38 
39   if(status == false)
40     {
41     out.soft_reset();
42     arma_stop_runtime_error("pinv(): svd failed");
43     }
44   }
45 
46 
47 
48 template<typename T1>
49 inline
50 bool
apply_direct(Mat<typename T1::elem_type> & out,const Base<typename T1::elem_type,T1> & expr,typename T1::pod_type tol,const uword method_id)51 op_pinv::apply_direct(Mat<typename T1::elem_type>& out, const Base<typename T1::elem_type,T1>& expr, typename T1::pod_type tol, const uword method_id)
52   {
53   arma_extra_debug_sigprint();
54 
55   typedef typename T1::elem_type eT;
56   typedef typename T1::pod_type   T;
57 
58   arma_debug_check((tol < T(0)), "pinv(): tolerance must be >= 0");
59 
60   // method_id = 0 -> default setting
61   // method_id = 1 -> use standard algorithm
62   // method_id = 2 -> use divide and conquer algorithm
63 
64   Mat<eT> A(expr.get_ref());
65 
66   const uword n_rows = A.n_rows;
67   const uword n_cols = A.n_cols;
68 
69   if(A.is_empty())  { out.set_size(n_cols,n_rows); return true; }
70 
71   #if defined(ARMA_OPTIMISE_SYMPD)
72     const bool try_sympd = (auxlib::crippled_lapack(A) == false) && (tol == T(0)) && (method_id == uword(0)) && sympd_helper::guess_sympd_anysize(A);
73   #else
74     const bool try_sympd = false;
75   #endif
76 
77   if(try_sympd)
78     {
79     arma_extra_debug_print("op_pinv: attempting sympd optimisation");
80 
81     out = A;
82 
83     const T rcond_threshold = T((std::max)(uword(100), uword(A.n_rows))) * std::numeric_limits<T>::epsilon();
84 
85     const bool status = auxlib::inv_sympd_rcond(out, rcond_threshold);
86 
87     if(status)  { return true; }
88 
89     arma_extra_debug_print("op_pinv: sympd optimisation failed");
90     // auxlib::inv_sympd_rcond() will fail if A isn't really positive definite or its rcond is below rcond_threshold
91     }
92 
93   // economical SVD decomposition
94   Mat<eT> U;
95   Col< T> s;
96   Mat<eT> V;
97 
98   if(n_cols > n_rows)  { A = trans(A); }
99 
100   const bool status = ((method_id == uword(0)) || (method_id == uword(2))) ? auxlib::svd_dc_econ(U, s, V, A) : auxlib::svd_econ(U, s, V, A, 'b');
101 
102   if(status == false)  { return false; }
103 
104   const uword s_n_elem = s.n_elem;
105   const T*    s_mem    = s.memptr();
106 
107   // set tolerance to default if it hasn't been specified
108   if( (tol == T(0)) && (s_n_elem > 0) )
109     {
110     tol = (std::max)(n_rows, n_cols) * s_mem[0] * std::numeric_limits<T>::epsilon();
111     }
112 
113 
114   uword count = 0;
115 
116   for(uword i = 0; i < s_n_elem; ++i)  { count += (s_mem[i] >= tol) ? uword(1) : uword(0); }
117 
118   if(count == 0)  { out.zeros(n_cols, n_rows); return true; }
119 
120   Col<T> s2(count, arma_nozeros_indicator());
121 
122   T* s2_mem = s2.memptr();
123 
124   uword count2 = 0;
125 
126   for(uword i=0; i < s_n_elem; ++i)
127     {
128     const T val = s_mem[i];
129 
130     if(val >= tol)  { s2_mem[count2] = (val > T(0)) ? T(T(1) / val) : T(0); ++count2; }
131     }
132 
133 
134   Mat<eT> tmp;
135 
136   if(n_rows >= n_cols)
137     {
138     // out = ( (V.n_cols > count) ? V.cols(0,count-1) : V ) * diagmat(s2) * trans( (U.n_cols > count) ? U.cols(0,count-1) : U );
139 
140     if(count < V.n_cols)
141       {
142       tmp = V.cols(0,count-1) * diagmat(s2);
143       }
144     else
145       {
146       tmp = V * diagmat(s2);
147       }
148 
149     if(count < U.n_cols)
150       {
151       out = tmp * trans(U.cols(0,count-1));
152       }
153     else
154       {
155       out = tmp * trans(U);
156       }
157     }
158   else
159     {
160     // out = ( (U.n_cols > count) ? U.cols(0,count-1) : U ) * diagmat(s2) * trans( (V.n_cols > count) ? V.cols(0,count-1) : V );
161 
162     if(count < U.n_cols)
163       {
164       tmp = U.cols(0,count-1) * diagmat(s2);
165       }
166     else
167       {
168       tmp = U * diagmat(s2);
169       }
170 
171     if(count < V.n_cols)
172       {
173       out = tmp * trans(V.cols(0,count-1));
174       }
175     else
176       {
177       out = tmp * trans(V);
178       }
179     }
180 
181   return true;
182   }
183 
184 
185 
186 //! @}
187