1 // Copyright (C) 2009  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 // This code was adapted from code from the JAMA part of NIST's TNT library.
4 //    See: http://math.nist.gov/tnt/
5 #ifndef DLIB_MATRIX_LU_DECOMPOSITION_H
6 #define DLIB_MATRIX_LU_DECOMPOSITION_H
7 
8 #include "matrix.h"
9 #include "matrix_utilities.h"
10 #include "matrix_subexp.h"
11 #include "matrix_trsm.h"
12 #include <algorithm>
13 
14 #ifdef DLIB_USE_LAPACK
15 #include "lapack/getrf.h"
16 #endif
17 
18 
19 namespace dlib
20 {
21 
22     template <
23         typename matrix_exp_type
24         >
25     class lu_decomposition
26     {
27     public:
28 
29         const static long NR = matrix_exp_type::NR;
30         const static long NC = matrix_exp_type::NC;
31         typedef typename matrix_exp_type::type type;
32         typedef typename matrix_exp_type::mem_manager_type mem_manager_type;
33         typedef typename matrix_exp_type::layout_type layout_type;
34 
35         typedef matrix<type,0,0,mem_manager_type,layout_type>  matrix_type;
36         typedef matrix<type,NR,1,mem_manager_type,layout_type> column_vector_type;
37         typedef matrix<long,NR,1,mem_manager_type,layout_type> pivot_column_vector_type;
38 
39         // You have supplied an invalid type of matrix_exp_type.  You have
40         // to use this object with matrices that contain float or double type data.
41         COMPILE_TIME_ASSERT((is_same_type<float, type>::value ||
42                              is_same_type<double, type>::value ));
43 
44         template <typename EXP>
45         lu_decomposition (
46             const matrix_exp<EXP> &A
47         );
48 
49         bool is_square (
50         ) const;
51 
52         bool is_singular (
53         ) const;
54 
55         long nr(
56         ) const;
57 
58         long nc(
59         ) const;
60 
61         const matrix_type get_l (
62         ) const;
63 
64         const matrix_type get_u (
65         ) const;
66 
67         const pivot_column_vector_type& get_pivot (
68         ) const;
69 
70         type det (
71         ) const;
72 
73         template <typename EXP>
74         const matrix_type solve (
75             const matrix_exp<EXP> &B
76         ) const;
77 
78     private:
79 
80         /* Array for internal storage of decomposition.  */
81         matrix<type,0,0,mem_manager_type,column_major_layout>  LU;
82         long m, n, pivsign;
83         pivot_column_vector_type piv;
84 
85 
86     };
87 
88 // ----------------------------------------------------------------------------------------
89 // ----------------------------------------------------------------------------------------
90 //                              Public member functions
91 // ----------------------------------------------------------------------------------------
92 // ----------------------------------------------------------------------------------------
93 
94     template <typename matrix_exp_type>
95     template <typename EXP>
96     lu_decomposition<matrix_exp_type>::
lu_decomposition(const matrix_exp<EXP> & A)97     lu_decomposition (
98         const matrix_exp<EXP>& A
99     ) :
100         LU(A),
101         m(A.nr()),
102         n(A.nc())
103     {
104         using namespace std;
105         using std::abs;
106 
107         COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value));
108 
109         // make sure requires clause is not broken
110         DLIB_ASSERT(A.size() > 0,
111             "\tlu_decomposition::lu_decomposition(A)"
112             << "\n\tInvalid inputs were given to this function"
113             << "\n\tA.size(): " << A.size()
114             << "\n\tthis:     " << this
115             );
116 
117 #ifdef DLIB_USE_LAPACK
118         matrix<lapack::integer,0,1,mem_manager_type,layout_type> piv_temp;
119         lapack::getrf(LU, piv_temp);
120 
121         pivsign = 1;
122 
123         // Turn the piv_temp vector into a more useful form.  This way we will have the identity
124         // rowm(A,piv) == L*U.  The permutation vector that comes out of LAPACK is somewhat
125         // different.
126         piv = trans(range(0,m-1));
127         for (long i = 0; i < piv_temp.size(); ++i)
128         {
129             // -1 because FORTRAN is indexed starting with 1 instead of 0
130             if (piv(piv_temp(i)-1) != piv(i))
131             {
132                 std::swap(piv(i), piv(piv_temp(i)-1));
133                 pivsign = -pivsign;
134             }
135         }
136 
137 #else
138 
139         // Use a "left-looking", dot-product, Crout/Doolittle algorithm.
140 
141 
142         piv = trans(range(0,m-1));
143         pivsign = 1;
144 
145         column_vector_type LUcolj(m);
146 
147         // Outer loop.
148         for (long j = 0; j < n; j++)
149         {
150 
151             // Make a copy of the j-th column to localize references.
152             LUcolj = colm(LU,j);
153 
154             // Apply previous transformations.
155             for (long i = 0; i < m; i++)
156             {
157                 // Most of the time is spent in the following dot product.
158                 const long kmax = std::min(i,j);
159                 type s;
160                 if (kmax > 0)
161                     s = rowm(LU,i, kmax)*colm(LUcolj,0,kmax);
162                 else
163                     s = 0;
164 
165                 LU(i,j) = LUcolj(i) -= s;
166             }
167 
168             // Find pivot and exchange if necessary.
169             long p = j;
170             for (long i = j+1; i < m; i++)
171             {
172                 if (abs(LUcolj(i)) > abs(LUcolj(p)))
173                 {
174                     p = i;
175                 }
176             }
177             if (p != j)
178             {
179                 long k=0;
180                 for (k = 0; k < n; k++)
181                 {
182                     type t = LU(p,k);
183                     LU(p,k) = LU(j,k);
184                     LU(j,k) = t;
185                 }
186                 k = piv(p);
187                 piv(p) = piv(j);
188                 piv(j) = k;
189                 pivsign = -pivsign;
190             }
191 
192             // Compute multipliers.
193             if ((j < m) && (LU(j,j) != 0.0))
194             {
195                 for (long i = j+1; i < m; i++)
196                 {
197                     LU(i,j) /= LU(j,j);
198                 }
199             }
200         }
201 
202 #endif
203     }
204 
205 // ----------------------------------------------------------------------------------------
206 
207     template <typename matrix_exp_type>
208     bool lu_decomposition<matrix_exp_type>::
is_square()209     is_square (
210     ) const
211     {
212         return m == n;
213     }
214 
215 // ----------------------------------------------------------------------------------------
216 
217     template <typename matrix_exp_type>
218     long lu_decomposition<matrix_exp_type>::
nr()219     nr (
220     ) const
221     {
222         return m;
223     }
224 
225 // ----------------------------------------------------------------------------------------
226 
227     template <typename matrix_exp_type>
228     long lu_decomposition<matrix_exp_type>::
nc()229     nc (
230     ) const
231     {
232         return n;
233     }
234 
235 // ----------------------------------------------------------------------------------------
236 
237     template <typename matrix_exp_type>
238     bool lu_decomposition<matrix_exp_type>::
is_singular()239     is_singular (
240     ) const
241     {
242         /* Is the matrix singular?
243           if upper triangular factor U (and hence A) is singular, false otherwise.
244         */
245         // make sure requires clause is not broken
246         DLIB_ASSERT(is_square() == true,
247             "\tbool lu_decomposition::is_singular()"
248             << "\n\tYou can only use this on square matrices"
249             << "\n\tthis: " << this
250             );
251 
252         type max_val, min_val;
253         find_min_and_max (abs(diag(LU)), min_val, max_val);
254         type eps = max_val;
255         if (eps != 0)
256             eps *= std::sqrt(std::numeric_limits<type>::epsilon())/10;
257         else
258             eps = 1;  // there is no max so just use 1
259 
260         return min_val < eps;
261     }
262 
263 // ----------------------------------------------------------------------------------------
264 
265     template <typename matrix_exp_type>
266     const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>::
get_l()267     get_l (
268     ) const
269     {
270         if (LU.nr() >= LU.nc())
271             return lowerm(LU,1.0);
272         else
273             return lowerm(subm(LU,0,0,m,m), 1.0);
274     }
275 
276 // ----------------------------------------------------------------------------------------
277 
278     template <typename matrix_exp_type>
279     const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>::
get_u()280     get_u (
281     ) const
282     {
283         if (LU.nr() >= LU.nc())
284             return upperm(subm(LU,0,0,n,n));
285         else
286             return upperm(LU);
287     }
288 
289 // ----------------------------------------------------------------------------------------
290 
291     template <typename matrix_exp_type>
292     const typename lu_decomposition<matrix_exp_type>::pivot_column_vector_type& lu_decomposition<matrix_exp_type>::
get_pivot()293     get_pivot (
294     ) const
295     {
296         return piv;
297     }
298 
299 // ----------------------------------------------------------------------------------------
300 
301     template <typename matrix_exp_type>
302     typename lu_decomposition<matrix_exp_type>::type lu_decomposition<matrix_exp_type>::
det()303     det (
304     ) const
305     {
306         // make sure requires clause is not broken
307         DLIB_ASSERT(is_square() == true,
308             "\ttype lu_decomposition::det()"
309             << "\n\tYou can only use this on square matrices"
310             << "\n\tthis: " << this
311             );
312 
313         // Check if it is singular and if it is just return 0.
314         // We want to do this because a prod() operation can easily
315         // overcome a single diagonal element that is effectively 0 when
316         // LU is a big enough matrix.
317         if (is_singular())
318             return 0;
319 
320         return prod(diag(LU))*static_cast<type>(pivsign);
321     }
322 
323 // ----------------------------------------------------------------------------------------
324 
325     template <typename matrix_exp_type>
326     template <typename EXP>
327     const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>::
solve(const matrix_exp<EXP> & B)328     solve (
329         const matrix_exp<EXP> &B
330     ) const
331     {
332         COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value));
333 
334         // make sure requires clause is not broken
335         DLIB_ASSERT(is_square() == true && B.nr() == nr(),
336             "\ttype lu_decomposition::solve()"
337             << "\n\tInvalid arguments to this function"
338             << "\n\tis_square():   " << (is_square()? "true":"false" )
339             << "\n\tB.nr():        " << B.nr()
340             << "\n\tnr():          " << nr()
341             << "\n\tthis:          " << this
342             );
343 
344         // Copy right hand side with pivoting
345         matrix<type,0,0,mem_manager_type,column_major_layout> X(rowm(B, piv));
346 
347         using namespace blas_bindings;
348         // Solve L*Y = B(piv,:)
349         triangular_solver(CblasLeft, CblasLower, CblasNoTrans, CblasUnit, LU, X);
350         // Solve U*X = Y;
351         triangular_solver(CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, LU, X);
352         return X;
353     }
354 
355 // ----------------------------------------------------------------------------------------
356 
357 }
358 
359 #endif // DLIB_MATRIX_LU_DECOMPOSITION_H
360 
361 
362