1 #ifndef __VMML__VMMLIB_LAPACK_LINEAR_LEAST_SQUARES__HPP__
2 #define __VMML__VMMLIB_LAPACK_LINEAR_LEAST_SQUARES__HPP__
3 
4 #include <vmmlib/matrix.hpp>
5 #include <vmmlib/vector.hpp>
6 #include <vmmlib/exception.hpp>
7 
8 #include <vmmlib/lapack_types.hpp>
9 #include <vmmlib/lapack_includes.hpp>
10 
11 #include <string>
12 
13 /**
14 *
15 * this is a wrapper for the following lapack routines:
16 *
17 * xGESV
18 *
19 *
20 */
21 
22 
23 namespace vmml
24 {
25 
26 // XYYZZZ
27 // X    = data type: S - float, D - double
28 // YY   = matrix type, GE - general, TR - triangular
29 // ZZZ  = function name
30 
31 namespace lapack
32 {
33 
34 //
35 //
36 // SGELS/DGELS
37 //
38 //
39 
40 
41 // parameter struct
42 template< typename float_t >
43 struct llsq_params_xgels
44 {
45     char            trans;  // 'N'->A, 'T'->Atransposed
46     lapack_int      m;      // number of rows,      M >= 0
47     lapack_int      n;      // number of columns,   N >= 0
48     lapack_int      nrhs;   // number of columns of B/X
49     float_t*        a;      // input A
50     lapack_int      lda;    // leading dimension of A (number of rows)
51     float_t*        b;      // input B, output X
52     lapack_int      ldb;    // leading dimension of b
53     float_t*        work;   // workspace
54     lapack_int      lwork;  // workspace size
55     lapack_int      info;   // 'return' value
56 
operator <<(std::ostream & os,const llsq_params_xgels<float_t> & p)57     friend std::ostream& operator << ( std::ostream& os,
58         const llsq_params_xgels< float_t >& p )
59     {
60         os
61             << " m "        << p.m
62             << " n "        << p.n
63             << " nrhs "     << p.nrhs
64             << " lda "      << p.lda
65             << " ldb "      << p.ldb
66             << " lwork "    << p.lwork
67             << " info "     << p.info
68             << std::endl;
69         return os;
70     }
71 
72 };
73 
74 // call wrappers
75 
76 #if 0
77 void dgels_(const char *trans, const int *M, const int *N, const int *nrhs,
78     double *A, const int *lda, double *b, const int *ldb, double *work,
79     const int * lwork, int *info);
80 #endif
81 
82 template< typename float_t >
83 inline void
llsq_call_xgels(llsq_params_xgels<float_t> & p)84 llsq_call_xgels( llsq_params_xgels< float_t >& p )
85 {
86     VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
87 }
88 
89 template<>
90 inline void
llsq_call_xgels(llsq_params_xgels<float> & p)91 llsq_call_xgels( llsq_params_xgels< float >& p )
92 {
93     sgels_(
94         &p.trans,
95         &p.m,
96         &p.n,
97         &p.nrhs,
98         p.a,
99         &p.lda,
100         p.b,
101         &p.ldb,
102         p.work,
103         &p.lwork,
104         &p.info
105     );
106 }
107 
108 template<>
109 inline void
llsq_call_xgels(llsq_params_xgels<double> & p)110 llsq_call_xgels( llsq_params_xgels< double >& p )
111 {
112     dgels_(
113         &p.trans,
114         &p.m,
115         &p.n,
116         &p.nrhs,
117         p.a,
118         &p.lda,
119         p.b,
120         &p.ldb,
121         p.work,
122         &p.lwork,
123         &p.info
124     );
125 
126 }
127 
128 
129 template< size_t M, size_t N, typename float_t >
130 struct linear_least_squares_xgels
131 {
132     void compute(
133         const matrix< M, N, float_t >& A,
134         const vector< M, float_t >& B,
135         vector< N, float_t >& x );
136 
137     linear_least_squares_xgels();
138     ~linear_least_squares_xgels();
139 
get_paramsvmml::lapack::linear_least_squares_xgels140     const lapack::llsq_params_xgels< float_t >& get_params(){ return p; };
141 
get_factorized_Avmml::lapack::linear_least_squares_xgels142     matrix< M, N, float_t >& get_factorized_A() { return _A; }
143 
144 protected:
145     matrix< M, N, float_t > _A;
146     vector< M, float_t >    _b;
147 
148     llsq_params_xgels< float_t > p;
149 
150 };
151 
152 
153 
154 template< size_t M, size_t N, typename float_t >
155 void
compute(const matrix<M,N,float_t> & A,const vector<M,float_t> & B,vector<N,float_t> & x)156 linear_least_squares_xgels< M, N, float_t >::compute(
157     const matrix< M, N, float_t >& A,
158     const vector< M, float_t >& B,
159     vector< N, float_t >& x )
160 {
161     _A = A;
162     _b = B;
163 
164     llsq_call_xgels( p );
165 
166     // success
167     if ( p.info == 0 )
168     {
169         for( size_t index = 0; index < N; ++index )
170         {
171             x( index ) = _b( index );
172         }
173 
174         return;
175     }
176     if ( p.info < 0 )
177     {
178         VMMLIB_ERROR( "xGELS - invalid argument.", VMMLIB_HERE );
179     }
180     else
181     {
182         VMMLIB_ERROR( "least squares solution could not be computed.",
183             VMMLIB_HERE );
184     }
185 
186 }
187 
188 
189 
190 template< size_t M, size_t N, typename float_t >
191 linear_least_squares_xgels< M, N, float_t >::
linear_least_squares_xgels()192 linear_least_squares_xgels()
193 {
194     p.trans = 'N';
195     p.m     = M;
196     p.n     = N;
197     p.nrhs  = 1;
198     p.a     = _A.array;
199     p.lda   = M;
200     p.b     = _b.array;
201     p.ldb   = M;
202     p.work  = new float_t();
203     p.lwork = -1;
204 
205     // workspace query
206     llsq_call_xgels( p );
207 
208     p.lwork = static_cast< lapack_int > ( p.work[0] );
209     delete p.work;
210 
211     p.work = new float_t[ p.lwork ];
212 }
213 
214 
215 
216 template< size_t M, size_t N, typename float_t >
217 linear_least_squares_xgels< M, N, float_t >::
~linear_least_squares_xgels()218 ~linear_least_squares_xgels()
219 {
220     delete[] p.work;
221 }
222 
223 
224 
225 //
226 //
227 // SGESV/DGESV
228 //
229 //
230 
231 template< typename float_t >
232 struct llsq_params_xgesv
233 {
234     lapack_int      n; // order of matrix A = M * N
235     lapack_int      nrhs; // number of columns of B
236     float_t*        a;   // input A, output P*L*U
237     lapack_int      lda; // leading dimension of A (for us: number of rows)
238     lapack_int*     ipiv; // pivot indices, integer array of size N
239     float_t*        b;  // input b, output X
240     lapack_int      ldb; // leading dimension of b
241     lapack_int      info;
242 
operator <<(std::ostream & os,const llsq_params_xgesv<float_t> & p)243     friend std::ostream& operator << ( std::ostream& os,
244         const llsq_params_xgesv< float_t >& p )
245     {
246         os
247             << "n "         << p.n
248             << " nrhs "     << p.nrhs
249             << " lda "      << p.lda
250             << " ldb "      << p.ldvt
251             << " info "     << p.info
252             << std::endl;
253         return os;
254     }
255 
256 };
257 
258 
259 #if 0
260 /* Subroutine */ int dgesv_(integer *n, integer *nrhs, doublereal *a, integer
261 	*lda, integer *ipiv, doublereal *b, integer *ldb, integer *info);
262 #endif
263 
264 
265 template< typename float_t >
266 inline void
llsq_call_xgesv(llsq_params_xgesv<float_t> & p)267 llsq_call_xgesv( llsq_params_xgesv< float_t >& p )
268 {
269     VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
270 }
271 
272 
273 template<>
274 inline void
llsq_call_xgesv(llsq_params_xgesv<float> & p)275 llsq_call_xgesv( llsq_params_xgesv< float >& p )
276 {
277     sgesv_(
278         &p.n,
279         &p.nrhs,
280         p.a,
281         &p.lda,
282         p.ipiv,
283         p.b,
284         &p.ldb,
285         &p.info
286     );
287 
288 }
289 
290 
291 template<>
292 inline void
llsq_call_xgesv(llsq_params_xgesv<double> & p)293 llsq_call_xgesv( llsq_params_xgesv< double >& p )
294 {
295     dgesv_(
296         &p.n,
297         &p.nrhs,
298         p.a,
299         &p.lda,
300         p.ipiv,
301         p.b,
302         &p.ldb,
303         &p.info
304     );
305 }
306 
307 
308 template< size_t M, size_t N, typename float_t >
309 struct linear_least_squares_xgesv
310 {
311     // computes x ( Ax = b ). x replaces b on output.
312     void compute(
313         matrix< N, N, float_t >& A,
314         matrix< N, M, float_t >& b
315         );
316 
317     linear_least_squares_xgesv();
318     ~linear_least_squares_xgesv();
319 
get_paramsvmml::lapack::linear_least_squares_xgesv320     const lapack::llsq_params_xgesv< float_t >& get_params() { return p; }
321 
322     lapack::llsq_params_xgesv< float_t > p;
323 
324 }; // struct lapack_linear_least_squares
325 
326 
327 template< size_t M, size_t N, typename float_t >
328 void
329 linear_least_squares_xgesv< M, N, float_t >::
compute(matrix<N,N,float_t> & A,matrix<N,M,float_t> & b)330 compute(
331         matrix< N, N, float_t >& A,
332         matrix< N, M, float_t >& b
333         )
334 {
335     p.a = A.array;
336     p.b = b.array;
337 
338     lapack::llsq_call_xgesv( p );
339 
340     if ( p.info != 0 )
341     {
342         if ( p.info < 0 )
343             VMMLIB_ERROR( "invalid value in input matrix", VMMLIB_HERE );
344         else
345             VMMLIB_ERROR( "factor U is exactly singular, solution could not be computed.", VMMLIB_HERE );
346     }
347 }
348 
349 
350 
351 template< size_t M, size_t N, typename float_t >
352 linear_least_squares_xgesv< M, N, float_t >::
linear_least_squares_xgesv()353 linear_least_squares_xgesv()
354 {
355     p.n     = N;
356     p.nrhs  = M;
357     p.lda   = N;
358     p.ldb   = N;
359     p.ipiv = new lapack_int[ N ];
360 
361 }
362 
363 
364 
365 template< size_t M, size_t N, typename float_t >
366 linear_least_squares_xgesv< M, N, float_t >::
~linear_least_squares_xgesv()367 ~linear_least_squares_xgesv()
368 {
369     delete[] p.ipiv;
370 }
371 
372 
373 } // namespace lapack
374 
375 } // namespace vmml
376 
377 #endif
378 
379