1 #ifndef __VMML__VMMLIB_LAPACK_LINEAR_LEAST_SQUARES__HPP__
2 #define __VMML__VMMLIB_LAPACK_LINEAR_LEAST_SQUARES__HPP__
3 
4 #include <vmmlib/matrix.hpp>
5 #include <vmmlib/vector.hpp>
6 #include <vmmlib/exception.hpp>
7 
8 #include <vmmlib/lapack_types.hpp>
9 #include <vmmlib/lapack_includes.hpp>
10 
11 #include <string>
12 
13 /**
14 *
15 * this is a wrapper for the following lapack routines:
16 *
17 * xGELS
18 *
19 *
20 */
21 
22 
23 namespace vmml
24 {
25 
26 // XYYZZZ
27 // X    = data type: S - float, D - double
28 // YY   = matrix type, GE - general, TR - triangular
29 // ZZZ  = function name
30 
31 namespace lapack
32 {
33 
34 //
35 //
36 // SGELS/DGELS
37 //
38 //
39 
40 
41 // parameter struct
42 template< typename float_t >
43 struct llsq_params_xgels
44 {
45     char            trans;  // 'N'->A, 'T'->Atransposed
46     lapack_int      m;      // number of rows,      M >= 0
47     lapack_int      n;      // number of columns,   N >= 0
48     lapack_int      nrhs;   // number of columns of B/X
49     float_t*        a;      // input A
50     lapack_int      lda;    // leading dimension of A (number of rows)
51     float_t*        b;      // input B, output X
52     lapack_int      ldb;    // leading dimension of b
53     float_t*        work;   // workspace
54     lapack_int      lwork;  // workspace size
55     lapack_int      info;   // 'return' value
56 
operator <<(std::ostream & os,const llsq_params_xgels<float_t> & p)57     friend std::ostream& operator << ( std::ostream& os,
58         const llsq_params_xgels< float_t >& p )
59     {
60         os
61             << " m "        << p.m
62             << " n "        << p.n
63             << " nrhs "     << p.nrhs
64             << " lda "      << p.lda
65             << " ldb "      << p.ldb
66             << " lwork "    << p.lwork
67             << " info "     << p.info
68             << std::endl;
69         return os;
70     }
71 
72 };
73 
74 // call wrappers
75 
76 #if 0
77 void dgels_(const char *trans, const int *M, const int *N, const int *nrhs,
78     double *A, const int *lda, double *b, const int *ldb, double *work,
79     const int * lwork, int *info);
80 #endif
81 
82 template< typename float_t >
83 inline void
llsq_call_xgels(llsq_params_xgels<float_t> & p)84 llsq_call_xgels( llsq_params_xgels< float_t >& p )
85 {
86     VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
87 }
88 
89 template<>
90 inline void
llsq_call_xgels(llsq_params_xgels<float> & p)91 llsq_call_xgels( llsq_params_xgels< float >& p )
92 {
93     sgels_(
94         &p.trans,
95         &p.m,
96         &p.n,
97         &p.nrhs,
98         p.a,
99         &p.lda,
100         p.b,
101         &p.ldb,
102         p.work,
103         &p.lwork,
104         &p.info
105     );
106 }
107 
108 template<>
109 inline void
llsq_call_xgels(llsq_params_xgels<double> & p)110 llsq_call_xgels( llsq_params_xgels< double >& p )
111 {
112     dgels_(
113         &p.trans,
114         &p.m,
115         &p.n,
116         &p.nrhs,
117         p.a,
118         &p.lda,
119         p.b,
120         &p.ldb,
121         p.work,
122         &p.lwork,
123         &p.info
124     );
125 
126 }
127 
128 
129 template< size_t M, size_t N, typename float_t >
130 struct linear_least_squares_xgels
131 {
132     bool compute(
133         const matrix< M, N, float_t >& A,
134         const vector< M, float_t >& B,
135         vector< N, float_t >& x );
136 
137     linear_least_squares_xgels();
138     ~linear_least_squares_xgels();
139 
get_paramsvmml::lapack::linear_least_squares_xgels140     const lapack::llsq_params_xgels< float_t >& get_params(){ return p; };
141 
get_factorized_Avmml::lapack::linear_least_squares_xgels142     matrix< M, N, float_t >& get_factorized_A() { return _A; }
143 
144 protected:
145     matrix< M, N, float_t > _A;
146     vector< M, float_t >    _b;
147 
148     llsq_params_xgels< float_t > p;
149 
150 };
151 
152 
153 
154 template< size_t M, size_t N, typename float_t >
155 bool
compute(const matrix<M,N,float_t> & A,const vector<M,float_t> & B,vector<N,float_t> & x)156 linear_least_squares_xgels< M, N, float_t >::compute(
157     const matrix< M, N, float_t >& A,
158     const vector< M, float_t >& B,
159     vector< N, float_t >& x )
160 {
161     _A = A;
162     _b = B;
163 
164     llsq_call_xgels( p );
165 
166     // success
167     if ( p.info == 0 )
168     {
169         for( size_t index = 0; index < N; ++index )
170         {
171             x( index ) = _b( index );
172         }
173 
174         return true;
175     }
176     if ( p.info < 0 )
177     {
178         VMMLIB_ERROR( "xGELS - invalid argument.", VMMLIB_HERE );
179     }
180     else
181     {
182         std::cout << "A\n" << A << std::endl;
183         std::cout << "B\n" << B << std::endl;
184 
185         VMMLIB_ERROR( "least squares solution could not be computed.",
186             VMMLIB_HERE );
187     }
188     return false;
189 }
190 
191 
192 
193 template< size_t M, size_t N, typename float_t >
194 linear_least_squares_xgels< M, N, float_t >::
linear_least_squares_xgels()195 linear_least_squares_xgels()
196 {
197     p.trans = 'N';
198     p.m     = M;
199     p.n     = N;
200     p.nrhs  = 1;
201     p.a     = _A.array;
202     p.lda   = M;
203     p.b     = _b.array;
204     p.ldb   = M;
205     p.work  = new float_t();
206     p.lwork = -1;
207 
208     // workspace query
209     llsq_call_xgels( p );
210 
211     p.lwork = static_cast< lapack_int > ( p.work[0] );
212     delete p.work;
213 
214     p.work = new float_t[ p.lwork ];
215 }
216 
217 
218 
219 template< size_t M, size_t N, typename float_t >
220 linear_least_squares_xgels< M, N, float_t >::
~linear_least_squares_xgels()221 ~linear_least_squares_xgels()
222 {
223     delete[] p.work;
224 }
225 
226 
227 
228 //
229 //
230 // SGESV/DGESV
231 //
232 //
233 
234 template< typename float_t >
235 struct llsq_params_xgesv
236 {
237     lapack_int      n; // order of matrix A = M * N
238     lapack_int      nrhs; // number of columns of B
239     float_t*        a;   // input A, output P*L*U
240     lapack_int      lda; // leading dimension of A (for us: number of rows)
241     lapack_int*     ipiv; // pivot indices, integer array of size N
242     float_t*        b;  // input b, output X
243     lapack_int      ldb; // leading dimension of b
244     lapack_int      info;
245 
operator <<(std::ostream & os,const llsq_params_xgesv<float_t> & p)246     friend std::ostream& operator << ( std::ostream& os,
247         const llsq_params_xgesv< float_t >& p )
248     {
249         os
250             << "n "         << p.n
251             << " nrhs "     << p.nrhs
252             << " lda "      << p.lda
253             << " ldb "      << p.ldvt
254             << " info "     << p.info
255             << std::endl;
256         return os;
257     }
258 
259 };
260 
261 
262 #if 0
263 /* Subroutine */ int dgesv_(integer *n, integer *nrhs, doublereal *a, integer
264 	*lda, integer *ipiv, doublereal *b, integer *ldb, integer *info);
265 #endif
266 
267 
268 template< typename float_t >
269 inline void
llsq_call_xgesv(llsq_params_xgesv<float_t> & p)270 llsq_call_xgesv( llsq_params_xgesv< float_t >& p )
271 {
272     VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
273 }
274 
275 
276 template<>
277 inline void
llsq_call_xgesv(llsq_params_xgesv<float> & p)278 llsq_call_xgesv( llsq_params_xgesv< float >& p )
279 {
280     sgesv_(
281         &p.n,
282         &p.nrhs,
283         p.a,
284         &p.lda,
285         p.ipiv,
286         p.b,
287         &p.ldb,
288         &p.info
289     );
290 
291 }
292 
293 
294 template<>
295 inline void
llsq_call_xgesv(llsq_params_xgesv<double> & p)296 llsq_call_xgesv( llsq_params_xgesv< double >& p )
297 {
298     dgesv_(
299         &p.n,
300         &p.nrhs,
301         p.a,
302         &p.lda,
303         p.ipiv,
304         p.b,
305         &p.ldb,
306         &p.info
307     );
308 }
309 
310 
311 template< size_t M, size_t N, typename float_t >
312 struct linear_least_squares_xgesv
313 {
314     // computes x ( Ax = b ). x replaces b on output.
315     void compute(
316         matrix< N, N, float_t >& A,
317         matrix< N, M, float_t >& b
318         );
319 
320     linear_least_squares_xgesv();
321     ~linear_least_squares_xgesv();
322 
get_paramsvmml::lapack::linear_least_squares_xgesv323     const lapack::llsq_params_xgesv< float_t >& get_params() { return p; }
324 
325     lapack::llsq_params_xgesv< float_t > p;
326 
327 }; // struct lapack_linear_least_squares
328 
329 
330 template< size_t M, size_t N, typename float_t >
331 void
332 linear_least_squares_xgesv< M, N, float_t >::
compute(matrix<N,N,float_t> & A,matrix<N,M,float_t> & b)333 compute(
334         matrix< N, N, float_t >& A,
335         matrix< N, M, float_t >& b
336         )
337 {
338     p.a = A.array;
339     p.b = b.array;
340 
341     lapack::llsq_call_xgesv( p );
342 
343     if ( p.info != 0 )
344     {
345         if ( p.info < 0 )
346             VMMLIB_ERROR( "invalid value in input matrix", VMMLIB_HERE );
347         else
348             VMMLIB_ERROR( "factor U is exactly singular, solution could not be computed.", VMMLIB_HERE );
349     }
350 }
351 
352 
353 
354 template< size_t M, size_t N, typename float_t >
355 linear_least_squares_xgesv< M, N, float_t >::
linear_least_squares_xgesv()356 linear_least_squares_xgesv()
357 {
358     p.n     = N;
359     p.nrhs  = M;
360     p.lda   = N;
361     p.ldb   = N;
362     p.ipiv = new lapack_int[ N ];
363 
364 }
365 
366 
367 
368 template< size_t M, size_t N, typename float_t >
369 linear_least_squares_xgesv< M, N, float_t >::
~linear_least_squares_xgesv()370 ~linear_least_squares_xgesv()
371 {
372     delete[] p.ipiv;
373 }
374 
375 
376 } // namespace lapack
377 
378 } // namespace vmml
379 
380 #endif
381 
382