1 #ifdef COMPILATION// -*-indent-tabs-mode:t;c-basic-offset:4;tab-width:4;-*-
2 $CXXX $CXXFLAGS $0 -o $0x -lboost_unit_test_framework `pkg-config --libs blas` \
3 `#-Wl,-rpath,/usr/local/Wolfram/Mathematica/12.0/SystemFiles/Libraries/Linux-x86-64 -L/usr/local/Wolfram/Mathematica/12.0/SystemFiles/Libraries/Linux-x86-64 -lmkl_intel_ilp64 -lmkl_intel_thread -lmkl_core -liomp5` \
4 -lboost_timer &&$0x&&rm $0x; exit
5 #endif
6 // © Alfredo A. Correa 2019-2020
7
8 #ifndef MULTI_ADAPTORS_BLAS_HERK_HPP
9 #define MULTI_ADAPTORS_BLAS_HERK_HPP
10
11 #include "../blas/core.hpp"
12 #include "../blas/copy.hpp"
13 //#include "../blas/scal.hpp"
14 #include "../blas/syrk.hpp" // fallback to real case
15
16 #include "../blas/side.hpp"
17 #include "../blas/filling.hpp"
18
19 #include "../blas/operations.hpp"
20
21 #include "../../config/NODISCARD.hpp"
22
23 //#include<iostream> //debug
24 //#include<type_traits> // void_t
25
26 namespace boost{
27 namespace multi{namespace blas{
28
29 template<class A, std::enable_if_t<not is_conjugated<A>{}, int> =0>
30 auto base_aux(A&& a)
31 ->decltype(base(a)){
32 return base(a);}
33
34 template<class A, std::enable_if_t< is_conjugated<A>{}, int> =0>
35 auto base_aux(A&& a)
36 ->decltype(underlying(base(a))){
37 return underlying(base(a));}
38
39 using core::herk;
40
41 template<class AA, class BB, class A2D, class C2D, class = typename A2D::element_ptr, std::enable_if_t<is_complex_array<C2D>{}, int> =0>
42 C2D&& herk(filling c_side, AA alpha, A2D const& a, BB beta, C2D&& c)
43 //->decltype(herk('\0', '\0', c.size(), a.size(), &alpha, base_aux(a), stride(a.rotated()), &beta, base_aux(c), stride(c)), std::forward<C2D>(c))
44 {
45 assert( a.size() == c.size() );
46 assert( c.size() == rotated(c).size() );
47 if(c.size()==0) return std::forward<C2D>(c);
48 if constexpr(is_conjugated<C2D>{}){herk(flip(c_side), alpha, a, beta, hermitized(c)); return std::forward<C2D>(c);}
49 {
50 auto base_a = base_aux(a);
51 auto base_c = base_aux(c); // static_assert( not is_conjugated<C2D>{}, "!" );
52 if constexpr(is_conjugated<A2D>{}){
53 // auto& ctxt = *blas::default_context_of(underlying(a.base()));
54 // if you get an error here might be due to lack of inclusion of a header file with the backend appropriate for your type of iterator
55 if(stride(a)==1 and stride(c)!=1) herk(c_side==filling::upper?'L':'U', 'N', size(c), size(rotated(a)), &alpha, base_a, stride(rotated(a)), &beta, base_c, stride(c));
56 else if(stride(a)==1 and stride(c)==1){
57 if(size(a)==1) herk(c_side==filling::upper?'L':'U', 'N', size(c), size(rotated(a)), &alpha, base_a, stride(rotated(a)), &beta, base_c, stride(c));
58 else assert(0);
59 }
60 else if(stride(a)!=1 and stride(c)==1) herk(c_side==filling::upper?'U':'L', 'C', size(c), size(rotated(a)), &alpha, base_a, stride( a ), &beta, base_c, stride(rotated(c)));
61 else if(stride(a)!=1 and stride(c)!=1) herk(c_side==filling::upper?'L':'U', 'C', size(c), size(rotated(a)), &alpha, base_a, stride( a ), &beta, base_c, stride( c ));
62 else assert(0);
63 }else{
64 // auto& ctxt = *blas::default_context_of( a.base() );
65 ;;;; if(stride(a)!=1 and stride(c)!=1) herk(c_side==filling::upper?'L':'U', 'C', size(c), size(rotated(a)), &alpha, base_a, stride( a ), &beta, base_c, stride(c));
66 else if(stride(a)!=1 and stride(c)==1){
67 if(size(a)==1) herk(c_side==filling::upper?'L':'U', 'N', size(c), size(rotated(a)), &alpha, base_a, stride(rotated(a)), &beta, base_c, stride(rotated(c)));
68 else assert(0);
69 }
70 else if(stride(a)==1 and stride(c)!=1) assert(0);//case not implemented, herk(c_side==filling::upper?'L':'U', 'N', size(c), size(rotated(a)), alpha, base_a, stride(rotated(a)), beta, base(c), stride(c));
71 else if(stride(a)==1 and stride(c)==1) herk(c_side==filling::upper?'U':'L', 'N', size(c), size(rotated(a)), &alpha, base_a, stride(rotated(a)), &beta, base_c, stride(rotated(c)));
72 else assert(0);
73 }
74 }
75 return std::forward<C2D>(c);
76 }
77
78 template<class AA, class BB, class A2D, class C2D, class = typename A2D::element_ptr, std::enable_if_t<not is_complex_array<C2D>{}, int> =0>
79 auto herk(filling c_side, AA alpha, A2D const& a, BB beta, C2D&& c)
80 ->decltype(syrk(c_side, alpha, a, beta, std::forward<C2D>(c))){
81 return syrk(c_side, alpha, a, beta, std::forward<C2D>(c));}
82
83 //template<class AA, class BB, class A2D, class C2D, class = typename A2D::element_ptr>
84 //auto herk(filling c_side, AA alpha, A2D const& a, BB beta, C2D&& c)
85 //->decltype(herk_aux(c_side, alpha, a, beta, std::forward<C2D>(c), is_complex<C2D>{})){
86 // return herk_aux(c_side, alpha, a, beta, std::forward<C2D>(c), is_complex<C2D>{});}
87
88 template<class AA, class A2D, class C2D, class = typename A2D::element_ptr>
herk(filling c_side,AA alpha,A2D const & a,C2D && c)89 auto herk(filling c_side, AA alpha, A2D const& a, C2D&& c)
90 ->decltype(herk(c_side, alpha, a, 0., std::forward<C2D>(c))){
91 return herk(c_side, alpha, a, 0., std::forward<C2D>(c));}
92
93 template<typename AA, class A2D, class C2D>
herk(AA alpha,A2D const & a,C2D && c)94 auto herk(AA alpha, A2D const& a, C2D&& c)
95 ->decltype(herk(filling::lower, alpha, a, herk(filling::upper, alpha, a, std::forward<C2D>(c)))){
96 return herk(filling::lower, alpha, a, herk(filling::upper, alpha, a, std::forward<C2D>(c)));}
97
98 template<class A2D, class C2D>
herk(A2D const & a,C2D && c)99 auto herk(A2D const& a, C2D&& c)
100 ->decltype(herk(1., a, std::forward<C2D>(c))){
101 return herk(1., a, std::forward<C2D>(c));}
102
103 /*
104 template<class A2D, class C2D>
105 NODISCARD("when last argument is const")
106 auto herk(A2D const& a, C2D const& c)
107 ->decltype(herk(1., a, decay(c))){
108 return herk(1., a, decay(c));}
109 */
110
111 template<class AA, class A2D, class Ret = typename A2D::decay_type>
112 NODISCARD("when argument is read-only")
herk(AA alpha,A2D const & a)113 auto herk(AA alpha, A2D const& a)//->std::decay_t<decltype(herk(alpha, a, Ret({size(a), size(a)}, get_allocator(a))))>{
114 {
115 return herk(alpha, a, Ret({size(a), size(a)}));//Ret({size(a), size(a)}));//, get_allocator(a)));
116 }
117
118 template<class T> struct numeric_limits : std::numeric_limits<T>{};
119 template<class T> struct numeric_limits<std::complex<T>> : std::numeric_limits<std::complex<T>>{
quiet_NaNboost::multi::blas::numeric_limits120 static std::complex<T> quiet_NaN(){auto n=numeric_limits<T>::quiet_NaN(); return {n, n};}
121 };
122
123 template<class AA, class A2D, class Ret = typename A2D::decay_type>
124 NODISCARD("because argument is read-only")
herk(filling cs,AA alpha,A2D const & a)125 auto herk(filling cs, AA alpha, A2D const& a)
126 ->std::decay_t<
127 decltype(herk(cs, alpha, a, Ret({size(a), size(a)}, 0., get_allocator(a))))>{
128 return herk(cs, alpha, a, Ret({size(a), size(a)},
129 #ifdef NDEBUG
130 numeric_limits<typename Ret::element_type>::quiet_NaN(),
131 #endif
132 get_allocator(a)
133 ));
134 }
135
herk(filling s,A2D const & a)136 template<class A2D> auto herk(filling s, A2D const& a)
137 ->decltype(herk(s, 1., a)){
138 return herk(s, 1., a);}
139
herk(A2D const & a)140 template<class A2D> auto herk(A2D const& a)
141 //->decltype(herk(1., a)){
142 { return herk(1., a);}
143
144 }}
145
146 }
147
148 #if not __INCLUDE_LEVEL__ // _TEST_MULTI_ADAPTORS_BLAS_HERK
149
150 #define BOOST_TEST_MODULE "C++ Unit Tests for Multi cuBLAS herk"
151 #define BOOST_TEST_DYN_LINK
152 #include<boost/test/unit_test.hpp>
153
154 #include "../../array.hpp"
155 #include "../../adaptors/blas/gemm.hpp"
156 #include "../../adaptors/blas/nrm2.hpp"
157
158 #include<iostream>
159 #include<numeric>
160
161 namespace utf = boost::unit_test;
162 namespace multi = boost::multi;
163
164 template<class T> void what(T&&) = delete;
165
print(M const & C)166 template<class M> decltype(auto) print(M const& C){
167 using std::cout;
168 using boost::multi::size;
169 for(int i = 0; i != size(C); ++i){
170 for(int j = 0; j != size(C[i]); ++j) cout << C[i][j] << ' ';
171 cout << std::endl;
172 }
173 return cout << std::endl;
174 }
175
BOOST_AUTO_TEST_CASE(inq_case)176 BOOST_AUTO_TEST_CASE(inq_case){
177 using namespace multi::blas;
178 multi::array<double, 2> const a = {
179 {0, 1, 2},
180 {3, 4, 5},
181 {6, 7, 8},
182 {9, 10, 11}
183 };
184 BOOST_REQUIRE( gemm(a, T(a))[1][2] == 86. );
185 {
186 multi::array<double, 2> c({4, 4});
187 herk(1.0, a, c);
188 BOOST_REQUIRE( c == gemm(a, T(a)) );
189 }
190 {
191 multi::array<double, 2> c = herk(1.0, a);
192 BOOST_REQUIRE( c == gemm(a, T(a)) );
193 }
194 {
195 BOOST_REQUIRE( herk(a) == gemm(a, T(a)) );
196 }
197 {
198 BOOST_REQUIRE( herk(2.0, a) == gemm(2.0, a, T(a)) );
199 }
200 }
201
BOOST_AUTO_TEST_CASE(multi_blas_herk_real)202 BOOST_AUTO_TEST_CASE(multi_blas_herk_real){
203 namespace blas = multi::blas;
204 multi::array<double, 2> const a = {
205 { 1., 3., 4.},
206 { 9., 7., 1.}
207 };
208 {
209 multi::array<double, 2> c({2, 2}, 9999);
210 blas::herk(1., a, c);
211 BOOST_REQUIRE( c[1][0] == 34 );
212 BOOST_REQUIRE( c[0][1] == 34 );
213
214 multi::array<double, 2> const c_copy = blas::herk(1., a);
215 BOOST_REQUIRE( c == c_copy );
216 }
217 }
218
BOOST_AUTO_TEST_CASE(multi_blas_herk1x1_case)219 BOOST_AUTO_TEST_CASE(multi_blas_herk1x1_case){
220 namespace blas = multi::blas;
221 multi::array<double, 2> const A = {{1., 2., 3.}};
222 multi::array<double, 2> B = blas::herk(A);
223 BOOST_REQUIRE( size(B) == 1 );
224 BOOST_REQUIRE( B[0][0] == 1.*1. + 2.*2. + 3.*3. );
225 }
226
BOOST_AUTO_TEST_CASE(multi_blas_herk1x1_case_scale)227 BOOST_AUTO_TEST_CASE(multi_blas_herk1x1_case_scale){
228 namespace blas = multi::blas;
229 multi::array<double, 2> const A = {{1., 2., 3.}};
230 multi::array<double, 2> B = blas::herk(0.1, A);
231 BOOST_REQUIRE( size(B) == 1 );
232 BOOST_TEST( B[0][0] == (1.*1. + 2.*2. + 3.*3.)*0.1 );
233 }
234
BOOST_AUTO_TEST_CASE(multi_blas_herk1x1_complex_real_case)235 BOOST_AUTO_TEST_CASE(multi_blas_herk1x1_complex_real_case){
236 namespace blas = multi::blas;
237 multi::array<complex, 2> const A = {{1., 2., 3.}};
238 multi::array<complex, 2> B = blas::herk(1.0, A);
239 BOOST_REQUIRE( size(B) == 1 );
240 BOOST_REQUIRE( B[0][0] == 1.*1. + 2.*2. + 3.*3. );
241 }
242
243 BOOST_AUTO_TEST_CASE(multi_blas_herk1x1_complex_real_case_scale, *utf::tolerance(0.00001)){
244 namespace blas = multi::blas;
245 multi::array<complex, 2> const A = {{1., 2., 3.}};
246 multi::array<complex, 2> B = blas::herk(0.1, A);
247 BOOST_REQUIRE( size(B) == 1 );
248 BOOST_TEST( real( B[0][0]/0.1 ) == 1.*1. + 2.*2. + 3.*3. );
249 }
250
BOOST_AUTO_TEST_CASE(multi_blas_herk1x1_complex_case)251 BOOST_AUTO_TEST_CASE(multi_blas_herk1x1_complex_case){
252 namespace blas = multi::blas;
253 multi::array<complex, 2> const A = {{1. + 2.*I, 2.+3.*I, 3. + 4.*I}};
254 multi::array<complex, 2> B = blas::herk(A);
255 BOOST_REQUIRE( size(B) == 1 );
256 BOOST_REQUIRE( B[0][0] == std::norm(1. + 2.*I) + std::norm(2.+3.*I) + std::norm(3. + 4.*I) );
257
258 BOOST_TEST( std::sqrt(real(blas::herk(A)[0][0])) == blas::nrm2(A[0])() );
259 }
260
BOOST_AUTO_TEST_CASE(multi_blas_herk1x1_complex_case_hermitized_out_param)261 BOOST_AUTO_TEST_CASE(multi_blas_herk1x1_complex_case_hermitized_out_param){
262 namespace blas = multi::blas;
263 multi::array<complex, 2> const A = {{1. + 2.*I}, {2.+3.*I}, {3. + 4.*I}};
264 multi::array<complex, 2> B({1, 1});
265 BOOST_REQUIRE( size(B) == 1 );
266
267 blas::herk(blas::filling::upper, 1.0, blas::H(A), 0.0, B);
268
269 BOOST_REQUIRE( B[0][0] == std::norm(1. + 2.*I) + std::norm(2.+3.*I) + std::norm(3. + 4.*I) );
270
271 BOOST_TEST( std::sqrt(real(B[0][0])) == blas::nrm2(blas::T(A)[0])() );
272 }
273
BOOST_AUTO_TEST_CASE(multi_blas_herk1x1_complex_case_hermitized)274 BOOST_AUTO_TEST_CASE(multi_blas_herk1x1_complex_case_hermitized){
275 multi::array<complex, 2> A = {{1. + 2.*I}, {2.+3.*I}, {3. + 4.*I}};
276 namespace blas = multi::blas;
277 multi::array<complex, 2> B = blas::herk(blas::H(A));
278 BOOST_REQUIRE( size(B) == 1 );
279 BOOST_REQUIRE( B[0][0] == std::norm(1. + 2.*I) + std::norm(2.+3.*I) + std::norm(3. + 4.*I) );
280
281 BOOST_TEST( std::sqrt(real(blas::herk(blas::H(A))[0][0])) == blas::nrm2(rotated(A)[0])() );
282 }
283
BOOST_AUTO_TEST_CASE(multi_blas_herk1x1_complex_case_hermitized_auto)284 BOOST_AUTO_TEST_CASE(multi_blas_herk1x1_complex_case_hermitized_auto){
285 namespace blas = multi::blas;
286
287 multi::array<complex, 2> A = {{1. + 2.*I}, {2.+3.*I}, {3. + 4.*I}};
288 auto B = blas::herk(1., blas::hermitized(A));
289 static_assert( std::is_same<decltype(B), multi::array<complex, 2>>{}, "!" );
290 BOOST_REQUIRE( size(B) == 1 );
291 BOOST_REQUIRE( B[0][0] == std::norm(1. + 2.*I) + std::norm(2.+3.*I) + std::norm(3. + 4.*I) );
292
293 BOOST_TEST( std::sqrt(real(blas::herk(blas::H(A))[0][0])) == blas::nrm2(rotated(A)[0])() );
294 }
295
296 #if 1
297 #if 1
298
BOOST_AUTO_TEST_CASE(multi_blas_herk_complex_identity)299 BOOST_AUTO_TEST_CASE(multi_blas_herk_complex_identity){
300 namespace blas = multi::blas;
301 multi::array<complex, 2> const a = {
302 { 1. + 3.*I, 3.- 2.*I, 4.+ 1.*I},
303 { 9. + 1.*I, 7.- 8.*I, 1.- 3.*I}
304 };
305
306 {
307 multi::array<complex, 2> c({2, 2}, 9999.);
308 blas::herk(blas::filling::lower, 1., a, 0., c); // c†=c=aa†=(aa†)†, `c` in lower triangular
309 BOOST_REQUIRE( c[1][0]==complex(50., -49.) );
310 BOOST_REQUIRE( c[0][1]==9999. );
311 }
312 {
313 multi::array<complex, 2> c({2, 2}, 9999.);
314 static_assert(blas::is_conjugated<decltype(blas::H(c))>{}, "!" );
315
316 blas::herk(blas::filling::lower, 1., a, 0., blas::H(c)); // c†=c=aa†=(aa†)†, `c` in upper triangular
317
318 BOOST_REQUIRE( blas::H(c)[1][0]==complex(50., -49.) );
319 BOOST_REQUIRE( blas::H(c)[0][1]==9999. );
320 }
321 {
322 // multi::array<complex, 2> c({2, 2}, 9999.);
323 // blas::herk(blas::filling::lower, 1., a, 0., blas::T(c)); // c†=c=aa†=(aa†)†, `c` in lower triangular
324 // BOOST_REQUIRE( transposed(c)[1][0]==complex(50., -49.) );
325 // BOOST_REQUIRE( transposed(c)[0][1]==9999. );
326 }
327 {
328 multi::array<complex, 2> c({3, 3}, 9999.);
329 // herk(filling::lower, 1., transposed(a), 0., c); // c†=c=aT(aT)† not supported
330 // print(c);
331 // BOOST_REQUIRE( c[1][0]==complex(52., -90.) );
332 // BOOST_REQUIRE( c[0][1]==9999. );
333 }
334 {
335 multi::array<complex, 2> c({3, 3}, 9999.);
336 // herk(filling::lower, 1., transposed(a), 0., hermitized(c)); // c†=c=aT(aT)† not supported
337 // BOOST_REQUIRE( hermitized(c)[1][0]==complex(52., -90.) );
338 // BOOST_REQUIRE( hermitized(c)[0][1]==9999. );
339 }
340 {
341 multi::array<complex, 2> c({3, 3}, 9999.);
342 herk(blas::filling::lower, 1., blas::T(a), 0., blas::T(c)); // c†=c=aT(aT)† not supported
343 BOOST_REQUIRE( transposed(c)[1][0]==complex(52., -90.) );
344 BOOST_REQUIRE( transposed(c)[0][1]==9999. );
345 }
346 {
347 multi::array<complex, 2> c({3, 3}, 9999.);
348 blas::herk(blas::filling::lower, 1., blas::T(a), 0., blas::H(blas::T(c))); // c†=c=aT(aT)† not supported
349 BOOST_REQUIRE( blas::H(blas::T(c))[1][0]==complex(52., -90.) );
350 BOOST_REQUIRE( blas::H(blas::T(c))[0][1]==9999. );
351 }
352 {
353 // multi::array<complex, 2> c({3, 3}, 9999.);
354 // using namespace multi::blas;
355 // blas::herk(blas::filling::lower, 1., blas::T(a), 0., c); // c†=c=aa†=(aa†)†, `c` in lower triangular
356 // BOOST_REQUIRE( c[1][0]==complex(50., -49.) );
357 // BOOST_REQUIRE( c[0][1]==9999. );
358 }
359 #if 1
360 {
361 multi::array<complex, 2> c({2, 2}, 9999.);
362 blas::herk(blas::U, 1., a, 0., c); // c†=c=aa†=(aa†)†, `c` in upper triangular
363 BOOST_REQUIRE( c[0][1]==complex(50., +49.) );
364 BOOST_REQUIRE( c[1][0]==9999. );
365 }
366 {
367 multi::array<complex, 2> c({2, 2}, 9999.);
368 blas::herk(1., a, c); // c†=c=aa†=(aa†)†
369 BOOST_REQUIRE( c[0][1]==complex(50., +49.) );
370 BOOST_REQUIRE( c[1][0]==complex(50., -49.) );
371 }
372 {
373 multi::array<complex, 2> c({3, 3}, 9999.);
374 blas::herk(blas::L, 1., blas::H(a), 0., c); // c†=c=aa†=(aa†)†, `c` in lower triangular
375 BOOST_REQUIRE( c[1][0]==complex(52., 90.) );
376 BOOST_REQUIRE( c[0][1]==9999. );
377 }
378 {
379 // multi::array<complex, 2> c({3, 3}, 9999.);
380 // using namespace multi::blas;
381 // herk(filling::lower, 1., transposed(a), 0., c); // c†=c=aa†=(aa†)†, `c` in lower triangular
382 // BOOST_REQUIRE( c[0][1]==9999. );
383 // BOOST_REQUIRE( c[1][0]==complex(52., 90.) );
384 }
385 #endif
386 }
387
388 #if 0
389
390 BOOST_AUTO_TEST_CASE(multi_blas_herk_complex_real_case){
391 multi::array<complex, 2> const a = {
392 { 1., 3., 4.},
393 { 9., 7., 1.}
394 };
395 namespace blas = multi::blas;
396 using blas::filling;
397 using blas::transposed;
398 using blas::hermitized;
399 {
400 multi::array<complex, 2> c({3, 3}, 9999.);
401
402 herk(filling::lower, 1., hermitized(a), 0., c);//c†=c=a†a=(a†a)†, `c` in lower triangular
403 BOOST_REQUIRE( c[2][1]==complex(19.,0.) );
404 BOOST_REQUIRE( c[1][2]==9999. );
405 }
406 {
407 multi::array<complex, 2> c({3, 3}, 9999.);
408 herk(filling::upper, 1., hermitized(a), 0., c);//c†=c=a†a=(a†a)†, `c` in lower triangular
409 BOOST_REQUIRE( c[1][2]==complex(19.,0.) );
410 BOOST_REQUIRE( c[2][1]==9999. );
411 }
412 {
413 multi::array<complex, 2> c({3, 3}, 9999.);
414 // herk(filling::upper, 1., hermitized(a), 0., transposed(c));//c†=c=a†a=(a†a)†, `c` in lower triangular
415 // print(transposed(c));
416 // BOOST_REQUIRE( c[1][2]==complex(19.,0.) );
417 // BOOST_REQUIRE( c[2][1]==9999. );
418 }
419 {
420 multi::array<complex, 2> c({3, 3}, 9999.);
421 using blas::transposed;
422 // herk(filling::upper, 1., transposed(a), 0., c);//c_†=c_=a_†a_=(a_†a_)†, `c_` in lower triangular
423 // BOOST_REQUIRE( c[2][1] == 9999. );
424 // BOOST_REQUIRE( c[1][2] == 19. );
425 }
426 }
427
428 BOOST_AUTO_TEST_CASE(multi_blas_herk_complex_basic_transparent_interface){
429 multi::array<complex, 2> const a = {
430 { 1. + 3.*I, 3.- 2.*I, 4.+ 1.*I},
431 { 9. + 1.*I, 7.- 8.*I, 1.- 3.*I}
432 };
433 namespace blas = multi::blas;
434 using blas::filling;
435 using blas::hermitized;
436 {
437 multi::array<complex, 2> c({3, 3}, 9999.);
438 herk(filling::lower, 1., hermitized(a), 0., c); // c†=c=a†a=(a†a)†, information in `c` lower triangular
439 BOOST_REQUIRE( c[2][1]==complex(41.,2.) );
440 BOOST_REQUIRE( c[1][2]==9999. );
441 }
442 {
443 multi::array<complex, 2> c({3, 3}, 9999.);
444 using multi::blas::herk;
445 herk(filling::upper, 1., hermitized(a), 0., c); // c†=c=a†a=(a†a)†, `c` in upper triangular
446 BOOST_REQUIRE( c[1][2]==complex(41., -2.) );
447 BOOST_REQUIRE( c[2][1]==9999. );
448 }
449 {
450 multi::array<complex, 2> c({2, 2}, 9999.);
451 using multi::blas::herk;
452 herk(filling::lower, 1., a, 0., c); // c†=c=aa†, `a` and `c` are c-ordering, information in c lower triangular
453 BOOST_REQUIRE( c[1][0]==complex(50., -49.) );
454 BOOST_REQUIRE( c[0][1]==9999. );
455 }
456 {
457 multi::array<complex, 2> c({2, 2}, 9999.);
458 using multi::blas::herk;
459 herk(filling::upper, 1., a, 0., c); //c†=c=aa†, `c` in upper triangular
460 BOOST_REQUIRE( c[0][1]==complex(50., 49.) );
461 BOOST_REQUIRE( c[1][0]==9999. );
462 }
463 }
464
465 BOOST_AUTO_TEST_CASE(multi_blas_herk_complex_basic_enum_interface){
466 multi::array<complex, 2> const a = {
467 { 1. + 3.*I, 3.- 2.*I, 4.+ 1.*I},
468 { 9. + 1.*I, 7.- 8.*I, 1.- 3.*I}
469 };
470 namespace blas = multi::blas;
471 using blas::filling;
472 using blas::hermitized;
473 using blas::transposed;
474 {
475 // multi::array<complex, 2> c({2, 2}, 8888.);
476 // std::cerr << "here" << std::endl;
477 // herk(filling::lower, 1., hermitized(transposed(a)), 0., c); //c†=c=a†a=(a†a)†, `c` in lower triangular
478 // print(c) << std::endl;
479 // std::cerr << "there" << std::endl;
480 // BOOST_REQUIRE( c[0][1]==complex(41.,2.) );
481 // BOOST_REQUIRE( c[1][0]==8888. );
482 }
483 {
484 multi::array<complex, 2> c({3, 3}, 9999.);
485 herk(filling::lower, 1., hermitized(a), 0., c); //c†=c=a†a=(a†a)†, `c` in lower triangular
486 BOOST_REQUIRE( c[2][1]==complex(41.,2.) );
487 BOOST_REQUIRE( c[1][2]==9999. );
488 }
489 {
490 multi::array<complex, 2> c({3, 3}, 9999.);
491 using namespace multi::blas;
492 herk(filling::upper, 1., hermitized(a), 0., c); //c†=c=a†a=(a†a)†, `c` in upper triangular
493 BOOST_REQUIRE( c[1][2]==complex(41., -2.) );
494 BOOST_REQUIRE( c[2][1]==9999. );
495 }
496 {
497 multi::array<complex, 2> c({2, 2}, 9999.);
498 using namespace multi::blas;
499 herk(filling::lower, 1., a, 0., c); // c†=c=aa†, `c` in lower triangular
500 BOOST_REQUIRE( c[1][0]==complex(50., -49.) );
501 BOOST_REQUIRE( c[0][1]==9999. );
502 }
503 {
504 multi::array<complex, 2> c({2, 2}, 9999.);
505 using namespace multi::blas;
506 herk(filling::upper, 1., a, 0., c); // c†=c=aa†, `c` in upper triangular
507 BOOST_REQUIRE( c[0][1]==complex(50., 49.) );
508 BOOST_REQUIRE( c[1][0]==9999. );
509 }
510 }
511
512 BOOST_AUTO_TEST_CASE(multi_blas_herk_complex_basic_explicit_enum_interface){
513 multi::array<complex, 2> const a = {
514 { 1. + 3.*I, 3.- 2.*I, 4.+ 1.*I},
515 { 9. + 1.*I, 7.- 8.*I, 1.- 3.*I}
516 };
517 using namespace multi::blas;
518 {
519 multi::array<complex, 2> c({3, 3}, 9999.);
520 herk(filling::lower, 1., hermitized(a), 0., c); // c†=c=a†a=(a†a)†, `c` in lower triangular
521 BOOST_REQUIRE( c[2][1]==complex(41.,2.) );
522 BOOST_REQUIRE( c[1][2]==9999. );
523 }
524 BOOST_REQUIRE( herk(hermitized(a)) == gemm(hermitized(a), a) );
525 {
526 multi::array<complex, 2> c({3, 3}, 9999.);
527 // herk(filling::lower, 1., hermitized(a), 0., transposed(c)); // c†=c=a†a=(a†a)†, `c` in lower triangular
528 // print(transposed(c));
529 // BOOST_REQUIRE( c[2][1]==complex(41.,2.) );
530 // BOOST_REQUIRE( c[1][2]==9999. );
531 }
532 {
533 multi::array<complex, 2> c({2, 2}, 9999.);
534 herk(filling::lower, 1., hermitized(transposed(a)), 0., transposed(c)); // c†=c=a†a=(a†a)†, `c` in lower triangular
535 BOOST_REQUIRE( transposed(c)[1][0]==complex(50.,+49.) );
536 BOOST_REQUIRE( transposed(c)[0][1]==9999. );
537 }
538 // BOOST_REQUIRE( herk(hermitized(transposed(a))) == gemm(hermitized(transposed(a)), transposed(a)) );
539 {
540 multi::array<complex, 2> c({3, 3}, 9999.);
541 herk(filling::upper, 1., hermitized(a), 0., c); // c†=c=a†a=(a†a)†, `c` in upper triangular
542 BOOST_REQUIRE( c[1][2]==complex(41., -2.) );
543 BOOST_REQUIRE( c[2][1]==9999. );
544 BOOST_REQUIRE( herk(hermitized(a)) == gemm(hermitized(a), a) );
545 }
546 {
547 multi::array<complex, 2> c({2, 2}, 9999.);
548 herk(filling::lower, 1., a, 0., c); // c†=c=aa†=(aa†)†, `c` in lower triangular
549 BOOST_REQUIRE( c[1][0]==complex(50., -49.) );
550 BOOST_REQUIRE( c[0][1]==9999. );
551 BOOST_REQUIRE( herk(a) == gemm(a, hermitized(a)) );
552 }
553 {
554 multi::array<complex, 2> c({2, 2}, 9999.);
555 herk(filling::upper, 1., a, 0., c); // c†=c=aa†=(aa†)†, `c` in upper triangular
556 BOOST_REQUIRE( c[0][1]==complex(50., 49.) );
557 BOOST_REQUIRE( c[1][0]==9999. );
558 BOOST_REQUIRE( herk(a) == gemm(a, hermitized(a)) );
559 }
560 {
561 multi::array<complex, 2> c({2, 2}, 9999.);
562 herk(filling::upper, 2., a, 0., c); // c†=c=aa†=(aa†)†, `c` in upper triangular
563 BOOST_REQUIRE( c[0][1]==complex(100., 98.) );
564 BOOST_REQUIRE( c[1][0]==9999. );
565 BOOST_REQUIRE( herk(2., a) == gemm(2., a, hermitized(a)) );
566 }
567 {
568 multi::array<complex, 2> c({2, 2}, 9999.);
569 herk(filling::upper, 1., a, 0., c); // c†=c=aa†=(aa†)†, `c` in upper triangular
570 BOOST_REQUIRE( c[0][1]==complex(50., 49.) );
571 BOOST_REQUIRE( c[1][0]==9999. );
572 BOOST_REQUIRE( herk(1., a) == gemm(1., a, hermitized(a)) );
573 }
574 }
575
576 BOOST_AUTO_TEST_CASE(multi_blas_herk_complex_automatic_operator_interface){
577 multi::array<complex, 2> const a = {
578 { 1. + 3.*I, 3.- 2.*I, 4.+ 1.*I},
579 { 9. + 1.*I, 7.- 8.*I, 1.- 3.*I}
580 };
581 {
582 multi::array<complex, 2> c({3, 3}, 9999.);
583 namespace blas = multi::blas;
584 using blas::filling;
585 using blas::hermitized;
586 herk(filling::lower, 1., hermitized(a), 0., c); // c=c†=a†a, `c` in lower triangular
587 BOOST_REQUIRE( c[2][1]==complex(41., 2.) );
588 BOOST_REQUIRE( c[1][2]==9999. );
589 }
590 {
591 multi::array<complex, 2> c({2, 2}, 9999.);
592 using multi:: blas::filling;
593 herk(filling::lower, 1., a, 0., c); // c=c†=aa†, `c` in lower triangular
594 BOOST_REQUIRE( c[1][0]==complex(50., -49.) );
595 BOOST_REQUIRE( c[0][1]==9999. );
596 }
597 {
598 multi::array<complex, 2> c({2, 2}, 9999.);
599 using multi::blas::herk;
600 herk(1., a, c); // c=c†=aa†
601 BOOST_REQUIRE( c[1][0]==complex(50., -49.) );
602 BOOST_REQUIRE( c[0][1]==complex(50., +49.) );
603 }
604 {
605 multi::array<complex, 2> c({3, 3}, 9999.);
606 namespace blas = multi::blas;
607 using blas::filling;
608 using blas::hermitized;
609 herk(filling::lower, 1., hermitized(a), 0., c); // c=c†=a†a, `c` in lower triangular
610 herk(filling::upper, 1., hermitized(a), 0., c);
611 BOOST_REQUIRE( c[2][1]==complex(41., 2.) );
612 BOOST_REQUIRE( c[1][2]==complex(41., -2.) );
613 }
614 }
615
616 BOOST_AUTO_TEST_CASE(multi_blas_herk_complex_automatic_operator_interface_implicit_no_sum){
617 multi::array<complex, 2> const a = {
618 { 1. + 3.*I, 3.- 2.*I, 4.+ 1.*I},
619 { 9. + 1.*I, 7.- 8.*I, 1.- 3.*I}
620 };
621 {
622 multi::array<complex, 2> c({3, 3}, 9999.);
623 namespace blas = multi::blas;
624 using blas::filling;
625 using blas::hermitized;
626 herk(filling::lower, 1., hermitized(a), c); // c=c†=a†a, `c` in lower triangular
627 BOOST_REQUIRE( c[2][1]==complex(41., 2.) );
628 BOOST_REQUIRE( c[1][2]==9999. );
629 }
630 {
631 multi::array<complex, 2> c({2, 2}, 9999.);
632 using multi::blas::filling;
633 herk(filling::lower, 1., a, c); // c=c†=aa†, `c` in lower triangular
634 BOOST_REQUIRE( c[1][0]==complex(50., -49.) );
635 BOOST_REQUIRE( c[0][1]==9999. );
636 }
637 }
638
639 BOOST_AUTO_TEST_CASE(multi_blas_herk_complex_automatic_ordering_and_symmetrization){
640
641 multi::array<complex, 2> const a = {
642 { 1. + 3.*I, 3.- 2.*I, 4.+ 1.*I},
643 { 9. + 1.*I, 7.- 8.*I, 1.- 3.*I}
644 };
645 namespace blas = multi::blas;
646 using blas::herk;
647 using blas::hermitized;
648 using blas::filling;
649 {
650 multi::array<complex, 2> c({3, 3}, 9999.);
651 herk(filling::upper, 1., hermitized(a), c); // c†=c=a†a
652 BOOST_REQUIRE( c[2][1]==9999. );
653 BOOST_REQUIRE( c[1][2]==complex(41., -2.) );
654 }
655 {
656 multi::array<complex, 2> c({3, 3}, 9999.);
657 herk(1., hermitized(a), c); // c†=c=a†a
658 BOOST_REQUIRE( c[2][1]==complex(41., +2.) );
659 BOOST_REQUIRE( c[1][2]==complex(41., -2.) );
660 }
661 {
662 multi::array<complex, 2> c({2, 2}, 9999.);
663 herk(filling::upper, 1., a, c); // c†=c=aa† // c implicit hermitic in upper
664 BOOST_REQUIRE( c[1][0] == 9999. );
665 BOOST_REQUIRE( c[0][1] == complex(50., +49.) );
666 }
667 {
668 multi::array<complex, 2> c({2, 2}, 9999.);
669 herk(1., a, c); // c†=c=aa†
670 BOOST_REQUIRE( c[1][0] == complex(50., -49.) );
671 BOOST_REQUIRE( c[0][1] == complex(50., +49.) );
672 }
673 {
674 multi::array<complex, 2> c = herk(filling::upper, 1., a); // c†=c=aa†
675 // BOOST_REQUIRE( c[1][0] == complex(50., -49.) );
676 BOOST_REQUIRE( c[0][1] == complex(50., +49.) );
677 }
678 {
679 using multi::blas::herk;
680 using multi::blas::filling;
681 multi::array<complex, 2> c = herk(1., a); // c†=c=aa†
682 BOOST_REQUIRE( c[1][0] == complex(50., -49.) );
683 BOOST_REQUIRE( c[0][1] == complex(50., +49.) );
684 }
685 {
686 using multi::blas::herk;
687 using multi::blas::hermitized;
688 using multi::blas::filling;
689 multi::array<complex, 2> c = herk(filling::upper, 1., hermitized(a)); // c†=c=a†a
690
691 BOOST_REQUIRE( size(hermitized(a))==3 );
692 // BOOST_REQUIRE( c[2][1] == complex(41., +2.) );
693 BOOST_REQUIRE( c[1][2] == complex(41., -2.) );
694 }
695 {
696 using multi::blas::herk;
697 using multi::blas::filling;
698 multi::array<complex, 2> c = herk(filling::upper, a); // c†=c=a†a
699 // what(multi::pointer_traits<decltype(base(a))>::default_allocator_of(base(a)));
700 // BOOST_REQUIRE( c[1][0] == complex(50., -49.) );
701 BOOST_REQUIRE( c[0][1] == complex(50., +49.) );
702 }
703 {
704 using multi::blas::herk;
705 using multi::blas::hermitized;
706 using multi::blas::filling;
707 multi::array<complex, 2> c = herk(filling::upper, hermitized(a)); // c†=c=a†a
708 // BOOST_REQUIRE( c[2][1] == complex(41., +2.) );
709 BOOST_REQUIRE( c[1][2] == complex(41., -2.) );
710 }
711 }
712
713 BOOST_AUTO_TEST_CASE(multi_blas_herk_complex_size1_real_case){
714 multi::array<complex, 2> const a = {
715 {1., 3., 4.}
716 };
717 using namespace multi::blas;
718 {
719 multi::array<complex, 2> c({1, 1}, 9999.);
720 herk(filling::upper, 1., a, c); // c†=c=aa†
721 BOOST_TEST( c[0][0] == 26. );
722 }
723 BOOST_TEST( herk(a) == gemm(a, hermitized(a)) );
724 }
725
726 BOOST_AUTO_TEST_CASE(multi_blas_herk_complex_size1){
727 multi::array<complex, 2> const a = {
728 {1. + 4.*I, 3. + 2.*I, 4. - 1.*I}
729 };
730 using namespace multi::blas;
731 {
732 multi::array<complex, 2> c({1, 1}, 9999.);
733 herk(filling::upper, 1., a, c); // c†=c=aa†
734 BOOST_TEST( c[0][0] == 47. );
735 }
736 }
737
738 BOOST_AUTO_TEST_CASE(multi_blas_herk_complex_size0){
739 multi::array<complex, 2> const a;
740 using namespace multi::blas;
741 {
742 multi::array<complex, 2> c;
743 herk(filling::upper, 1., a, c); // c†=c=aa†
744 // BOOST_TEST( c[0][0] == 47. );
745 }
746 }
747
748
749 BOOST_AUTO_TEST_CASE(multi_blas_herk_complex_automatic_ordering_and_symmetrization_real_case){
750
751 multi::array<complex, 2> const a = {
752 { 1., 3., 4.},
753 { 9., 7., 1.}
754 };
755 using namespace multi::blas;
756 {
757 multi::array<complex, 2> c({3, 3}, 9999.);
758 herk(filling::upper, 1., hermitized(a), c); // c†=c=a†a
759 // BOOST_REQUIRE( c[2][1]==19. );
760 BOOST_REQUIRE( c[1][2]==19. );
761 }
762 {
763 multi::array<complex, 2> c({2, 2}, 9999.);
764 herk(filling::upper, 1., a, c); // c†=c=aa†
765 // BOOST_REQUIRE( c[1][0] == 34. );
766 BOOST_REQUIRE( c[0][1] == 34. );
767 }
768 {
769 multi::array<complex, 2> c = herk(filling::upper, 1., a); // c†=c=aa†
770 // BOOST_REQUIRE( c[1][0] == 34. );
771 BOOST_REQUIRE( c[0][1] == 34. );
772 }
773 {
774 multi::array<complex, 2> c = herk(filling::upper, 1., hermitized(a)); // c†=c=a†a
775 BOOST_REQUIRE( size(hermitized(a))==3 );
776 // BOOST_REQUIRE( c[2][1]==19. );
777 BOOST_REQUIRE( c[1][2]==19. );
778 }
779 {
780 multi::array<complex, 2> c = herk(filling::upper, a); // c†=c=a†a
781 // BOOST_REQUIRE( c[1][0] == 34. );
782 BOOST_REQUIRE( c[0][1] == 34. );
783 }
784 {
785 multi::array<complex, 2> c = herk(filling::upper, hermitized(a)); // c†=c=a†a
786 // BOOST_REQUIRE( c[2][1]==19. );
787 BOOST_REQUIRE( c[1][2]==19. );
788 }
789 }
790
791
792 BOOST_AUTO_TEST_CASE(multi_blas_herk_real_automatic_ordering_and_symmetrization_real_case){
793
794 multi::array<double, 2> const a = {
795 { 1., 3., 4.},
796 { 9., 7., 1.}
797 };
798 {
799 multi::array<double, 2> c({3, 3}, 9999.);
800 using multi::blas::hermitized;
801 using multi::blas::herk;
802 using multi::blas::filling;
803 // herk(filling::upper, 1., hermitized(a), c); // c†=c=a†a
804 // BOOST_REQUIRE( c[2][1]==19. );
805 // BOOST_REQUIRE( c[1][2]==19. );
806 }
807 {
808 multi::array<double, 2> c({2, 2}, 9999.);
809 using multi::blas::herk;
810 using multi::blas::filling;
811 herk(filling::upper, 1., a, c); // c†=c=aa†
812 // BOOST_REQUIRE( c[1][0] == 34. );
813 BOOST_REQUIRE( c[0][1] == 34. );
814 }
815 {
816 multi::array<double, 2> c({2, 2}, 9999.);
817 using multi::blas::herk;
818 using multi::blas::filling;
819 herk(filling::upper, 1., a, c); // c†=c=aa†
820 // BOOST_REQUIRE( c[1][0] == 34. );
821 BOOST_REQUIRE( c[0][1] == 34. );
822 }
823 {
824 using multi::blas::herk;
825 using multi::blas::filling;
826 multi::array<double, 2> c = herk(filling::upper, 1., a); // c†=c=aa†
827 // BOOST_REQUIRE( c[1][0] == 34. );
828 BOOST_REQUIRE( c[0][1] == 34. );
829 }
830 {
831 using multi::blas::herk;
832 multi::array<complex, 2> c = herk(a); // c†=c=a†a
833 BOOST_REQUIRE( c[1][0] == 34. );
834 BOOST_REQUIRE( c[0][1] == 34. );
835 }
836 {
837 using multi::blas::herk;
838 using multi::blas::hermitized;
839 multi::array<complex, 2> c = herk(hermitized(a)); // c†=c=a†a
840 BOOST_REQUIRE( c[2][1]==19. );
841 BOOST_REQUIRE( c[1][2]==19. );
842 }
843 }
844
845 BOOST_AUTO_TEST_CASE(multi_blas_herk_real_case){
846 multi::array<double, 2> const a = {
847 { 1., 3., 4.},
848 { 9., 7., 1.}
849 };
850 using multi::blas::filling;
851 {
852 static_assert( not boost::multi::blas::is_complex_array<multi::array<double, 2>>{} , "!");
853 multi::array<double, 2> c({2, 2}, 9999.);
854 syrk(filling::lower, 1., a, 0., c);//c†=c=aa†=(aa†)†, `c` in lower triangular
855 }
856 {
857 multi::array<double, 2> c({2, 2}, 9999.);
858 herk(filling::lower, 1., a, 0., c);//c†=c=aa†=(aa†)†, `c` in lower triangular
859 }
860 {
861 static_assert( not boost::multi::blas::is_complex_array<multi::array<double, 2>>{} , "!");
862 multi::array<double, 2> c = herk(filling::upper, a);//c†=c=aa†=(aa†)†, `c` in lower triangular
863 }
864 }
865
866 BOOST_AUTO_TEST_CASE(multi_blas_herk_complex_real_case_1d){
867 multi::array<complex, 2> const a = {
868 { 1., 3., 4.},
869 };
870 namespace blas = multi::blas;
871 using blas::filling;
872 using blas::transposed;
873 using blas::hermitized;
874 {
875 multi::array<complex, 2> c({3, 3}, 9999.);
876 herk(filling::lower, 1., hermitized(a), 0., c);//c†=c=a†a=(a†a)†, `c` in lower triangular
877 print(c);
878 BOOST_REQUIRE( c[2][1]==complex(12.,0.) );
879 BOOST_REQUIRE( c[1][2]==9999. );
880 }
881 {
882 multi::array<complex, 2> c({3, 3}, 9999.);
883 herk(2., hermitized(a), c);//c†=c=a†a=(a†a)†, `c` in lower triangular
884
885 BOOST_REQUIRE( c[2][1]==complex(24.,0.) );
886 BOOST_REQUIRE( c[1][2]==complex(24.,0.) );
887 multi::array<complex, 2> c_gemm({3, 3});
888 // gemm(2., hermitized(a), a, c_gemm);
889 }
890 }
891 #endif
892
893 #if 0
894 BOOST_AUTO_TEST_CASE(multi_blas_herk_complex_timing){
895 multi::array<complex, 2> const a({4000, 4000}); std::iota(data_elements(a), data_elements(a) + num_elements(a), 0.2);
896 multi::array<complex, 2> c({4000, 4000}, 9999.);
897 boost::timer::auto_cpu_timer t;
898 using multi::blas::herk;
899 using multi::blas::hermitized;
900 using multi::blas::filling;
901 herk(filling::upper, 1., hermitized(a), c); // c†=c=a†a
902 }
903 #endif
904 #endif
905 #endif
906
907 #endif
908 #endif
909
910