1 #ifndef __VMML__VMMLIB_LAPACK_LINEAR_LEAST_SQUARES__HPP__
2 #define __VMML__VMMLIB_LAPACK_LINEAR_LEAST_SQUARES__HPP__
3
4 #include <vmmlib/matrix.hpp>
5 #include <vmmlib/vector.hpp>
6 #include <vmmlib/exception.hpp>
7
8 #include <vmmlib/lapack_types.hpp>
9 #include <vmmlib/lapack_includes.hpp>
10
11 #include <string>
12
13 /**
14 *
15 * this is a wrapper for the following lapack routines:
16 *
17 * xGELS
18 *
19 *
20 */
21
22
23 namespace vmml
24 {
25
26 // XYYZZZ
27 // X = data type: S - float, D - double
28 // YY = matrix type, GE - general, TR - triangular
29 // ZZZ = function name
30
31 namespace lapack
32 {
33
34 //
35 //
36 // SGELS/DGELS
37 //
38 //
39
40
41 // parameter struct
42 template< typename float_t >
43 struct llsq_params_xgels
44 {
45 char trans; // 'N'->A, 'T'->Atransposed
46 lapack_int m; // number of rows, M >= 0
47 lapack_int n; // number of columns, N >= 0
48 lapack_int nrhs; // number of columns of B/X
49 float_t* a; // input A
50 lapack_int lda; // leading dimension of A (number of rows)
51 float_t* b; // input B, output X
52 lapack_int ldb; // leading dimension of b
53 float_t* work; // workspace
54 lapack_int lwork; // workspace size
55 lapack_int info; // 'return' value
56
operator <<(std::ostream & os,const llsq_params_xgels<float_t> & p)57 friend std::ostream& operator << ( std::ostream& os,
58 const llsq_params_xgels< float_t >& p )
59 {
60 os
61 << " m " << p.m
62 << " n " << p.n
63 << " nrhs " << p.nrhs
64 << " lda " << p.lda
65 << " ldb " << p.ldb
66 << " lwork " << p.lwork
67 << " info " << p.info
68 << std::endl;
69 return os;
70 }
71
72 };
73
74 // call wrappers
75
76 #if 0
77 void dgels_(const char *trans, const int *M, const int *N, const int *nrhs,
78 double *A, const int *lda, double *b, const int *ldb, double *work,
79 const int * lwork, int *info);
80 #endif
81
82 template< typename float_t >
83 inline void
llsq_call_xgels(llsq_params_xgels<float_t> & p)84 llsq_call_xgels( llsq_params_xgels< float_t >& p )
85 {
86 VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
87 }
88
89 template<>
90 inline void
llsq_call_xgels(llsq_params_xgels<float> & p)91 llsq_call_xgels( llsq_params_xgels< float >& p )
92 {
93 sgels_(
94 &p.trans,
95 &p.m,
96 &p.n,
97 &p.nrhs,
98 p.a,
99 &p.lda,
100 p.b,
101 &p.ldb,
102 p.work,
103 &p.lwork,
104 &p.info
105 );
106 }
107
108 template<>
109 inline void
llsq_call_xgels(llsq_params_xgels<double> & p)110 llsq_call_xgels( llsq_params_xgels< double >& p )
111 {
112 dgels_(
113 &p.trans,
114 &p.m,
115 &p.n,
116 &p.nrhs,
117 p.a,
118 &p.lda,
119 p.b,
120 &p.ldb,
121 p.work,
122 &p.lwork,
123 &p.info
124 );
125
126 }
127
128
129 template< size_t M, size_t N, typename float_t >
130 struct linear_least_squares_xgels
131 {
132 bool compute(
133 const matrix< M, N, float_t >& A,
134 const vector< M, float_t >& B,
135 vector< N, float_t >& x );
136
137 linear_least_squares_xgels();
138 ~linear_least_squares_xgels();
139
get_paramsvmml::lapack::linear_least_squares_xgels140 const lapack::llsq_params_xgels< float_t >& get_params(){ return p; };
141
get_factorized_Avmml::lapack::linear_least_squares_xgels142 matrix< M, N, float_t >& get_factorized_A() { return _A; }
143
144 protected:
145 matrix< M, N, float_t > _A;
146 vector< M, float_t > _b;
147
148 llsq_params_xgels< float_t > p;
149
150 };
151
152
153
154 template< size_t M, size_t N, typename float_t >
155 bool
compute(const matrix<M,N,float_t> & A,const vector<M,float_t> & B,vector<N,float_t> & x)156 linear_least_squares_xgels< M, N, float_t >::compute(
157 const matrix< M, N, float_t >& A,
158 const vector< M, float_t >& B,
159 vector< N, float_t >& x )
160 {
161 _A = A;
162 _b = B;
163
164 llsq_call_xgels( p );
165
166 // success
167 if ( p.info == 0 )
168 {
169 for( size_t index = 0; index < N; ++index )
170 {
171 x( index ) = _b( index );
172 }
173
174 return true;
175 }
176 if ( p.info < 0 )
177 {
178 VMMLIB_ERROR( "xGELS - invalid argument.", VMMLIB_HERE );
179 }
180 else
181 {
182 std::cout << "A\n" << A << std::endl;
183 std::cout << "B\n" << B << std::endl;
184
185 VMMLIB_ERROR( "least squares solution could not be computed.",
186 VMMLIB_HERE );
187 }
188 return false;
189 }
190
191
192
193 template< size_t M, size_t N, typename float_t >
194 linear_least_squares_xgels< M, N, float_t >::
linear_least_squares_xgels()195 linear_least_squares_xgels()
196 {
197 p.trans = 'N';
198 p.m = M;
199 p.n = N;
200 p.nrhs = 1;
201 p.a = _A.array;
202 p.lda = M;
203 p.b = _b.array;
204 p.ldb = M;
205 p.work = new float_t();
206 p.lwork = -1;
207
208 // workspace query
209 llsq_call_xgels( p );
210
211 p.lwork = static_cast< lapack_int > ( p.work[0] );
212 delete p.work;
213
214 p.work = new float_t[ p.lwork ];
215 }
216
217
218
219 template< size_t M, size_t N, typename float_t >
220 linear_least_squares_xgels< M, N, float_t >::
~linear_least_squares_xgels()221 ~linear_least_squares_xgels()
222 {
223 delete[] p.work;
224 }
225
226
227
228 //
229 //
230 // SGESV/DGESV
231 //
232 //
233
234 template< typename float_t >
235 struct llsq_params_xgesv
236 {
237 lapack_int n; // order of matrix A = M * N
238 lapack_int nrhs; // number of columns of B
239 float_t* a; // input A, output P*L*U
240 lapack_int lda; // leading dimension of A (for us: number of rows)
241 lapack_int* ipiv; // pivot indices, integer array of size N
242 float_t* b; // input b, output X
243 lapack_int ldb; // leading dimension of b
244 lapack_int info;
245
operator <<(std::ostream & os,const llsq_params_xgesv<float_t> & p)246 friend std::ostream& operator << ( std::ostream& os,
247 const llsq_params_xgesv< float_t >& p )
248 {
249 os
250 << "n " << p.n
251 << " nrhs " << p.nrhs
252 << " lda " << p.lda
253 << " ldb " << p.ldvt
254 << " info " << p.info
255 << std::endl;
256 return os;
257 }
258
259 };
260
261
262 #if 0
263 /* Subroutine */ int dgesv_(integer *n, integer *nrhs, doublereal *a, integer
264 *lda, integer *ipiv, doublereal *b, integer *ldb, integer *info);
265 #endif
266
267
268 template< typename float_t >
269 inline void
llsq_call_xgesv(llsq_params_xgesv<float_t> & p)270 llsq_call_xgesv( llsq_params_xgesv< float_t >& p )
271 {
272 VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
273 }
274
275
276 template<>
277 inline void
llsq_call_xgesv(llsq_params_xgesv<float> & p)278 llsq_call_xgesv( llsq_params_xgesv< float >& p )
279 {
280 sgesv_(
281 &p.n,
282 &p.nrhs,
283 p.a,
284 &p.lda,
285 p.ipiv,
286 p.b,
287 &p.ldb,
288 &p.info
289 );
290
291 }
292
293
294 template<>
295 inline void
llsq_call_xgesv(llsq_params_xgesv<double> & p)296 llsq_call_xgesv( llsq_params_xgesv< double >& p )
297 {
298 dgesv_(
299 &p.n,
300 &p.nrhs,
301 p.a,
302 &p.lda,
303 p.ipiv,
304 p.b,
305 &p.ldb,
306 &p.info
307 );
308 }
309
310
311 template< size_t M, size_t N, typename float_t >
312 struct linear_least_squares_xgesv
313 {
314 // computes x ( Ax = b ). x replaces b on output.
315 void compute(
316 matrix< N, N, float_t >& A,
317 matrix< N, M, float_t >& b
318 );
319
320 linear_least_squares_xgesv();
321 ~linear_least_squares_xgesv();
322
get_paramsvmml::lapack::linear_least_squares_xgesv323 const lapack::llsq_params_xgesv< float_t >& get_params() { return p; }
324
325 lapack::llsq_params_xgesv< float_t > p;
326
327 }; // struct lapack_linear_least_squares
328
329
330 template< size_t M, size_t N, typename float_t >
331 void
332 linear_least_squares_xgesv< M, N, float_t >::
compute(matrix<N,N,float_t> & A,matrix<N,M,float_t> & b)333 compute(
334 matrix< N, N, float_t >& A,
335 matrix< N, M, float_t >& b
336 )
337 {
338 p.a = A.array;
339 p.b = b.array;
340
341 lapack::llsq_call_xgesv( p );
342
343 if ( p.info != 0 )
344 {
345 if ( p.info < 0 )
346 VMMLIB_ERROR( "invalid value in input matrix", VMMLIB_HERE );
347 else
348 VMMLIB_ERROR( "factor U is exactly singular, solution could not be computed.", VMMLIB_HERE );
349 }
350 }
351
352
353
354 template< size_t M, size_t N, typename float_t >
355 linear_least_squares_xgesv< M, N, float_t >::
linear_least_squares_xgesv()356 linear_least_squares_xgesv()
357 {
358 p.n = N;
359 p.nrhs = M;
360 p.lda = N;
361 p.ldb = N;
362 p.ipiv = new lapack_int[ N ];
363
364 }
365
366
367
368 template< size_t M, size_t N, typename float_t >
369 linear_least_squares_xgesv< M, N, float_t >::
~linear_least_squares_xgesv()370 ~linear_least_squares_xgesv()
371 {
372 delete[] p.ipiv;
373 }
374
375
376 } // namespace lapack
377
378 } // namespace vmml
379
380 #endif
381
382