1 // Copyright (C) 2009 Davis E. King (davis@dlib.net) 2 // License: Boost Software License See LICENSE.txt for the full license. 3 // This code was adapted from code from the JAMA part of NIST's TNT library. 4 // See: http://math.nist.gov/tnt/ 5 #ifndef DLIB_MATRIX_LU_DECOMPOSITION_H 6 #define DLIB_MATRIX_LU_DECOMPOSITION_H 7 8 #include "matrix.h" 9 #include "matrix_utilities.h" 10 #include "matrix_subexp.h" 11 #include "matrix_trsm.h" 12 #include <algorithm> 13 14 #ifdef DLIB_USE_LAPACK 15 #include "lapack/getrf.h" 16 #endif 17 18 19 namespace dlib 20 { 21 22 template < 23 typename matrix_exp_type 24 > 25 class lu_decomposition 26 { 27 public: 28 29 const static long NR = matrix_exp_type::NR; 30 const static long NC = matrix_exp_type::NC; 31 typedef typename matrix_exp_type::type type; 32 typedef typename matrix_exp_type::mem_manager_type mem_manager_type; 33 typedef typename matrix_exp_type::layout_type layout_type; 34 35 typedef matrix<type,0,0,mem_manager_type,layout_type> matrix_type; 36 typedef matrix<type,NR,1,mem_manager_type,layout_type> column_vector_type; 37 typedef matrix<long,NR,1,mem_manager_type,layout_type> pivot_column_vector_type; 38 39 // You have supplied an invalid type of matrix_exp_type. You have 40 // to use this object with matrices that contain float or double type data. 41 COMPILE_TIME_ASSERT((is_same_type<float, type>::value || 42 is_same_type<double, type>::value )); 43 44 template <typename EXP> 45 lu_decomposition ( 46 const matrix_exp<EXP> &A 47 ); 48 49 bool is_square ( 50 ) const; 51 52 bool is_singular ( 53 ) const; 54 55 long nr( 56 ) const; 57 58 long nc( 59 ) const; 60 61 const matrix_type get_l ( 62 ) const; 63 64 const matrix_type get_u ( 65 ) const; 66 67 const pivot_column_vector_type& get_pivot ( 68 ) const; 69 70 type det ( 71 ) const; 72 73 template <typename EXP> 74 const matrix_type solve ( 75 const matrix_exp<EXP> &B 76 ) const; 77 78 private: 79 80 /* Array for internal storage of decomposition. */ 81 matrix<type,0,0,mem_manager_type,column_major_layout> LU; 82 long m, n, pivsign; 83 pivot_column_vector_type piv; 84 85 86 }; 87 88 // ---------------------------------------------------------------------------------------- 89 // ---------------------------------------------------------------------------------------- 90 // Public member functions 91 // ---------------------------------------------------------------------------------------- 92 // ---------------------------------------------------------------------------------------- 93 94 template <typename matrix_exp_type> 95 template <typename EXP> 96 lu_decomposition<matrix_exp_type>:: lu_decomposition(const matrix_exp<EXP> & A)97 lu_decomposition ( 98 const matrix_exp<EXP>& A 99 ) : 100 LU(A), 101 m(A.nr()), 102 n(A.nc()) 103 { 104 using namespace std; 105 using std::abs; 106 107 COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value)); 108 109 // make sure requires clause is not broken 110 DLIB_ASSERT(A.size() > 0, 111 "\tlu_decomposition::lu_decomposition(A)" 112 << "\n\tInvalid inputs were given to this function" 113 << "\n\tA.size(): " << A.size() 114 << "\n\tthis: " << this 115 ); 116 117 #ifdef DLIB_USE_LAPACK 118 matrix<lapack::integer,0,1,mem_manager_type,layout_type> piv_temp; 119 lapack::getrf(LU, piv_temp); 120 121 pivsign = 1; 122 123 // Turn the piv_temp vector into a more useful form. This way we will have the identity 124 // rowm(A,piv) == L*U. The permutation vector that comes out of LAPACK is somewhat 125 // different. 126 piv = trans(range(0,m-1)); 127 for (long i = 0; i < piv_temp.size(); ++i) 128 { 129 // -1 because FORTRAN is indexed starting with 1 instead of 0 130 if (piv(piv_temp(i)-1) != piv(i)) 131 { 132 std::swap(piv(i), piv(piv_temp(i)-1)); 133 pivsign = -pivsign; 134 } 135 } 136 137 #else 138 139 // Use a "left-looking", dot-product, Crout/Doolittle algorithm. 140 141 142 piv = trans(range(0,m-1)); 143 pivsign = 1; 144 145 column_vector_type LUcolj(m); 146 147 // Outer loop. 148 for (long j = 0; j < n; j++) 149 { 150 151 // Make a copy of the j-th column to localize references. 152 LUcolj = colm(LU,j); 153 154 // Apply previous transformations. 155 for (long i = 0; i < m; i++) 156 { 157 // Most of the time is spent in the following dot product. 158 const long kmax = std::min(i,j); 159 type s; 160 if (kmax > 0) 161 s = rowm(LU,i, kmax)*colm(LUcolj,0,kmax); 162 else 163 s = 0; 164 165 LU(i,j) = LUcolj(i) -= s; 166 } 167 168 // Find pivot and exchange if necessary. 169 long p = j; 170 for (long i = j+1; i < m; i++) 171 { 172 if (abs(LUcolj(i)) > abs(LUcolj(p))) 173 { 174 p = i; 175 } 176 } 177 if (p != j) 178 { 179 long k=0; 180 for (k = 0; k < n; k++) 181 { 182 type t = LU(p,k); 183 LU(p,k) = LU(j,k); 184 LU(j,k) = t; 185 } 186 k = piv(p); 187 piv(p) = piv(j); 188 piv(j) = k; 189 pivsign = -pivsign; 190 } 191 192 // Compute multipliers. 193 if ((j < m) && (LU(j,j) != 0.0)) 194 { 195 for (long i = j+1; i < m; i++) 196 { 197 LU(i,j) /= LU(j,j); 198 } 199 } 200 } 201 202 #endif 203 } 204 205 // ---------------------------------------------------------------------------------------- 206 207 template <typename matrix_exp_type> 208 bool lu_decomposition<matrix_exp_type>:: is_square()209 is_square ( 210 ) const 211 { 212 return m == n; 213 } 214 215 // ---------------------------------------------------------------------------------------- 216 217 template <typename matrix_exp_type> 218 long lu_decomposition<matrix_exp_type>:: nr()219 nr ( 220 ) const 221 { 222 return m; 223 } 224 225 // ---------------------------------------------------------------------------------------- 226 227 template <typename matrix_exp_type> 228 long lu_decomposition<matrix_exp_type>:: nc()229 nc ( 230 ) const 231 { 232 return n; 233 } 234 235 // ---------------------------------------------------------------------------------------- 236 237 template <typename matrix_exp_type> 238 bool lu_decomposition<matrix_exp_type>:: is_singular()239 is_singular ( 240 ) const 241 { 242 /* Is the matrix singular? 243 if upper triangular factor U (and hence A) is singular, false otherwise. 244 */ 245 // make sure requires clause is not broken 246 DLIB_ASSERT(is_square() == true, 247 "\tbool lu_decomposition::is_singular()" 248 << "\n\tYou can only use this on square matrices" 249 << "\n\tthis: " << this 250 ); 251 252 type max_val, min_val; 253 find_min_and_max (abs(diag(LU)), min_val, max_val); 254 type eps = max_val; 255 if (eps != 0) 256 eps *= std::sqrt(std::numeric_limits<type>::epsilon())/10; 257 else 258 eps = 1; // there is no max so just use 1 259 260 return min_val < eps; 261 } 262 263 // ---------------------------------------------------------------------------------------- 264 265 template <typename matrix_exp_type> 266 const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>:: get_l()267 get_l ( 268 ) const 269 { 270 if (LU.nr() >= LU.nc()) 271 return lowerm(LU,1.0); 272 else 273 return lowerm(subm(LU,0,0,m,m), 1.0); 274 } 275 276 // ---------------------------------------------------------------------------------------- 277 278 template <typename matrix_exp_type> 279 const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>:: get_u()280 get_u ( 281 ) const 282 { 283 if (LU.nr() >= LU.nc()) 284 return upperm(subm(LU,0,0,n,n)); 285 else 286 return upperm(LU); 287 } 288 289 // ---------------------------------------------------------------------------------------- 290 291 template <typename matrix_exp_type> 292 const typename lu_decomposition<matrix_exp_type>::pivot_column_vector_type& lu_decomposition<matrix_exp_type>:: get_pivot()293 get_pivot ( 294 ) const 295 { 296 return piv; 297 } 298 299 // ---------------------------------------------------------------------------------------- 300 301 template <typename matrix_exp_type> 302 typename lu_decomposition<matrix_exp_type>::type lu_decomposition<matrix_exp_type>:: det()303 det ( 304 ) const 305 { 306 // make sure requires clause is not broken 307 DLIB_ASSERT(is_square() == true, 308 "\ttype lu_decomposition::det()" 309 << "\n\tYou can only use this on square matrices" 310 << "\n\tthis: " << this 311 ); 312 313 // Check if it is singular and if it is just return 0. 314 // We want to do this because a prod() operation can easily 315 // overcome a single diagonal element that is effectively 0 when 316 // LU is a big enough matrix. 317 if (is_singular()) 318 return 0; 319 320 return prod(diag(LU))*static_cast<type>(pivsign); 321 } 322 323 // ---------------------------------------------------------------------------------------- 324 325 template <typename matrix_exp_type> 326 template <typename EXP> 327 const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>:: solve(const matrix_exp<EXP> & B)328 solve ( 329 const matrix_exp<EXP> &B 330 ) const 331 { 332 COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value)); 333 334 // make sure requires clause is not broken 335 DLIB_ASSERT(is_square() == true && B.nr() == nr(), 336 "\ttype lu_decomposition::solve()" 337 << "\n\tInvalid arguments to this function" 338 << "\n\tis_square(): " << (is_square()? "true":"false" ) 339 << "\n\tB.nr(): " << B.nr() 340 << "\n\tnr(): " << nr() 341 << "\n\tthis: " << this 342 ); 343 344 // Copy right hand side with pivoting 345 matrix<type,0,0,mem_manager_type,column_major_layout> X(rowm(B, piv)); 346 347 using namespace blas_bindings; 348 // Solve L*Y = B(piv,:) 349 triangular_solver(CblasLeft, CblasLower, CblasNoTrans, CblasUnit, LU, X); 350 // Solve U*X = Y; 351 triangular_solver(CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, LU, X); 352 return X; 353 } 354 355 // ---------------------------------------------------------------------------------------- 356 357 } 358 359 #endif // DLIB_MATRIX_LU_DECOMPOSITION_H 360 361 362