1 #ifdef COMPILATION_INSTRUCTIONS
2 (echo "#include\""$0"\"" > $0x.cpp) && clang++ `#-DNDEBUG` -O3 -std=c++14 -Wall -Wextra -Wpedantic -Wfatal-errors -D_TEST_MULTI_ADAPTORS_LAPACK_CORE -DADD_ $0x.cpp -o $0x.x -lblas -llapack && time $0x.x $@ && rm -f $0x.x $0x.cpp; exit
3 #endif
4 // Alfredo A. Correa 2019 ©
5
6 #ifndef MULTI_ADAPTORS_LAPACK_CORE_HPP
7 #define MULTI_ADAPTORS_LAPACK_CORE_HPP
8
9 //#include<iostream>
10 #include<cassert>
11 #include<complex>
12
13 //#include <cblas/cblas.h>
14 #include<lapacke.h>
15
16 #define s float
17 #define d double
18 #define c std::complex<s>
19 #define z std::complex<d>
20 #define v void
21
22 #define INT int
23 #define INTEGER INT const&
24
25 //#define N INTEGER n
26 #define CHARACTER char const&
27 #define UPLO CHARACTER
28 #define JOBZ CHARACTER
29 #define LAPACK(NamE) NamE##_
30 #define LWORK INTEGER lwork
31 #define LIWORK INTEGER liwork
32 #define IWORK int*
33
34 #define xPOTRF(T) v LAPACK(T##potrf)(UPLO, int const& N, T*, int const& LDA, int& INFO)
35 #define xSYEV(T) v LAPACK(T##syev) (JOBZ, UPLO, int const& N, T*, int const& LDA, T*, T*, LWORK, int& INFO)
36 #define xSYEVD(T) v LAPACK(T##syevd)(JOBZ, UPLO, int const& N, T*, int const& LDA, T*, T*, LWORK, IWORK, LIWORK, int& INFO)
37 #define xHEEV(T) v LAPACK(T##heev) (JOBZ, UPLO, int const& N, T*, int const& LDA, T*, T*, LWORK, int& INFO)
38
39 #define subroutine void
40 #define integer int const&
41 #define integer_out int&
42 #define integer_ptr int*
43 #define integer_cptr int const*
44 #define character char const&
45
46 // http://www.netlib.org/lapack/explore-html/dd/d9a/group__double_g_ecomputational_ga0019443faea08275ca60a734d0593e60.html
47 #define xGETRF(T) \
48 subroutine T##getrf_( \
49 integer M, /*The number of rows of the matrix A. M >= 0.*/ \
50 integer N, /*The number of columns of the matrix A. N >= 0.*/ \
51 T* A, /*On entry, the M-by-N matrix to be factored.*/ \
52 /*On exit, the factors L and U from the factorization*/ \
53 integer LDA, /*The leading dimension of the array A. LDA >= max(1,M).*/\
54 integer_ptr IPIV, /*The pivot indices; for 1 <= i <= min(M,N), row i of the matrix was interchanged with row IPIV(i).*/\
55 integer_out INFO /*= 0: successful exit*/\
56 /*< 0: if INFO = -i, the i-th argument had an illegal value*/\
57 /*> 0: if INFO = i, U(i,i) is exactly zero. The factorization has been completed, but the factor U is exactly singular, and division by zero will occur if it is used to solve a system of equations.*/\
58 )
59
60 // http://www.netlib.org/lapack/explore-html/d8/ddc/group__real_g_ecomputational_gaa00bcf4d83a118cb6f0b6619d6ffaa24.html
61 #define xGETRS(T) \
62 subroutine T##getrs_( \
63 character TRANS,/*Specifies the form of the system of equations: */\
64 /* = 'N': A * X = B (No transpose) */\
65 /* = 'T': A**T* X = B (Transpose) */\
66 /* = 'C': A**T* X = B (Conjugate transpose = Transpose) */\
67 integer N, /*The order of the matrix A. N >= 0. */\
68 integer NRHS, /*The number of right hand sides, i.e., the number of columns*/\
69 /*of the matrix B. NRHS >= 0. */\
70 T const* A, /* The factors L and U from the factorization A = P*L*U */\
71 /*as computed by SGETRF. */\
72 integer LDA, /*The leading dimension of the array A. LDA >= max(1,N). */\
73 integer_cptr IPIV, /*The pivot indices from SGETRF; for 1<=i<=N, row i of the */\
74 /*matrix was interchanged with row IPIV(i). */\
75 T* B, /*On entry, the right hand side matrix B. */\
76 /*On exit, the solution matrix X. */\
77 integer LDB, /*The leading dimension of the array B. LDB >= max(1,N). */\
78 integer INFO /*= 0: successful exit */\
79 /*< 0: if INFO = -i, the i-th argument had an illegal value */\
80 )
81
82 // TODO // http://www.netlib.org/lapack/explore-html/d7/d3b/group__double_g_esolve_ga5ee879032a8365897c3ba91e3dc8d512.html
83
84
85 extern "C"{
86 //xGETRF(s) ; xGETRF(d) ; xGETRF(c) ; xGETRF(z) ;
87 //xGETRS(s) ; xGETRS(d) ; xGETRS(c) ; xGETRS(z) ;
88 }
89
90 namespace core{
91 // http://www.netlib.org/lapack/explore-html/da/d30/a18643_ga5b625680e6251feb29e386193914981c.html
92
getrf(lapack_int m,lapack_int n,double * A,lapack_int lda,int * ipiv)93 int getrf(lapack_int m, lapack_int n, double* A, lapack_int lda, int* ipiv){
94 assert( m >= 0 );
95 assert( n >= 0 );
96 assert( lda >= std::max(lapack_int{1}, m) );
97 int info;
98 dgetrf_(&m, &n, A, &lda, ipiv, &info);
99 assert(info >= 0);
100 return info;
101 }
102
getrs(char trans,lapack_int const n,lapack_int const nrhs,double const * A,lapack_int const lda,int const * ipiv,double * B,lapack_int const ldb)103 void getrs(char trans, lapack_int const n, lapack_int const nrhs, double const* A, lapack_int const lda, int const* ipiv, double* B, lapack_int const ldb){
104 assert( trans == 'T' or trans == 'N' or trans == 'C' );
105 assert( n >= 0 );
106 assert( nrhs >= 0 );
107 assert( lda >= std::max(1, n) );
108 int info;
109 dgetrs_(&trans, &n, &nrhs, A, &lda, ipiv, B, &ldb, &info);
110 switch(info){
111 case -1: throw std::logic_error{"transa ≠ 'N', 'T', or 'C'"};
112 case -2: throw std::logic_error{"n < 0" };
113 case -3: throw std::logic_error{"nrhs < 0" };
114 case -4: throw std::logic_error{"n > lda" };
115 case -5: throw std::logic_error{"lda ≤ 0" };
116 case -6: throw std::logic_error{"n > ldb" };
117 case -7: throw std::logic_error{"ldb ≤ 0" };
118 case -8: throw std::logic_error{"error!" };
119 }
120 assert(info == 0 );
121 return;
122 }
123
124 }
125
126 namespace lapack{
127
128 struct context{
getrflapack::context129 template<class... Args> static auto getrf(Args&&... args)->decltype(core::getrf(args...)){return core::getrf(args...);}
getrslapack::context130 template<class... Args> static auto getrs(Args&&... args)->decltype(core::getrs(args...)){return core::getrs(args...);}
131 };
132
133 }
134
135 extern "C"{
136 //xPOTRF(s) ; xPOTRF(d) ;
137 //xPOTRF(c) ; xPOTRF(z) ;
138
139 //xSYEV(s) ; xSYEV(d) ;
140 //xSYEVD(s) ; xSYEVD(d) ;
141 // xHEEV(c) ; xHEEV(z) ;
142 }
143
144 #undef subroutine
145 #undef integer
146 #undef character
147
148 #undef JOBZ
149 #undef UPLO
150 #undef INFO
151 #undef CHARACTER
152 #undef N
153 #undef LDA
154
155 #undef INTEGER
156 #undef INT
157
158
159 #define xpotrf(T) template<class S> v potrf(char uplo, S n, T *x, S incx, int& info){LAPACK(T##potrf)(uplo, n, x, incx, info);}
160
161 namespace core{
162 xpotrf(s) xpotrf(d)
163 xpotrf(c) xpotrf(z)
164 }
165
166 // http://www.netlib.org/lapack/explore-html/d2/d8a/group__double_s_yeigen_ga442c43fca5493590f8f26cf42fed4044.html
167 #define xsyev(T) template<class S> v syev(char jobz, char uplo, S n, T* a, S lda, T* w, T* work, S lwork, int& info){LAPACK(T##syev)(jobz, uplo, n, a, lda, w, work, lwork, info);}
168 // http://www.netlib.org/lapack/explore-html/d2/d8a/group__double_s_yeigen_ga77dfa610458b6c9bd7db52533bfd53a1.html
169 #define xsyevd(T) template<class S> v syevd(char jobz, char uplo, S n, T* a, S lda, T* w, T* work, S lwork, int* iwork, S liwork, int& info){ \
170 if(n <= 1 ){assert(lwork >= 1 ); assert(liwork >=1 );} \
171 if(jobz == 'N' and n > 1){assert(lwork >= 2*n+1 ); assert(liwork >= 1 );} \
172 if(jobz == 'V' and n > 1){assert(lwork >= 1 + 6*n + 2*n*n); assert(liwork >= 3 + 5*n);} \
173 LAPACK(T##syevd)(jobz, uplo, n, a, lda, w, work, lwork, iwork, liwork, info); \
174 }
175 #define xheev(T) template<class S> v heev(char jobz, char uplo, S n, T* a, S lda, T* w, T* work, S lwork, int& info){LAPACK(T##heev)(jobz, uplo, n, a, lda, w, work, lwork, info);}
176
177 namespace core{
178 xsyev (s) xsyev (d)
179 xsyevd(s) xsyevd(d)
180 xheev(c) xheev(z)
181 }
182
183 #undef s
184 #undef d
185 #undef c
186 #undef z
187 #undef v
188
189 #define TRANS const char& trans
190
191 ///////////////////////////////////////////////////////////////////////////////
192
193 #if _TEST_MULTI_ADAPTORS_LAPACK_CORE
194
195 #include "../../array.hpp"
196 #include "../../utility.hpp"
197
198 #include<iostream>
199 #include<numeric>
200 #include<vector>
201
202 namespace multi = boost::multi;
203 using std::cout;
204
main()205 int main(){
206 using core::potrf;
207
208 std::vector<double> v = {
209 2., 1.,
210 1., 2.
211 };
212 cout
213 << v[0] <<'\t'<< v[1] <<'\n'
214 << v[2] <<'\t'<< v[3] <<'\n' << std::endl
215 ;
216 int info;
217 potrf('U', 2, v.data(), 2, info);
218 cout << "error " << info << std::endl;
219 cout
220 << v[0] <<'\t'<< v[1] <<'\n'
221 << v[2] <<'\t'<< v[3] <<'\n'
222 ;
223 cout << std::endl;
224 }
225
226 #endif
227 #endif
228
229