1 #ifndef __VMML__VMMLIB_LAPACK_LINEAR_LEAST_SQUARES__HPP__
2 #define __VMML__VMMLIB_LAPACK_LINEAR_LEAST_SQUARES__HPP__
3
4 #include <vmmlib/matrix.hpp>
5 #include <vmmlib/vector.hpp>
6 #include <vmmlib/exception.hpp>
7
8 #include <vmmlib/lapack_types.hpp>
9 #include <vmmlib/lapack_includes.hpp>
10
11 #include <string>
12
13 /**
14 *
15 * this is a wrapper for the following lapack routines:
16 *
17 * xGESV
18 *
19 *
20 */
21
22
23 namespace vmml
24 {
25
26 // XYYZZZ
27 // X = data type: S - float, D - double
28 // YY = matrix type, GE - general, TR - triangular
29 // ZZZ = function name
30
31 namespace lapack
32 {
33
34 //
35 //
36 // SGELS/DGELS
37 //
38 //
39
40
41 // parameter struct
42 template< typename float_t >
43 struct llsq_params_xgels
44 {
45 char trans; // 'N'->A, 'T'->Atransposed
46 lapack_int m; // number of rows, M >= 0
47 lapack_int n; // number of columns, N >= 0
48 lapack_int nrhs; // number of columns of B/X
49 float_t* a; // input A
50 lapack_int lda; // leading dimension of A (number of rows)
51 float_t* b; // input B, output X
52 lapack_int ldb; // leading dimension of b
53 float_t* work; // workspace
54 lapack_int lwork; // workspace size
55 lapack_int info; // 'return' value
56
operator <<(std::ostream & os,const llsq_params_xgels<float_t> & p)57 friend std::ostream& operator << ( std::ostream& os,
58 const llsq_params_xgels< float_t >& p )
59 {
60 os
61 << " m " << p.m
62 << " n " << p.n
63 << " nrhs " << p.nrhs
64 << " lda " << p.lda
65 << " ldb " << p.ldb
66 << " lwork " << p.lwork
67 << " info " << p.info
68 << std::endl;
69 return os;
70 }
71
72 };
73
74 // call wrappers
75
76 #if 0
77 void dgels_(const char *trans, const int *M, const int *N, const int *nrhs,
78 double *A, const int *lda, double *b, const int *ldb, double *work,
79 const int * lwork, int *info);
80 #endif
81
82 template< typename float_t >
83 inline void
llsq_call_xgels(llsq_params_xgels<float_t> & p)84 llsq_call_xgels( llsq_params_xgels< float_t >& p )
85 {
86 VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
87 }
88
89 template<>
90 inline void
llsq_call_xgels(llsq_params_xgels<float> & p)91 llsq_call_xgels( llsq_params_xgels< float >& p )
92 {
93 sgels_(
94 &p.trans,
95 &p.m,
96 &p.n,
97 &p.nrhs,
98 p.a,
99 &p.lda,
100 p.b,
101 &p.ldb,
102 p.work,
103 &p.lwork,
104 &p.info
105 );
106 }
107
108 template<>
109 inline void
llsq_call_xgels(llsq_params_xgels<double> & p)110 llsq_call_xgels( llsq_params_xgels< double >& p )
111 {
112 dgels_(
113 &p.trans,
114 &p.m,
115 &p.n,
116 &p.nrhs,
117 p.a,
118 &p.lda,
119 p.b,
120 &p.ldb,
121 p.work,
122 &p.lwork,
123 &p.info
124 );
125
126 }
127
128
129 template< size_t M, size_t N, typename float_t >
130 struct linear_least_squares_xgels
131 {
132 void compute(
133 const matrix< M, N, float_t >& A,
134 const vector< M, float_t >& B,
135 vector< N, float_t >& x );
136
137 linear_least_squares_xgels();
138 ~linear_least_squares_xgels();
139
get_paramsvmml::lapack::linear_least_squares_xgels140 const lapack::llsq_params_xgels< float_t >& get_params(){ return p; };
141
get_factorized_Avmml::lapack::linear_least_squares_xgels142 matrix< M, N, float_t >& get_factorized_A() { return _A; }
143
144 protected:
145 matrix< M, N, float_t > _A;
146 vector< M, float_t > _b;
147
148 llsq_params_xgels< float_t > p;
149
150 };
151
152
153
154 template< size_t M, size_t N, typename float_t >
155 void
compute(const matrix<M,N,float_t> & A,const vector<M,float_t> & B,vector<N,float_t> & x)156 linear_least_squares_xgels< M, N, float_t >::compute(
157 const matrix< M, N, float_t >& A,
158 const vector< M, float_t >& B,
159 vector< N, float_t >& x )
160 {
161 _A = A;
162 _b = B;
163
164 llsq_call_xgels( p );
165
166 // success
167 if ( p.info == 0 )
168 {
169 for( size_t index = 0; index < N; ++index )
170 {
171 x( index ) = _b( index );
172 }
173
174 return;
175 }
176 if ( p.info < 0 )
177 {
178 VMMLIB_ERROR( "xGELS - invalid argument.", VMMLIB_HERE );
179 }
180 else
181 {
182 VMMLIB_ERROR( "least squares solution could not be computed.",
183 VMMLIB_HERE );
184 }
185
186 }
187
188
189
190 template< size_t M, size_t N, typename float_t >
191 linear_least_squares_xgels< M, N, float_t >::
linear_least_squares_xgels()192 linear_least_squares_xgels()
193 {
194 p.trans = 'N';
195 p.m = M;
196 p.n = N;
197 p.nrhs = 1;
198 p.a = _A.array;
199 p.lda = M;
200 p.b = _b.array;
201 p.ldb = M;
202 p.work = new float_t();
203 p.lwork = -1;
204
205 // workspace query
206 llsq_call_xgels( p );
207
208 p.lwork = static_cast< lapack_int > ( p.work[0] );
209 delete p.work;
210
211 p.work = new float_t[ p.lwork ];
212 }
213
214
215
216 template< size_t M, size_t N, typename float_t >
217 linear_least_squares_xgels< M, N, float_t >::
~linear_least_squares_xgels()218 ~linear_least_squares_xgels()
219 {
220 delete[] p.work;
221 }
222
223
224
225 //
226 //
227 // SGESV/DGESV
228 //
229 //
230
231 template< typename float_t >
232 struct llsq_params_xgesv
233 {
234 lapack_int n; // order of matrix A = M * N
235 lapack_int nrhs; // number of columns of B
236 float_t* a; // input A, output P*L*U
237 lapack_int lda; // leading dimension of A (for us: number of rows)
238 lapack_int* ipiv; // pivot indices, integer array of size N
239 float_t* b; // input b, output X
240 lapack_int ldb; // leading dimension of b
241 lapack_int info;
242
operator <<(std::ostream & os,const llsq_params_xgesv<float_t> & p)243 friend std::ostream& operator << ( std::ostream& os,
244 const llsq_params_xgesv< float_t >& p )
245 {
246 os
247 << "n " << p.n
248 << " nrhs " << p.nrhs
249 << " lda " << p.lda
250 << " ldb " << p.ldvt
251 << " info " << p.info
252 << std::endl;
253 return os;
254 }
255
256 };
257
258
259 #if 0
260 /* Subroutine */ int dgesv_(integer *n, integer *nrhs, doublereal *a, integer
261 *lda, integer *ipiv, doublereal *b, integer *ldb, integer *info);
262 #endif
263
264
265 template< typename float_t >
266 inline void
llsq_call_xgesv(llsq_params_xgesv<float_t> & p)267 llsq_call_xgesv( llsq_params_xgesv< float_t >& p )
268 {
269 VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
270 }
271
272
273 template<>
274 inline void
llsq_call_xgesv(llsq_params_xgesv<float> & p)275 llsq_call_xgesv( llsq_params_xgesv< float >& p )
276 {
277 sgesv_(
278 &p.n,
279 &p.nrhs,
280 p.a,
281 &p.lda,
282 p.ipiv,
283 p.b,
284 &p.ldb,
285 &p.info
286 );
287
288 }
289
290
291 template<>
292 inline void
llsq_call_xgesv(llsq_params_xgesv<double> & p)293 llsq_call_xgesv( llsq_params_xgesv< double >& p )
294 {
295 dgesv_(
296 &p.n,
297 &p.nrhs,
298 p.a,
299 &p.lda,
300 p.ipiv,
301 p.b,
302 &p.ldb,
303 &p.info
304 );
305 }
306
307
308 template< size_t M, size_t N, typename float_t >
309 struct linear_least_squares_xgesv
310 {
311 // computes x ( Ax = b ). x replaces b on output.
312 void compute(
313 matrix< N, N, float_t >& A,
314 matrix< N, M, float_t >& b
315 );
316
317 linear_least_squares_xgesv();
318 ~linear_least_squares_xgesv();
319
get_paramsvmml::lapack::linear_least_squares_xgesv320 const lapack::llsq_params_xgesv< float_t >& get_params() { return p; }
321
322 lapack::llsq_params_xgesv< float_t > p;
323
324 }; // struct lapack_linear_least_squares
325
326
327 template< size_t M, size_t N, typename float_t >
328 void
329 linear_least_squares_xgesv< M, N, float_t >::
compute(matrix<N,N,float_t> & A,matrix<N,M,float_t> & b)330 compute(
331 matrix< N, N, float_t >& A,
332 matrix< N, M, float_t >& b
333 )
334 {
335 p.a = A.array;
336 p.b = b.array;
337
338 lapack::llsq_call_xgesv( p );
339
340 if ( p.info != 0 )
341 {
342 if ( p.info < 0 )
343 VMMLIB_ERROR( "invalid value in input matrix", VMMLIB_HERE );
344 else
345 VMMLIB_ERROR( "factor U is exactly singular, solution could not be computed.", VMMLIB_HERE );
346 }
347 }
348
349
350
351 template< size_t M, size_t N, typename float_t >
352 linear_least_squares_xgesv< M, N, float_t >::
linear_least_squares_xgesv()353 linear_least_squares_xgesv()
354 {
355 p.n = N;
356 p.nrhs = M;
357 p.lda = N;
358 p.ldb = N;
359 p.ipiv = new lapack_int[ N ];
360
361 }
362
363
364
365 template< size_t M, size_t N, typename float_t >
366 linear_least_squares_xgesv< M, N, float_t >::
~linear_least_squares_xgesv()367 ~linear_least_squares_xgesv()
368 {
369 delete[] p.ipiv;
370 }
371
372
373 } // namespace lapack
374
375 } // namespace vmml
376
377 #endif
378
379