1 // -*-indent-tabs-mode:t;c-basic-offset:4;tab-width:4;-*-
2 // © Alfredo A. Correa 2019-2020
3
4 #define BOOST_TEST_MODULE "C++ Unit Tests for Multi BLAS gemm"
5 #define BOOST_TEST_DYN_LINK
6 #include<boost/test/unit_test.hpp>
7
8 #include "../../../adaptors/blas/gemm.hpp"
9
10 #include "config.hpp"
11
12 #include "../../../array.hpp"
13
14 #include<random>
15
16 namespace multi = boost::multi;
17 namespace blas = multi::blas;
18
BOOST_AUTO_TEST_CASE(multi_blas_gemm_square_real)19 BOOST_AUTO_TEST_CASE(multi_blas_gemm_square_real){
20 multi::array<double, 2> const a = {
21 {1, 3, 4},
22 {9, 7, 1},
23 {1, 2, 3}
24 };
25 multi::array<double, 2> const b = {
26 {11, 12, 4},
27 { 7, 19, 1},
28 {11, 12, 4}
29 };
30 {
31 multi::array<double, 2> c({size(a), size(~b)}, 9999);
32 blas::gemm(1., a, b, 0., c);
33 BOOST_REQUIRE( c[2][1] == 86 );
34 }
35 {
36 multi::array<double, 2> c({size(a), size(~b)}, 9999);
37 BOOST_REQUIRE( size( a) == size( c) );
38 BOOST_REQUIRE( size(~b) == size(~c) );
39 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c));
40 BOOST_REQUIRE( c[2][1] == 86 );
41 }
42 {
43 multi::array<double, 2> c({size(a), size(~b)}, 9999);
44 blas::gemm(1., a, blas::T(b), 0., c);
45 BOOST_REQUIRE( c[2][1] == 48 );
46 }
47 {
48 multi::array<double, 2> c({size(a), size(~b)}, 9999);
49 blas::gemm_n(1., a.begin(), a.size(), blas::T(b).begin(), 0., c.begin());
50 BOOST_REQUIRE( c[2][1] == 48 );
51 }
52 {
53 multi::array<double, 2> c({size(a), size(~b)}, 9999);
54 blas::gemm(1., blas::T(a), b, 0., c);
55 BOOST_REQUIRE( c[2][1] == 103 );
56 }
57 {
58 multi::array<double, 2> c({size(a), size(~b)}, 9999);
59 blas::gemm_n(1., begin(blas::T(a)), size(blas::T(a)), begin(b), 0., begin(c));
60 BOOST_REQUIRE( c[2][1] == 103 );
61 }
62 {
63 multi::array<double, 2> c({size(a), size(~b)}, 9999);
64 blas::gemm(1., blas::T(a), blas::T(b), 0., c);
65 BOOST_REQUIRE( c[2][1] == 50 );
66 }
67 {
68 multi::array<double, 2> c({size(a), size(~b)}, 9999);
69 blas::gemm_n(1., begin(blas::T(a)), size(blas::T(a)), begin(blas::T(b)), 0., begin(c));
70 BOOST_REQUIRE( c[2][1] == 50 );
71 }
72 {
73 multi::array<double, 2> c({size(a), size(~b)}, 9999);
74 blas::gemm(1., a, blas::T(b), 0., c);
75 BOOST_REQUIRE( c[2][1] == 48 );
76 }
77 {
78 multi::array<double, 2> c({size(a), size(~b)}, 9999);
79 blas::gemm_n(1., begin(a), size(a), begin(blas::T(b)), 0., begin(c));
80 BOOST_REQUIRE( c[2][1] == 48 );
81 }
82 {
83 multi::array<double, 2> c({size(a), size(~b)}, 9999);
84 blas::gemm(1., blas::T(a), b, 0., c);
85 BOOST_REQUIRE( c[2][1] == 103 );
86 }
87 {
88 multi::array<double, 2> c({size(a), size(~b)}, 9999);
89 blas::gemm_n(1., begin(blas::T(a)), size(blas::T(a)), begin(b), 0., begin(c));
90 BOOST_REQUIRE( c[2][1] == 103 );
91 }
92 {
93 multi::array<double, 2> c({size(a), size(rotated(b))}, 9999);
94 blas::gemm(2., blas::H(a), blas::H(b), 0., c);
95 BOOST_REQUIRE( c[2][1] == 100 );
96 }
97 {
98 multi::array<double, 2> c = blas::gemm(2., blas::H(a), blas::H(b));
99 BOOST_REQUIRE( c[2][1] == 100 );
100 }
101 {
102 multi::array<double, 2> const c = blas::gemm(2., blas::H(a), blas::H(b));
103 BOOST_REQUIRE( c[2][1] == 100 );
104 }
105 {
106 multi::array<double, 2> c({size(a), size(rotated(b))}, 9999);
107 c = blas::gemm(2., blas::H(a), blas::H(b));
108 BOOST_REQUIRE( c[2][1] == 100 );
109 }
110 {
111 multi::array<double, 2> c;
112 c = blas::gemm(2., blas::H(a), blas::H(b));
113 BOOST_REQUIRE( c[2][1] == 100 );
114 }
115 // {
116 // multi::array<double, 2> c({size(a), size(rotated(b))}, 9999);
117 // blas::gemm(2., blas::H(a), blas::H(b), 0., c);
118 // BOOST_REQUIRE( c[2][1] == 100 );
119
120 // multi::array<double, 2> const c_copy = blas::gemm(2., blas::H(a), blas::H(b));
121 // BOOST_REQUIRE( c == c_copy );
122 // multi::array<double, 2> const c_copy2 = blas::gemm(1., blas::H(a), blas::H(b));
123 // BOOST_REQUIRE( c_copy2[2][1] == 50 );
124 // }
125 {
126 multi::array<double, 2> c({size(a), size(rotated(b))}, 9999);
127 blas::gemm_n(2., begin(blas::H(a)), size(blas::H(a)), begin(blas::H(b)), 0., begin(c));
128 BOOST_REQUIRE( c[2][1] == 100 );
129 }
130 }
131
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_square)132 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_square){
133 multi::array<double, 2> const a = {
134 { 1, 3},
135 { 9, 7},
136 };
137 multi::array<double, 2> const b = {
138 { 11, 12},
139 { 7, 19},
140 };
141 {
142 multi::array<double, 2> c({2, 2});
143 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
144 BOOST_REQUIRE( c[1][0] == 148 );
145 }
146 {
147 multi::array<double, 2> c({2, 2});
148 blas::context ctxt;
149 blas::gemm_n(ctxt, 1., begin(a), size(a), begin(b), 0., begin(c));
150 BOOST_REQUIRE( c[1][0] == 148 );
151 }
152 {
153 multi::array<double, 2> c({2, 2});
154 blas::gemm(1., ~a, b, 0., c); // c=a⸆b, c⸆=b⸆a
155 BOOST_REQUIRE(( c[1][1] == 169 and c[1][0] == 82 ));
156 }
157 {
158 multi::array<double, 2> c({2, 2});
159 blas::context ctxt;
160 blas::gemm_n(ctxt, 1., begin(~a), size(~a), begin(b), 0., begin( c));
161 BOOST_REQUIRE(( c[1][1] == 169 and c[1][0] == 82 ));
162 }
163 {
164 multi::array<double, 2> c({2, 2});
165 blas::context ctxt;
166 blas::gemm_n(ctxt, 1., begin(~a), size(~a), begin(b), 0., begin(~c));
167 BOOST_REQUIRE( (~c)[1][1] == 169 );
168 BOOST_REQUIRE( (~c)[1][0] == 82 );
169 }
170 {
171 multi::array<double, 2> c({2, 2});
172 blas::gemm(1., a, ~b, 0., c); // c=ab⸆, c⸆=ba⸆
173 BOOST_REQUIRE( c[1][0] == 183 );
174 }
175 {
176 multi::array<double, 2> c({2, 2});
177 blas::context ctxt;
178 blas::gemm_n(ctxt, 1., begin(a), size(a), begin(~b), 0., begin(c)); // c=ab⸆, c⸆=ba⸆
179 BOOST_REQUIRE( c[1][0] == 183 );
180 }
181 {
182 multi::array<double, 2> c({2, 2});
183 blas::gemm(1., a, ~b, 0., ~c); // c=ab⸆, c⸆=ba⸆
184 BOOST_REQUIRE( (~c)[1][0] == 183 );
185 }
186 {
187 multi::array<double, 2> c({2, 2});
188 blas::gemm_n(1., begin(a), size(a), begin(~b), 0., begin(~c)); // c=ab⸆, c⸆=ba⸆
189 BOOST_REQUIRE( (~c)[1][0] == 183 );
190 }
191 {
192 multi::array<double, 2> c({2, 2});
193 blas::gemm(1., ~a, ~b, 0., c); // c=a⸆b⸆, c⸆=ba
194 BOOST_REQUIRE( c[1][0] == 117 );
195 }
196 {
197 multi::array<double, 2> c({2, 2});
198 blas::gemm_n(1., begin(~a), size(~a), begin(~b), 0., begin(c)); // c=a⸆b⸆, c⸆=ba
199 BOOST_REQUIRE( c[1][0] == 117 );
200 }
201 {
202 multi::array<double, 2> c({2, 2});
203 blas::gemm(1., ~a, ~b, 0., ~c); // c⸆=a⸆b⸆, c=ba
204 BOOST_REQUIRE( c[0][1] == 117 );
205 }
206 {
207 multi::array<double, 2> c({2, 2});
208 blas::gemm_n(1., begin(~a), size(~a), begin(~b), 0., begin(~c)); // c⸆=a⸆b⸆, c=ba
209 BOOST_REQUIRE( c[0][1] == 117 );
210 }
211 }
212
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_nonsquare)213 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_nonsquare){
214 multi::array<double, 2> const a = {
215 { 1, 3, 1},
216 { 9, 7, 1},
217 };
218 multi::array<double, 2> const b = {
219 { 11, 12, 1},
220 { 7, 19, 1},
221 { 1, 1, 1}
222 };
223 {
224 multi::array<double, 2> c({2, 3});
225 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
226 BOOST_REQUIRE( c[1][2] == 17 );
227 }
228 {
229 multi::array<double, 2> c({2, 3});
230 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
231 BOOST_REQUIRE( c[1][2] == 17 );
232 }
233 }
234
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_nonsquare_automatic)235 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_nonsquare_automatic){//, *utf::tolerance(0.00001)){
236 namespace blas = multi::blas;
237 multi::array<double, 2> const a = {
238 { 1., 3., 1.},
239 { 9., 7., 1.},
240 };
241 multi::array<double, 2> const b = {
242 { 11., 12., 4., 8.},
243 { 7., 19., 2., 7.},
244 { 5., 3., 3., 1.}
245 };
246 {
247 multi::array<double, 2> c({size(a), size(~b)});
248 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
249 BOOST_REQUIRE( c[1][2] == 53 );
250 }
251 {
252 multi::array<double, 2> c({size(a), size(~b)});
253 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
254 BOOST_REQUIRE( c[1][2] == 53 );
255 }
256 {
257 multi::array<double, 2> c({2, 4});
258 blas::gemm(0.1, a, b, 0., c); // c=ab, c⸆=b⸆a⸆
259 BOOST_REQUIRE_CLOSE( c[1][2] , 5.3 , 0.00001 );
260 }
261 {
262 multi::array<double, 2> c({2, 4});
263 blas::gemm_n(0.1, begin(a), size(a), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
264 BOOST_REQUIRE_CLOSE( c[1][2] , 5.3 , 0.00001 );
265 }
266 {
267 auto c =+ blas::gemm(0.1, a, b); // c=ab, c⸆=b⸆a⸆
268 BOOST_REQUIRE_CLOSE( c[1][2] , 5.3 , 0.00001 );
269 }
270 {
271 multi::array c = blas::gemm(0.1, a, b);
272 BOOST_REQUIRE_CLOSE( c[1][2] , 5.3 , 0.00001 );
273 }
274 }
275
BOOST_AUTO_TEST_CASE(multi_blas_gemm_nh)276 BOOST_AUTO_TEST_CASE(multi_blas_gemm_nh){
277 using complex = std::complex<double>; complex const I{0,1};
278 multi::array<complex, 2> const a = {
279 {1.-2.*I, 9.-1.*I},
280 {2.+3.*I, 1.-2.*I}
281 };
282 {
283 auto c =+ blas::gemm(1., a, blas::H(a)); // c=aa†, c†=aa†
284 BOOST_REQUIRE( c[1][0] == 7.-10.*I );
285 BOOST_REQUIRE( c[0][1] == 7.+10.*I );
286 }
287 {
288 multi::array c = blas::gemm(1., a, blas::H(a)); // c=aa†, c†=aa†
289 BOOST_REQUIRE( c[1][0] == 7.-10.*I );
290 BOOST_REQUIRE( c[0][1] == 7.+10.*I );
291 }
292 {
293 multi::array<complex, 2> c = blas::gemm(1., a, blas::H(a)); // c=aa†, c†=aa†
294 BOOST_REQUIRE( c[1][0] == 7.-10.*I );
295 BOOST_REQUIRE( c[0][1] == 7.+10.*I );
296 }
297 {
298 multi::array<complex, 2> c({2, 2}, 9999.);
299 c = blas::gemm(1., a, blas::H(a)); // c=aa†, c†=aa†
300 BOOST_REQUIRE( c[1][0] == 7.-10.*I );
301 BOOST_REQUIRE( c[0][1] == 7.+10.*I );
302 }
303 {
304 multi::array<complex, 2> c({2, 2}, 9999.);
305 c() = blas::gemm(1., a, blas::H(a)); // c=aa†, c†=aa†
306 BOOST_REQUIRE( c[1][0] == 7.-10.*I );
307 BOOST_REQUIRE( c[0][1] == 7.+10.*I );
308 }
309 {
310 multi::array<complex, 2> c({2, 2}, 9999.);
311 blas::gemm(1., a, blas::H(a), 0., c); // c=aa†, c†=aa†
312 BOOST_REQUIRE( c[1][0] == 7.-10.*I );
313 BOOST_REQUIRE( c[0][1] == 7.+10.*I );
314 }
315 {
316 multi::array<complex, 2> c({2, 2}, 9999.);
317 blas::gemm_n(1., begin(a), size(a), begin(blas::H(a)), 0., begin(c)); // c=aa†, c†=aa†
318 BOOST_REQUIRE( c[1][0] == 7.-10.*I );
319 BOOST_REQUIRE( c[0][1] == 7.+10.*I );
320 }
321 }
322
323 #if CUDA_FOUND
324 #include<thrust/complex.h>
BOOST_AUTO_TEST_CASE(multi_blas_gemm_nh_thrust)325 BOOST_AUTO_TEST_CASE(multi_blas_gemm_nh_thrust){
326 using complex = thrust::complex<double>; complex const I{0, 1};
327 multi::array<complex, 2> const a = {
328 {1.-2.*I, 9.-1.*I},
329 {2.+3.*I, 1.-2.*I}
330 };
331 {
332 auto c =+ blas::gemm(1., a, blas::hermitized(a)); // c=aa†, c†=aa†
333 BOOST_REQUIRE( c[1][0] == 7.-10.*I );
334 BOOST_REQUIRE( c[0][1] == 7.+10.*I );
335 }
336 {
337 multi::array c = blas::gemm(1., a, blas::hermitized(a)); // c=aa†, c†=aa†
338 BOOST_REQUIRE( c[1][0] == 7.-10.*I );
339 BOOST_REQUIRE( c[0][1] == 7.+10.*I );
340 }
341 {
342 multi::array<complex, 2> c = blas::gemm(1., a, blas::hermitized(a)); // c=aa†, c†=aa†
343 BOOST_REQUIRE( c[1][0] == 7.-10.*I );
344 BOOST_REQUIRE( c[0][1] == 7.+10.*I );
345 }
346 {
347 multi::array<complex, 2> c({2, 2});
348 c = blas::gemm(1., a, blas::hermitized(a)); // c=aa†, c†=aa†
349 BOOST_REQUIRE( c[1][0] == 7.-10.*I );
350 BOOST_REQUIRE( c[0][1] == 7.+10.*I );
351 }
352 {
353 multi::array<complex, 2> c({2, 2});
354 blas::gemm(1., a, blas::hermitized(a), 0., c); // c=aa†, c†=aa†
355 BOOST_REQUIRE( c[1][0] == 7.-10.*I );
356 BOOST_REQUIRE( c[0][1] == 7.+10.*I );
357 }
358 {
359 multi::array<complex, 2> c({2, 2});
360 blas::gemm_n(1., begin(a), size(a), begin(blas::H(a)), 0., begin(c)); // c=aa†, c†=aa†
361 BOOST_REQUIRE( c[1][0] == 7.-10.*I );
362 BOOST_REQUIRE( c[0][1] == 7.+10.*I );
363 }
364 }
365 #endif
366
BOOST_AUTO_TEST_CASE(multi_blas_gemm_elongated)367 BOOST_AUTO_TEST_CASE(multi_blas_gemm_elongated){
368 using complex = std::complex<double>; complex const I{0, 1};
369 multi::array<complex, 2> const a = {
370 {1.-2.*I, 9.-1.*I}
371 };
372 {
373 multi::array<complex, 2> c({1, 1});
374 blas::gemm(1., a, blas::H(a), 0., c); // c=aa†, c†=aa†
375 BOOST_REQUIRE( c[0][0] == 87. + 0.*I );
376 }
377 {
378 multi::array<complex, 2> c({1, 1});
379 blas::gemm_n(1., begin(a), size(a), begin(blas::H(a)), 0., begin(c)); // c=aa†, c†=aa†
380 BOOST_REQUIRE( c[0][0] == 87. + 0.*I );
381 }
382 }
383
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_3x1_3x1_bisbis)384 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_3x1_3x1_bisbis){
385 using complex = std::complex<double>; complex const I{0, 1};
386 multi::array<complex, 2> const a = {
387 {1. + 2.*I},
388 {9. - 1.*I},
389 {1. + 1.*I}
390 };
391 multi::array<complex, 2> const b = {
392 { 11. - 2.*I, 7. - 3.*I, 8. - 1.*I}
393 };
394 {
395 multi::array<complex, 2> c({1, 1});
396
397 BOOST_REQUIRE( size(blas::H(a)) == 1 );
398 BOOST_REQUIRE( size(blas::H(b)[0]) == 1 );
399
400 blas::gemm(1., blas::H(a), blas::H(b), 0., c); // c=ab, c⸆=b⸆a⸆
401 BOOST_REQUIRE( c[0][0] == 84.+7.*I );
402 }
403 {
404 multi::array<complex, 2> c({1, 1});
405 blas::gemm_n(1., begin(blas::H(a)), size(blas::H(a)), begin(blas::H(b)), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
406 BOOST_REQUIRE( c[0][0] == 84.+7.*I );
407 }
408 }
409
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_empty)410 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_empty){
411 multi::array<double, 2> const a({0, 5});
412 BOOST_REQUIRE( size( a) == 0 );
413 BOOST_REQUIRE( size(~a) == 5 );
414 BOOST_REQUIRE( a.is_empty() );
415
416 multi::array<double, 2> const b({5, 0});
417 BOOST_REQUIRE( size( b) == 0 );
418 BOOST_REQUIRE( size(~b) == 0 );
419 BOOST_REQUIRE( b.is_empty() );
420 {
421 multi::array<double, 2> c;
422 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
423 }
424 {
425 multi::array<double, 2> c;
426 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
427 }
428 }
429
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_nonsquare2)430 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_nonsquare2){
431 multi::array<double, 2> const a = {
432 { 1, 3},
433 { 9, 7},
434 { 1, 1}
435 };
436 multi::array<double, 2> const b = {
437 { 11, 12},
438 { 7, 19}
439 };
440 {
441 multi::array<double, 2> c({size(a), size(~b)});
442 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
443 BOOST_REQUIRE( c[2][1] == 31 );
444 }
445 {
446 multi::array<double, 2> c({size(a), size(~b)});
447 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
448 BOOST_REQUIRE( c[2][1] == 31 );
449 }
450 {
451 multi::array<double, 2> c({size(~b), size(a)});
452 blas::gemm(1., a, b, 0., ~c); // c=ab, c⸆=b⸆a⸆
453 BOOST_REQUIRE( c[1][2] == 31 );
454 }
455 {
456 multi::array<double, 2> c({size(~b), size(a)});
457 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(~c)); // c=ab, c⸆=b⸆a⸆
458 BOOST_REQUIRE( c[1][2] == 31 );
459 }
460 {
461 auto ar = +~a;
462 multi::array<double, 2> c({3, 2});
463 blas::gemm(1., ~ar, b, 0., c); // c=ab, c⸆=b⸆a⸆
464 BOOST_REQUIRE( c[2][1] == 31 );
465 }
466 {
467 auto ar = +~a;
468 multi::array<double, 2> c({3, 2});
469 blas::gemm_n(1., begin(~ar), size(~ar), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
470 BOOST_REQUIRE( c[2][1] == 31 );
471 }
472 {
473 auto ar = +~a;
474 multi::array<double, 2> c({2, 3});
475 blas::gemm(1., ~ar, b, 0., ~c); // c=ab, c⸆=b⸆a⸆
476 BOOST_REQUIRE( c[1][2] == 31 );
477 }
478 {
479 auto ar = +~a;
480 multi::array<double, 2> c({2, 3});
481 blas::gemm_n(1., begin(~ar), size(~ar), begin(b), 0., begin(~c)); // c=ab, c⸆=b⸆a⸆
482 BOOST_REQUIRE( c[1][2] == 31 );
483 }
484 }
485
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_2x2_2x2)486 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_2x2_2x2){
487 multi::array<double, 2> const a = {
488 { 1, 3},
489 { 9, 4},
490 };
491 multi::array<double, 2> const b = {
492 { 11, 12},
493 { 7, 19},
494 };
495 {
496 multi::array<double, 2> c({2, 2});
497 blas::gemm(1., ~a, b, 0., c); // c=a⸆b, c⸆=b⸆a
498 BOOST_REQUIRE( c[1][0] == 61 );
499
500 blas::gemm(1., ~a, b, 0., ~c); // c⸆=a⸆b, c=b⸆a
501 BOOST_REQUIRE( c[0][1] == 61 );
502 }
503 {
504 multi::array<double, 2> c({2, 2});
505 blas::gemm_n(1., begin(~a), size(~a), begin(b), 0., begin( c)); // c=a⸆b, c⸆=b⸆a
506 BOOST_REQUIRE( c[1][0] == 61 );
507
508 blas::gemm_n(1., begin(~a), size(~a), begin(b), 0., begin(~c)); // c⸆=a⸆b, c=b⸆a
509 BOOST_REQUIRE( c[0][1] == 61 );
510 }
511 }
512
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_2x3_3x2)513 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_2x3_3x2){
514 multi::array<double, 2> const a = {
515 { 1, 3},
516 { 9, 4},
517 { 1, 5}
518 };
519 multi::array<double, 2> const b = {
520 { 11, 12},
521 { 7, 19},
522 { 8, 1 }
523 };
524 {
525 multi::array<double, 2> c({2, 2});
526 blas::gemm(1., ~a, b, 0., c); // c=a⸆b, c⸆=b⸆a
527 BOOST_REQUIRE( c[1][0] == 101 );
528
529 blas::gemm(1., ~a, b, 0., ~c); // c⸆=a⸆b, c=b⸆a
530 BOOST_REQUIRE( c[0][1] == 101 );
531 }
532 {
533 multi::array<double, 2> c({2, 2});
534 blas::gemm_n(1., begin(~a), size(~a), begin(b), 0., begin( c)); // c=a⸆b, c⸆=b⸆a
535 BOOST_REQUIRE( c[1][0] == 101 );
536
537 blas::gemm_n(1., begin(~a), size(~a), begin(b), 0., begin(~c)); // c⸆=a⸆b, c=b⸆a
538 BOOST_REQUIRE( c[0][1] == 101 );
539 }
540 }
541
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_1x3_3x2)542 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_1x3_3x2){
543 multi::array<double, 2> const a = {
544 {1, 9, 1}
545 };
546 BOOST_REQUIRE( stride(~a) == 1 );
547 BOOST_REQUIRE( stride( a) == 3 );
548 multi::array<double, 2> const b = {
549 { 11, 12},
550 { 7, 19},
551 { 8, 1 }
552 };
553 {
554 multi::array<double, 2> c({size(a), size(~b)});
555 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
556 BOOST_REQUIRE( c[0][1] == 184 );
557 }
558 {
559 multi::array<double, 2> c({size(a), size(~b)});
560 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
561 BOOST_REQUIRE( c[0][1] == 184 );
562 }
563 {
564 auto ar = +~a;
565 multi::array<double, 2> c({size(~b), size(~ar)});
566 blas::gemm(1., ~ar, b, 0., ~c); // c⸆=a⸆b, c=b⸆a
567 BOOST_REQUIRE( c[1][0] == 184 );
568 }
569 {
570 auto ar = +~a;
571 BOOST_REQUIRE( size(~ar) == 1 );
572 BOOST_REQUIRE( begin(~ar).stride() == 1 );
573 BOOST_REQUIRE( begin(~ar)->stride() == 1 );
574 BOOST_REQUIRE( begin( ar)->stride() == 1 );
575
576 multi::array<double, 2> c({size(~b), size(~ar)});
577 BOOST_REQUIRE( begin( c).stride() == 1 );
578 BOOST_REQUIRE( begin(~c).stride() == 1 );
579 BOOST_REQUIRE( begin(c)->stride() == 1 );
580
581 BOOST_REQUIRE( begin(b) );
582 blas::gemm_n(1., begin(~ar), size(~ar), begin(b), 0., begin(~c)); // c⸆=a⸆b, c=b⸆a
583 BOOST_REQUIRE( c[1][0] == 184 );
584 }
585 }
586
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complexreal_1x3_3x2)587 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complexreal_1x3_3x2){
588 using complex = std::complex<double>;
589 multi::array<complex, 2> const a = {
590 {1, 9, 1}
591 };
592 BOOST_REQUIRE( stride(~a) == 1 );
593 BOOST_REQUIRE( stride( a) == 3 );
594 multi::array<complex, 2> const b = {
595 { 11, 12},
596 { 7, 19},
597 { 8, 1 }
598 };
599 {
600 multi::array<complex, 2> c({size(a), size(~b)});
601 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
602 BOOST_REQUIRE( c[0][1] == 184. );
603 }
604 {
605 multi::array<complex, 2> c({size(a), size(~b)});
606 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
607 BOOST_REQUIRE( c[0][1] == 184. );
608 }
609 {
610 auto ar = +~a;
611 multi::array<complex, 2> c({size(~b), size(~ar)});
612 blas::gemm(1., ~ar, b, 0., ~c); // c⸆=a⸆b, c=b⸆a
613 BOOST_REQUIRE( c[1][0] == 184. );
614 }
615 {
616 auto ar = +~a;
617 multi::array<complex, 2> c({size(~b), size(~ar)});
618 blas::gemm_n(1., begin(~ar), size(~ar), begin(b), 0., begin(~c)); // c⸆=a⸆b, c=b⸆a
619 BOOST_REQUIRE( c[1][0] == 184. );
620 }
621 }
622
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_1x3_part_3x2)623 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_1x3_part_3x2){
624 multi::array<double, 2> const a = {
625 {1, 9, 1},
626 {3, 3, 3}
627 };
628 BOOST_REQUIRE( stride(~a) == 1 );
629 BOOST_REQUIRE( stride( a) == 3 );
630 multi::array<double, 2> const b = {
631 { 11, 12},
632 { 7, 19},
633 { 8, 1 }
634 };
635 {
636 multi::array<double, 2> c({size(a({0, 1})), size(~b)});
637 blas::gemm(1., a({0, 1}), b, 0., c); // c=a⸆b, c⸆=b⸆a
638 BOOST_REQUIRE( c[0][1] == 184 );
639 }
640 {
641 multi::array<double, 2> c({size(a({0, 1})), size(~b)});
642 blas::gemm_n(1., begin(a({0, 1})), size(a({0, 1})), begin(b), 0., begin(c)); // c=a⸆b, c⸆=b⸆a
643 BOOST_REQUIRE( c[0][1] == 184 );
644 }
645 {
646 auto ar = +~a;
647 multi::array<double, 2> c({size(~b), size(~ar(extension(ar), {0, 1}))});
648 blas::gemm(1., ~(ar(extension(ar), {0, 1})), b, 0., ~c); // c=a⸆b, c⸆=b⸆a
649 BOOST_REQUIRE( c[1][0] == 184 );
650 }
651 {
652 auto ar = +~a;
653 multi::array<double, 2> c({size(~b), size(~ar(extension(ar), {0, 1}))});
654 blas::gemm_n(1., begin(~(ar(extension(ar), {0, 1}))), size(~(ar(extension(ar), {0, 1}))), begin(b), 0., begin(~c)); // c=a⸆b, c⸆=b⸆a
655 BOOST_REQUIRE( c[1][0] == 184 );
656 }
657 }
658
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complexreal_1x3_part_3x2)659 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complexreal_1x3_part_3x2){
660 using complex = std::complex<double>;
661 multi::array<complex, 2> const a = {
662 {1., 9., 1.},
663 {3., 3., 3.}
664 };
665 BOOST_REQUIRE( stride(~a) == 1 );
666 BOOST_REQUIRE( stride( a) == 3 );
667 multi::array<complex, 2> const b = {
668 { 11., 12.},
669 { 7., 19.},
670 { 8., 1.}
671 };
672 {
673 multi::array<complex, 2> c({size(a({0, 1})), size(~b)});
674 blas::gemm(1., a({0, 1}), b, 0., c);
675 BOOST_REQUIRE( c[0][1] == 184. );
676 }
677 {
678 multi::array<complex, 2> c({size(a({0, 1})), size(~b)});
679 blas::gemm_n(1., begin(a({0, 1})), size(a({0, 1})), begin(b), 0., begin(c));
680 BOOST_REQUIRE( c[0][1] == 184. );
681 }
682 {
683 auto ar = +~a;
684 multi::array<complex, 2> c({size(~b), size(~ar(extension(ar), {0, 1}))});
685 blas::gemm(1., ~(ar(extension(ar), {0, 1})), b, 0., ~c);
686 BOOST_REQUIRE( c[1][0] == 184. );
687 }
688 {
689 auto ar = +~a;
690 multi::array<complex, 2> c({size(~b), size(~ar(extension(ar), {0, 1}))});
691 blas::gemm_n(1., begin(~(ar(extension(ar), {0, 1}))), size(~(ar(extension(ar), {0, 1}))), begin(b), 0., begin(~c));
692 BOOST_REQUIRE( c[1][0] == 184. );
693 }
694 }
695
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_2x3_3x1)696 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_2x3_3x1){
697 multi::array<double, 2> const a = {
698 {1, 9, 1},
699 {3, 3, 3}
700 };
701 BOOST_REQUIRE( stride(~a) == 1 );
702 BOOST_REQUIRE( stride( a) == 3 );
703 multi::array<double, 2> const b = {
704 { 11},
705 { 7},
706 { 8}
707 };
708 {
709 multi::array<double, 2> c({size(a), size(~b)});
710 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
711 BOOST_REQUIRE( c[0][0] == 82 );
712 BOOST_REQUIRE( c[1][0] == 78 );
713 }
714 {
715 multi::array<double, 2> c({size(a), size(~b)});
716 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c=a⸆b, c⸆=b⸆a
717 BOOST_REQUIRE( c[0][0] == 82 );
718 BOOST_REQUIRE( c[1][0] == 78 );
719 }
720 {
721 auto ar = +~a;
722 multi::array<double, 2> c({size(~b), size(~ar(extension(ar), {0, 1}))});
723 blas::gemm(1., ~(ar(extension(ar), {0, 1})), b, 0., ~c); // c=a⸆b, c⸆=b⸆a
724 BOOST_REQUIRE( c[0][0] == 82 );
725 }
726 {
727 auto ar = +~a;
728 multi::array<double, 2> c({size(~b), size(~ar(extension(ar), {0, 1}))});
729 blas::gemm_n(1., begin(~(ar(extension(ar), {0, 1}))), size(~(ar(extension(ar), {0, 1}))), begin(b), 0., begin(~c)); // c=a⸆b, c⸆=b⸆a
730 BOOST_REQUIRE( c[0][0] == 82 );
731 }
732 }
733
734
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_2x3_3x1_bis)735 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_2x3_3x1_bis){
736 multi::array<double, 2> const a = {
737 {1, 9, 1},
738 {3, 4, 5}
739 };
740 multi::array<double, 2> const b = {
741 { 11},
742 { 7},
743 { 8}
744 };
745
746 {
747 multi::array<double, 2> c({1, 2});
748 blas::gemm(1., a, b, 0., ~c); // c⸆=ab, c=b⸆a⸆
749 BOOST_REQUIRE( (~c)[0][0] == 82 );
750 BOOST_REQUIRE( (~c)[1][0] == 101 );
751 }
752 {
753 multi::array<double, 2> c({1, 2});
754 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(~c)); // c⸆=ab, c=b⸆a⸆
755 BOOST_REQUIRE( (~c)[0][0] == 82 );
756 BOOST_REQUIRE( (~c)[1][0] == 101 );
757 }
758 {
759 multi::array<double, 2> c({2, 1});
760 blas::gemm(1., a, b, 0., c); // c⸆=ab, c=b⸆a⸆
761 BOOST_REQUIRE( (~c)[0][1] == 101 );
762 BOOST_REQUIRE( c [1][0] == 101 );
763 }
764 {
765 multi::array<double, 2> c({2, 1});
766 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c⸆=ab, c=b⸆a⸆
767 BOOST_REQUIRE( (~c)[0][1] == 101 );
768 BOOST_REQUIRE( c [1][0] == 101 );
769 }
770 {
771 multi::array<double, 2> c({1, 2});
772 auto ar = +~a;
773 blas::gemm(1., ~ar, b, 0., ~c); // c⸆=ab, c⸆=b⸆a⸆
774 BOOST_REQUIRE( c[0][1] == 101 );
775 }
776 {
777 multi::array<double, 2> c({1, 2});
778 auto ar = +~a;
779 blas::gemm_n(1., begin(~ar), size(~ar), begin(b), 0., begin(~c)); // c⸆=ab, c⸆=b⸆a⸆
780 BOOST_REQUIRE( c[0][1] == 101 );
781 }
782 }
783
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_1x3_3x1)784 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_1x3_3x1){
785 multi::array<double, 2> const a = {
786 {1, 9, 1}
787 };
788 multi::array<double, 2> const b = {
789 { 11},
790 { 7},
791 { 8}
792 };
793 {
794 multi::array<double, 2> c({1, 1});
795 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
796 BOOST_REQUIRE( c[0][0] == 82 );
797 }
798 {
799 multi::array<double, 2> c({1, 1});
800 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c));
801 BOOST_REQUIRE( c[0][0] == 82 );
802 }
803 {
804 multi::array<double, 2> c({1, 1});
805 auto ar = +~a;
806 blas::gemm(1., ~ar, b, 0., c);
807 BOOST_REQUIRE( c[0][0] == 82 );
808 }
809 {
810 multi::array<double, 2> c({1, 1});
811 auto ar = +~a;
812 blas::gemm_n(1., begin(~ar), size(~ar), begin(b), 0., begin(c));
813 BOOST_REQUIRE( c[0][0] == 82 );
814 }
815 {
816 multi::array<double, 2> c({1, 1});
817 auto br = +~b;
818 blas::gemm(1., a, ~br, 0., c);
819 BOOST_REQUIRE( c[0][0] == 82 );
820 }
821 {
822 multi::array<double, 2> c({1, 1});
823 BOOST_REQUIRE( begin(c). stride() == 1 );
824 BOOST_REQUIRE( begin(c)->stride() == 1 );
825
826 auto br = +~b;
827 // BOOST_REQUIRE( begin(br). stride() == 1 );
828 BOOST_REQUIRE( begin( br)->stride() == 1 );
829
830 BOOST_REQUIRE(begin(a)->stride() == 1);
831 BOOST_REQUIRE( begin(~br). stride() == 1 );
832 // BOOST_REQUIRE( begin(~br)->stride() == 1 );
833 BOOST_REQUIRE(begin(c)->stride() == 1);
834 BOOST_REQUIRE(begin(c).stride() == 1);
835 BOOST_REQUIRE(size(a) == 1);
836
837 blas::gemm_n(1., begin(a), size(a), begin(~br), 0., begin(c));
838 BOOST_REQUIRE( c[0][0] == 82 );
839 }
840 {
841 multi::array<double, 2> c({1, 1});
842 auto br = +~b;
843 blas::gemm(1., a, blas::H(br), 0., c);
844 BOOST_REQUIRE( c[0][0] == 82 );
845 }
846 {
847 multi::array<double, 2> c({1, 1});
848 auto br = +~b;
849 blas::gemm_n(1., begin(a), size(a), begin(blas::H(br)), 0., begin(c));
850 BOOST_REQUIRE( c[0][0] == 82 );
851 }
852 }
853
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_square)854 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_square){
855 using complex = std::complex<double>; constexpr complex I{0, 1};
856 multi::array<complex, 2> const a = {
857 { 1.+3.*I, 3.+2.*I},
858 { 9.+1.*I, 7.+1.*I},
859 };
860 multi::array<complex, 2> const b = {
861 {11.+2.*I, 12.+4.*I},
862 { 7.+1.*I, 19.-9.*I},
863 };
864 {
865 multi::array<complex, 2> c({2, 2});
866 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
867 BOOST_REQUIRE( c[1][0] == 145. + 43.*I );
868 }
869 {
870 multi::array<complex, 2> c({2, 2});
871 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
872 BOOST_REQUIRE( c[1][0] == 145. + 43.*I );
873 }
874 {
875 multi::array<complex, 2> c({2, 2});
876 blas::gemm(1., ~a, b, 0., c); // c=a⸆b, c⸆=b⸆a
877 BOOST_REQUIRE(( c[1][1] == 170.-8.*I and c[1][0] == 77.+42.*I ));
878 }
879 {
880 multi::array<complex, 2> c({2, 2});
881 blas::gemm_n(1., begin(~a), size(~a), begin(b), 0., begin(c)); // c=a⸆b, c⸆=b⸆a
882 BOOST_REQUIRE(( c[1][1] == 170.-8.*I and c[1][0] == 77.+42.*I ));
883 }
884 {
885 multi::array<complex, 2> c({2, 2});
886 blas::gemm(1., a, ~b, 0., c); // c=ab⸆, c⸆=ba⸆
887 BOOST_REQUIRE( c[1][0] == 177.+69.*I );
888 }
889 {
890 multi::array<complex, 2> c({2, 2});
891 blas::gemm_n(1., begin(a), size(a), begin(~b), 0., begin(c)); // c=ab⸆, c⸆=ba⸆
892 BOOST_REQUIRE( c[1][0] == 177.+69.*I );
893 }
894 {
895 multi::array<complex, 2> c({2, 2});
896 blas::gemm(1., blas::T(a), blas::T(b), 0., c); // c=a⸆b⸆, c⸆=ba
897 BOOST_REQUIRE( c[1][0] == 109. + 68.*I );
898 }
899 {
900 multi::array<complex, 2> c({2, 2});
901 blas::gemm_n(1., begin(blas::T(a)), size(blas::T(a)), begin(blas::T(b)), 0., begin(c)); // c=a⸆b⸆, c⸆=ba
902 BOOST_REQUIRE( c[1][0] == 109. + 68.*I );
903 }
904 {
905 multi::array<complex, 2> c({2, 2});
906 blas::gemm(1., blas::T(a), blas::T(b), 0., blas::T(c)); // c⸆=a⸆b⸆, c=ba
907 BOOST_REQUIRE( c[0][1] == 109.+68.*I );
908 }
909 {
910 multi::array<complex, 2> c({2, 2});
911 blas::gemm_n(1., begin(blas::T(a)), size(blas::T(a)), begin(blas::T(b)), 0., begin(blas::T(c))); // c⸆=a⸆b⸆, c=ba
912 BOOST_REQUIRE( c[0][1] == 109.+68.*I );
913 }
914 }
915
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_1x3_3x1)916 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_1x3_3x1){
917 using complex = std::complex<double>; complex const I{0, 1};
918 multi::array<complex, 2> const a = {
919 {1. + 2.*I, 9. - 1.*I, 1. + 1.*I}
920 };
921 multi::array<complex, 2> const b = {
922 { 11. - 2.*I},
923 { 7. - 3.*I},
924 { 8. - 1.*I}
925 };
926 {
927 multi::array<complex, 2> c({1, 1});
928 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
929 BOOST_REQUIRE( c[0][0] == 84.-7.*I );
930 }
931 {
932 multi::array<complex, 2> c({1, 1});
933 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
934 BOOST_REQUIRE( c[0][0] == 84.-7.*I );
935 }
936 {
937 multi::array<complex, 2> c({1, 1});
938 auto ar = +~a;
939 blas::gemm(1., ~ar, b, 0., c); // c=ab, c⸆=ba
940 BOOST_REQUIRE( c[0][0] == 84.-7.*I );
941 }
942 {
943 multi::array<complex, 2> c({1, 1});
944 auto ar = +~a;
945 blas::gemm_n(1., begin(~ar), size(~ar), begin(b), 0., begin(c)); // c=ab, c⸆=ba
946 BOOST_REQUIRE( c[0][0] == 84.-7.*I );
947 }
948 {
949 multi::array<complex, 2> c({1, 1});
950 auto br = +~b;
951 blas::gemm(1., a, ~br, 0., c);
952 BOOST_REQUIRE( c[0][0] == 84.-7.*I );
953 }
954 {
955 multi::array<complex, 2> c({1, 1});
956 auto br = +~b;
957 blas::context ctxt;
958 blas::gemm_n(ctxt, 1., begin(a), size(a), begin(~br), 0., begin(c));
959 BOOST_REQUIRE( c[0][0] == 84.-7.*I );
960 }
961 {
962 multi::array<complex, 2> c({1, 1});
963 auto br = +~b;
964 blas::gemm(1., a, blas::H(br), 0., ~c);
965 BOOST_REQUIRE( c[0][0] == 80. + 53.*I );
966 }
967 {
968 multi::array<complex, 2> c({1, 1});
969 auto br = +~b;
970 blas::gemm_n(1., begin(a), size(a), begin(blas::H(br)), 0., begin(~c));
971 BOOST_REQUIRE( c[0][0] == 80. + 53.*I );
972 }
973 }
974
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_hermitized_square)975 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_hermitized_square){
976 using complex = std::complex<double>; constexpr complex I{0, 1};
977 multi::array<complex, 2> const a = {
978 { 1.+3.*I, 3.+2.*I},
979 { 9.+1.*I, 7.+1.*I},
980 };
981 multi::array<complex, 2> const b = {
982 {11.+2.*I, 12.+4.*I},
983 { 7.+1.*I, 19.-9.*I},
984 };
985 {
986 multi::array<complex, 2> c({2, 2});
987 blas::gemm(1., a, b, 0., c); // c=ab, c†=b†a†
988 BOOST_REQUIRE( c[1][0] == 145. + 43.*I );
989 }
990 {
991 multi::array<complex, 2> c({2, 2});
992 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c=ab, c†=b†a†
993 BOOST_REQUIRE( c[1][0] == 145. + 43.*I );
994 }
995 {
996 multi::array<complex, 2> c({2, 2});
997 blas::gemm(1., blas::H(a), blas::H(b), 0., c); // c=a†b†, c†=ba
998 BOOST_REQUIRE( c[1][0] == 109. - 68.*I );
999 }
1000 {
1001 multi::array<complex, 2> c({2, 2});
1002 blas::gemm_n(1., begin(blas::H(a)), size(blas::H(a)), begin(blas::H(b)), 0., begin(c)); // c=a†b†, c†=ba
1003 BOOST_REQUIRE( c[1][0] == 109. - 68.*I );
1004 }
1005 {
1006 multi::array<complex, 2> c({2, 2});
1007 blas::gemm(1., blas::H(a), blas::H(b), 0., blas::H(c)); // c†=a†b†, c=ba
1008 BOOST_REQUIRE( c[1][0] == 184. - 40.*I );
1009 }
1010 // {
1011 // multi::array<complex, 2> c({2, 2});
1012 // blas::context ctxt;
1013 // blas::gemm_n(ctxt, 1., begin(blas::H(a)), size(blas::H(a)), begin(blas::H(b)), 0., begin(blas::H(c))); // c†=a†b†, c=ba
1014 // BOOST_REQUIRE( c[1][0] == 184. - 40.*I );
1015 // }
1016 {
1017 multi::array<complex, 2> c({2, 2});
1018 blas::gemm(1., blas::H(a), b, 0., c); // c=a†b, c†=b†a
1019 BOOST_REQUIRE( c[1][0] == 87. - 16.*I );
1020 }
1021 {
1022 multi::array<complex, 2> c({2, 2});
1023 blas::gemm_n(1., begin(blas::H(a)), size(blas::H(a)), begin(b), 0., begin(c)); // c=a†b, c†=b†a
1024 BOOST_REQUIRE( c[1][0] == 87. - 16.*I );
1025 }
1026 {
1027 multi::array<complex, 2> c({2, 2});
1028 blas::gemm(1., a, blas::H(b), 0., c); // c=ab†, c†=ba†
1029 BOOST_REQUIRE( c[1][0] == 189. - 23.*I );
1030 }
1031 {
1032 multi::array<complex, 2> c({2, 2});
1033 c = blas::gemm(1., a, blas::H(b)); // c=ab†, c†=ba†
1034 BOOST_REQUIRE( c[1][0] == 189. - 23.*I );
1035 }
1036 {
1037 multi::array<complex, 2> c = blas::gemm(1., a, blas::H(b)); // c=ab†, c†=ba†
1038 BOOST_REQUIRE( size(c) == 2 );
1039 BOOST_REQUIRE( c[1][0] == 189. - 23.*I );
1040 }
1041 {
1042 multi::array c = blas::gemm(1., a, blas::H(b)); // CTAD
1043 BOOST_REQUIRE( size(c) == 2 );
1044 BOOST_REQUIRE( c[1][0] == 189. - 23.*I );
1045 }
1046 {
1047 auto c = multi::array<complex, 2>(blas::gemm(1., a, blas::H(b))); // c=ab†, c†=ba†
1048 BOOST_REQUIRE( size(c) == 2 );
1049 BOOST_REQUIRE( c[1][0] == 189. - 23.*I );
1050 }
1051 {
1052 multi::array<complex, 2> c({2, 2});
1053 blas::gemm_n(1., begin(a), size(a), begin(blas::H(b)), 0., begin(c)); // c=ab†, c†=ba†
1054 BOOST_REQUIRE( c[1][0] == 189. - 23.*I );
1055 }
1056 {
1057 multi::array<complex, 2> c({2, 2});
1058 blas::gemm(1., blas::H(a), blas::H(b), 0., c); // c=a†b†, c†=ba
1059 BOOST_REQUIRE( c[1][0] == 109. - 68.*I);
1060 }
1061 {
1062 multi::array<complex, 2> c({2, 2});
1063 blas::gemm_n(1., begin(blas::H(a)), size(blas::H(a)), begin(blas::H(b)), 0., begin(c)); // c=a†b†, c†=ba
1064 BOOST_REQUIRE( c[1][0] == 109. - 68.*I);
1065 }
1066 // {
1067 // multi::array<complex, 2> c({2, 2});
1068 // blas::gemm(1., blas::H(a), blas::H(b), 0., ~c); // case no implemented in blas
1069 // BOOST_REQUIRE( c[0][1] == 109. - 68.*I );
1070 // }
1071 }
1072
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_3x1_3x1)1073 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_3x1_3x1){
1074 using complex = std::complex<double>; complex const I{0, 1};
1075 multi::array<complex, 2> const a = {
1076 {1. + 2.*I},
1077 {9. - 1.*I},
1078 {1. + 1.*I}
1079 };
1080 multi::array<complex, 2> const b = {
1081 { 11. - 2.*I},
1082 { 7. - 3.*I},
1083 { 8. - 1.*I}
1084 };
1085 {
1086 multi::array<complex, 2> c({1, 1});
1087 blas::gemm(1., blas::H(a), b, 0., c); // c=ab, c⸆=b⸆a⸆
1088 BOOST_REQUIRE( c[0][0] == 80.-53.*I );
1089 }
1090 {
1091 multi::array<complex, 2> c({1, 1});
1092 blas::gemm_n(1., begin(blas::H(a)), size(blas::H(a)), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1093 BOOST_REQUIRE( c[0][0] == 80.-53.*I );
1094 }
1095 {
1096 multi::array<complex, 2> c({1, 1});
1097 blas::gemm(1., blas::H(a), b, 0., c); // c=a†b, c†=b†a
1098 BOOST_REQUIRE( c[0][0] == 80.-53.*I );
1099 }
1100 {
1101 multi::array<complex, 2> c({1, 1});
1102 blas::gemm_n(1., begin(blas::H(a)), size(blas::H(a)), begin(b), 0., begin(c)); // c=a†b, c†=b†a
1103 BOOST_REQUIRE( c[0][0] == 80.-53.*I );
1104 }
1105 {
1106 multi::array<complex, 2> c({1, 1});
1107 auto ha = +blas::hermitized(a);
1108 blas::gemm(1., ha, b, 0., c);
1109 BOOST_REQUIRE( c[0][0] == 80.-53.*I );
1110
1111 blas::gemm(1., blas::H(b), a, 0., c);
1112 BOOST_REQUIRE( c[0][0] == 80.+53.*I );
1113 }
1114 {
1115 multi::array<complex, 2> c({1, 1});
1116 auto ha = +blas::hermitized(a);
1117 blas::gemm_n(1., begin(ha), size(ha), begin(b), 0., begin(c));
1118 BOOST_REQUIRE( c[0][0] == 80.-53.*I );
1119
1120 blas::gemm_n(1., begin(blas::H(b)), size(blas::H(b)), begin(a), 0., begin(c));
1121 BOOST_REQUIRE( c[0][0] == 80.+53.*I );
1122 }
1123 }
1124
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_1x3_3x2)1125 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_1x3_3x2){
1126 using complex = std::complex<double>; constexpr complex I{0, 1};
1127 multi::array<complex, 2> const a = {
1128 {1. + 2.*I, 9. - 1.*I, 1. + 1.*I}
1129 };
1130 multi::array<complex, 2> const b = {
1131 { 11. - 2.*I, 5. + 2.*I},
1132 { 7. - 3.*I, 2. + 1.*I},
1133 { 8. - 1.*I, 1. + 1.*I}
1134 };
1135 {
1136 multi::array<complex, 2> c({1, 2});
1137 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
1138 BOOST_REQUIRE( c[0][1] == 20.+21.*I );
1139 }
1140 {
1141 multi::array<complex, 2> c({1, 2});
1142 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1143 BOOST_REQUIRE( c[0][1] == 20.+21.*I );
1144 }
1145 {
1146 auto ar = +~a;
1147 multi::array<complex, 2> c({1, 2});
1148 blas::gemm(1., blas::H(ar), b, 0., c); // c=ab, c⸆=b⸆a⸆
1149 BOOST_REQUIRE( c[0][1] == 28.+3.*I );
1150 }
1151 {
1152 auto ar = +~a;
1153 multi::array<complex, 2> c({1, 2});
1154 blas::gemm_n(1., begin(blas::H(ar)), size(blas::H(ar)), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1155 BOOST_REQUIRE( c[0][1] == 28.+3.*I );
1156 }
1157 }
1158
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_3x1_3x2)1159 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_3x1_3x2){
1160 using complex = std::complex<double>; complex const I{0, 1};
1161 multi::array<complex, 2> const a = {
1162 {1. + 2.*I},
1163 {9. - 1.*I},
1164 {1. + 1.*I}
1165 };
1166 multi::array<complex, 2> const b = {
1167 { 11. - 2.*I, 5. + 2.*I},
1168 { 7. - 3.*I, 2. + 1.*I},
1169 { 8. - 1.*I, 1. + 1.*I}
1170 };
1171 {
1172 multi::array<complex, 2> c({1, 2});
1173 blas::gemm(1., blas::H(a), b, 0., c); // c=ab, c⸆=b⸆a⸆
1174 BOOST_REQUIRE( c[0][1] == 28.+3.*I );
1175 }
1176 {
1177 multi::array<complex, 2> c({1, 2});
1178 blas::gemm_n(1., begin(blas::H(a)), size(blas::H(a)), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1179 BOOST_REQUIRE( c[0][1] == 28.+3.*I );
1180 }
1181 }
1182
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_3x2_3x2)1183 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_3x2_3x2){
1184 using complex = std::complex<double>; complex const I{0, 1};
1185 multi::array<complex, 2> const a = {
1186 {1. + 2.*I, 5. + 2.*I},
1187 {9. - 1.*I, 9. + 1.*I},
1188 {1. + 1.*I, 2. + 2.*I}
1189 };
1190 multi::array<complex, 2> const b = {
1191 { 11. - 2.*I, 5. + 2.*I},
1192 { 7. - 3.*I, 2. + 1.*I},
1193 { 8. - 1.*I, 1. + 1.*I}
1194 };
1195 {
1196 multi::array<complex, 2> c({2, 2});
1197 blas::gemm(1., blas::H(a), b, 0., c); // c=ab, c⸆=b⸆a⸆
1198 BOOST_REQUIRE( c[1][0] == 125.-84.*I );
1199 }
1200 {
1201 multi::array<complex, 2> c({2, 2});
1202 blas::gemm_n(1., begin(blas::H(a)), size(blas::H(a)), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1203 BOOST_REQUIRE( c[1][0] == 125.-84.*I );
1204 }
1205 }
1206
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_3x2_3x1)1207 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_3x2_3x1){
1208 using complex = std::complex<double>; complex const I{0, 1};
1209 multi::array<complex, 2> const a = {
1210 {1. + 2.*I, 5. + 2.*I},
1211 {9. - 1.*I, 9. + 1.*I},
1212 {1. + 1.*I, 2. + 2.*I}
1213 };
1214 multi::array<complex, 2> const b = {
1215 { 11. - 2.*I},
1216 { 7. - 3.*I},
1217 { 8. - 1.*I}
1218 };
1219 {
1220 multi::array<complex, 2> c({2, 1});
1221 blas::gemm(1., blas::H(a), b, 0., c); // c=ab, c⸆=b⸆a⸆
1222 BOOST_REQUIRE( c[1][0] == 125.-84.*I );
1223 }
1224 {
1225 multi::array<complex, 2> c({2, 1});
1226 blas::gemm_n(1., begin(blas::H(a)), size(blas::H(a)), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1227 BOOST_REQUIRE( c[1][0] == 125.-84.*I );
1228 }
1229 }
1230
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_3x1_3x1_bis)1231 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_3x1_3x1_bis){
1232 using complex = std::complex<double>; complex const I{0, 1};
1233 multi::array<complex, 2> const a = {
1234 {1. + 2.*I},
1235 {9. - 1.*I},
1236 {1. + 1.*I}
1237 };
1238 multi::array<complex, 2> const b = {
1239 { 11. - 2.*I},
1240 { 7. - 3.*I},
1241 { 8. - 1.*I}
1242 };
1243 {
1244 multi::array<complex, 2> c({1, 1});
1245 blas::gemm(1., blas::H(a), b, 0., c); // c=ab, c⸆=b⸆a⸆
1246 BOOST_REQUIRE( c[0][0] == 80. - 53.*I );
1247 }
1248 {
1249 multi::array<complex, 2> c({1, 1});
1250 blas::gemm_n(1., begin(blas::H(a)), size(blas::H(a)), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1251 BOOST_REQUIRE( c[0][0] == 80. - 53.*I );
1252 }
1253 }
1254
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_square_automatic)1255 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_square_automatic){
1256 multi::array<double, 2> const a = {
1257 { 1., 3.},
1258 { 9., 7.},
1259 };
1260 multi::array<double, 2> const b = {
1261 { 11., 12.},
1262 { 7., 19.},
1263 };
1264 {
1265 multi::array<double, 2> c({2, 2});
1266 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
1267 BOOST_REQUIRE( c[1][0] == 148 and c[1][1] == 241 );
1268 }
1269 {
1270 multi::array<double, 2> c({2, 2});
1271 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1272 BOOST_REQUIRE( c[1][0] == 148 and c[1][1] == 241 );
1273 }
1274 {
1275 multi::array<double, 2> c({2, 2});
1276 blas::gemm(1., a, blas::T(b), 0., c); // c=ab, c⸆=b⸆a⸆
1277 BOOST_REQUIRE( c[1][1] == 196. );
1278 }
1279 {
1280 multi::array<double, 2> c({2, 2});
1281 blas::gemm(1., blas::T(a), b, 0., c); // c=ab, c⸆=b⸆a⸆
1282 BOOST_REQUIRE(( c[1][1] == 169. and c[1][0] == 82. ));
1283 }
1284 {
1285 multi::array<double, 2> c({2, 2});
1286 blas::gemm(1., blas::T(a), blas::T(b), 0., c); // c=ab, c⸆=b⸆a⸆
1287 BOOST_REQUIRE( c[1][1] == 154. );
1288 }
1289 }
1290
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_square_automatic)1291 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_square_automatic){
1292 using complex = std::complex<double>; complex const I{0, 1};
1293 multi::array<complex, 2> const a = {
1294 { 1. + 2.*I, 3. - 3.*I},
1295 { 9. + 1.*I, 7. + 4.*I},
1296 };
1297 multi::array<complex, 2> const b = {
1298 { 11. + 1.*I, 12. + 1.*I},
1299 { 7. + 8.*I, 19. - 2.*I},
1300 };
1301 namespace blas = multi::blas;
1302 {
1303 multi::array<complex, 2> c({2, 2});
1304 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
1305 BOOST_REQUIRE( c[1][0] == complex(115, 104) );
1306 }
1307 {
1308 multi::array<complex, 2> c({2, 2});
1309 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1310 BOOST_REQUIRE( c[1][0] == complex(115, 104) );
1311 }
1312 {
1313 multi::array<complex, 2> c({2, 2});
1314 blas::gemm(1., a, blas::T(b), 0., c); // c=ab⸆, c⸆=ba⸆
1315 BOOST_REQUIRE( c[1][0] == complex(178, 75) );
1316 }
1317 {
1318 multi::array<complex, 2> c({2, 2});
1319 blas::gemm_n(1., begin(a), size(a), begin(blas::T(b)), 0., begin(c)); // c=ab⸆, c⸆=ba⸆
1320 BOOST_REQUIRE( c[1][0] == complex(178, 75) );
1321 }
1322 {
1323 multi::array<complex, 2> c({2, 2});
1324 blas::gemm(1., blas::T(a), b, 0., c); // c=a⸆b, c⸆=b⸆a
1325 BOOST_REQUIRE(( c[1][1] == complex(180, 29) and c[1][0] == complex(53, 54) ));
1326 }
1327 {
1328 multi::array<complex, 2> c({2, 2});
1329 blas::gemm_n(1., begin(blas::T(a)), size(blas::T(a)), begin(b), 0., begin(c)); // c=a⸆b, c⸆=b⸆a
1330 BOOST_REQUIRE(( c[1][1] == complex(180, 29) and c[1][0] == complex(53, 54) ));
1331 }
1332 {
1333 multi::array<complex, 2> c({2, 2});
1334 blas::gemm(1., blas::T(a), blas::T(b), 0., c); // c=ab, c⸆=b⸆a⸆
1335 BOOST_REQUIRE(( c[1][1] == complex(186, 65) and c[1][0] == complex(116, 25) ));
1336 }
1337 {
1338 multi::array<complex, 2> c({2, 2});
1339 blas::gemm_n(1., begin(blas::T(a)), size(blas::T(a)), begin(blas::T(b)), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1340 BOOST_REQUIRE(( c[1][1] == complex(186, 65) and c[1][0] == complex(116, 25) ));
1341 }
1342 {
1343 multi::array<complex, 2> c({2, 2});
1344 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
1345 BOOST_REQUIRE( c[1][0] == complex(115, 104) );
1346 }
1347 {
1348 multi::array<complex, 2> c({2, 2});
1349 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1350 BOOST_REQUIRE( c[1][0] == complex(115, 104) );
1351 }
1352 {
1353 multi::array<complex, 2> c({2, 2});
1354 blas::gemm(1., blas::H(a), b, 0., c); // c=a†b, c†=b†a
1355 BOOST_REQUIRE( c[1][0] == complex(111, 64) and c[1][1] == complex(158, -51) );
1356 }
1357 {
1358 multi::array<complex, 2> c({2, 2});
1359 blas::gemm_n(1., begin(blas::H(a)), size(blas::H(a)), begin(b), 0., begin(c)); // c=a†b, c†=b†a
1360 BOOST_REQUIRE( c[1][0] == complex(111, 64) and c[1][1] == complex(158, -51) );
1361 }
1362 {
1363 multi::array<complex, 2> c({2, 2});
1364 blas::gemm(1., a, blas::H(b), 0., c); // c=ab†, c†=ba†
1365 BOOST_REQUIRE( c[1][0] == complex(188, 43) and c[1][1] == complex(196, 25) );
1366 }
1367 {
1368 multi::array<complex, 2> c({2, 2});
1369 blas::gemm(1., blas::H(a), blas::H(b), 0., c); // c=a†b†, c†=ba
1370 BOOST_REQUIRE( c[1][0] == complex(116, -25) and c[1][1] == complex(186, -65) );
1371 }
1372 {
1373 multi::array<complex, 2> c({2, 2});
1374 blas::gemm_n(1., begin(blas::H(a)), size(blas::H(a)), begin(blas::H(b)), 0., begin(c)); // c=a†b†, c†=ba
1375 BOOST_REQUIRE( c[1][0] == complex(116, -25) and c[1][1] == complex(186, -65) );
1376 }
1377 {
1378 multi::array<complex, 2> c({2, 2});
1379 blas::gemm(1., blas::T(a), blas::H(b), 0., c); // c=a⸆b†, c†=ba⸆†
1380 BOOST_REQUIRE( c[1][0] == complex(118, 5) and c[1][1] == complex(122, 45) );
1381 }
1382 {
1383 multi::array<complex, 2> c({2, 2});
1384 blas::gemm_n(1., begin(blas::T(a)), size(blas::T(a)), begin(blas::H(b)), 0., begin(c)); // c=a⸆b†, c†=ba⸆†
1385 BOOST_REQUIRE( c[1][0] == complex(118, 5) and c[1][1] == complex(122, 45) );
1386 }
1387 {
1388 multi::array<complex, 2> c({2, 2});
1389 blas::gemm(1., blas::T(a), blas::T(b), 0., c); // c=a⸆b⸆, c⸆=ba
1390 BOOST_REQUIRE( c[1][0] == complex(116, 25) and c[1][1] == complex(186, 65) );
1391 }
1392 {
1393 multi::array<complex, 2> c({2, 2});
1394 blas::gemm_n(1., begin(blas::T(a)), size(blas::T(a)), begin(blas::T(b)), 0., begin(c)); // c=a⸆b⸆, c⸆=ba
1395 BOOST_REQUIRE( c[1][0] == complex(116, 25) and c[1][1] == complex(186, 65) );
1396 }
1397 }
1398
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_nonsquare_automatic)1399 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_nonsquare_automatic){
1400 using complex = std::complex<double>; complex const I{0, 1};
1401 multi::array<complex, 2> const a = {
1402 { 1. + 2.*I, 3. - 3.*I, 1.-9.*I},
1403 { 9. + 1.*I, 7. + 4.*I, 1.-8.*I},
1404 };
1405 multi::array<complex, 2> const b = {
1406 { 11.+1.*I, 12.+1.*I, 4.+1.*I, 8.-2.*I},
1407 { 7.+8.*I, 19.-2.*I, 2.+1.*I, 7.+1.*I},
1408 { 5.+1.*I, 3.-1.*I, 3.+8.*I, 1.+1.*I}
1409 };
1410 {
1411 multi::array<complex, 2> c({2, 4});
1412 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
1413 BOOST_REQUIRE( c[1][2] == complex(112, 12) );
1414 }
1415 {
1416 multi::array<complex, 2> c({2, 4});
1417 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1418 BOOST_REQUIRE( c[1][2] == complex(112, 12) );
1419 }
1420 }
1421
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_realcomplex_complex_nonsquare_automatic)1422 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_realcomplex_complex_nonsquare_automatic){
1423 using complex = std::complex<double>; complex const I{0, 1};
1424 multi::array<complex, 2> const a = {
1425 { 1., 3., 1.},
1426 { 9., 7., 1.},
1427 };
1428 multi::array<complex, 2> const b = {
1429 { 11.+1.*I, 12.+1.*I, 4.+1.*I, 8.-2.*I},
1430 { 7.+8.*I, 19.-2.*I, 2.+1.*I, 7.+1.*I},
1431 { 5.+1.*I, 3.-1.*I, 3.+8.*I, 1.+1.*I}
1432 };
1433 {
1434 multi::array<complex, 2> c = blas::gemm(1., a, b); // c=ab, c⸆=b⸆a⸆
1435 BOOST_REQUIRE( c[1][2] == complex(53, 24) );
1436 }
1437 {
1438 multi::array<complex, 2> c({2, 4});
1439 c = blas::gemm(1., a, b); // c=ab, c⸆=b⸆a⸆
1440 BOOST_REQUIRE( c[1][2] == complex(53, 24) );
1441 }
1442 {
1443 multi::array<complex, 2> c({2, 4});
1444 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
1445 BOOST_REQUIRE( c[1][2] == complex(53, 24) );
1446 }
1447 {
1448 multi::array<complex, 2> c({2, 4});
1449 blas::gemm_n(1., begin(a), size(a), begin(b), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1450 BOOST_REQUIRE( c[1][2] == complex(53, 24) );
1451 }
1452 {
1453 multi::array<double, 2> const a_real = {
1454 { 1., 3., 1.},
1455 { 9., 7., 1.},
1456 };
1457 multi::array<complex, 2> c({2, 4});
1458 blas::real_doubled(c) = blas::gemm(1., a_real, blas::real_doubled(b));
1459
1460 BOOST_REQUIRE( c[1][2] == complex(53, 24) );
1461 }
1462 }
1463
BOOST_AUTO_TEST_CASE(submatrix_result_issue_97)1464 BOOST_AUTO_TEST_CASE(submatrix_result_issue_97){
1465 using complex = std::complex<double>; constexpr complex I{0, 1};
1466 multi::array<complex, 2> M = {
1467 {2. + 3.*I, 2. + 1.*I, 1. + 2.*I},
1468 {4. + 2.*I, 2. + 4.*I, 3. + 1.*I},
1469 {7. + 1.*I, 1. + 5.*I, 0. + 3.*I}
1470 };
1471
1472
1473 auto M2 = +M({0, 3}, {0, 1});
1474 BOOST_REQUIRE( M2 == M({0, 3}, {0, 1}) );
1475
1476 // multi::array<complex, 2> V = {
1477 // {1. + 2.*I},
1478 // {2. + 1.*I},
1479 // {9. + 2.*I}
1480 // };
1481
1482 // BOOST_REQUIRE( (+blas::gemm(1., blas::H(M2 ), V))[0][0] == 83. + 6.*I );
1483 // BOOST_REQUIRE( (+blas::gemm(1., blas::H(M({0, 3}, {0, 1})), V))[0][0] == 83. + 6.*I );
1484
1485 // using namespace multi::blas::operators;
1486 // BOOST_REQUIRE( (+(blas::H(M)*V))[0][0] == 83. + 6.*I );
1487 }
1488
1489
BOOST_AUTO_TEST_CASE(blas_context_gemm)1490 BOOST_AUTO_TEST_CASE(blas_context_gemm){
1491 using complex = std::complex<double>; static constexpr complex I{0, 1};
1492 auto rand = [d=std::normal_distribution<>{}, g=std::mt19937{}]()mutable{return d(g) + d(g)*I;}; // NOLINT(cert-msc32-c, cert-msc51-cpp): test purposes
1493
1494 multi::array<complex, 2> A({30, 40});
1495 multi::array<complex, 2> B({40, 50});
1496
1497 std::generate(A.elements().begin(), A.elements().end(), rand);
1498 std::generate(B.elements().begin(), B.elements().end(), rand);
1499
1500 // auto C = +blas::gemm(1., A, B);
1501
1502 // using namespace multi::blas::operators;
1503
1504 // {
1505 // auto sum = 0.;
1506 // for(auto i : extension(~C))
1507 // sum += blas::nrm2((~C)[i] - blas::gemv(A, (~B)[i]))();
1508
1509 // BOOST_REQUIRE(sum == 0, boost::test_tools::tolerance(1e-12));
1510 // }
1511
1512 // BOOST_REQUIRE( std::inner_product(
1513 // begin(~C), end(~C), begin(~B), 0., std::plus<>{}, [&A](auto const& Ccol, auto const& Bcol){
1514 // return multi::blas::nrm2( Ccol - multi::blas::gemv(A, Bcol) );
1515 // }) == 0. , boost::test_tools::tolerance(1e-12) );
1516
1517 // BOOST_REQUIRE( std::equal(
1518 // begin(~C), end(~C), begin(~B), [&A](auto const& Ccol, auto const& Bcol){
1519 // return multi::blas::nrm2( Ccol - multi::blas::gemv(A, Bcol) ) < 1e-12;
1520 // }
1521 // ) );
1522
1523 // blas::context ctxt;
1524 // auto C2 = +blas::gemm(&ctxt, 1., A, B);
1525
1526 // BOOST_REQUIRE( std::equal(
1527 // begin(C), end(C), begin(C2), [](auto const& crow, auto const& c2row){return ((crow - c2row)^2) < 1e-13;}
1528 // ) );
1529
1530 }
1531
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_nonsquare_hermitized_second_gemm_range)1532 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_nonsquare_hermitized_second_gemm_range){//, *utf::tolerance(0.00001)){
1533 multi::array<double, 2> const a({2, 3}, 0.);
1534 multi::array<double, 2> const b({4, 3}, 0.);
1535 {
1536 multi::array<double, 2> c({2, 4});
1537 c() = blas::gemm(0.1, a, blas::H(b));
1538 BOOST_REQUIRE_CLOSE( c[1][2], 0., 0.00001 );
1539 }
1540 {
1541 multi::array<double, 2> c = blas::gemm(0.1, a, blas::H(b)); // c=ab⸆, c⸆=ba⸆
1542 BOOST_REQUIRE( c[1][2] == 0. );
1543 }
1544 {
1545 multi::array<double, 2> const a = {
1546 {1, 3, 1},
1547 {9, 7, 1},
1548 };
1549 (void)a;
1550 }
1551 }
1552
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_nonsquare_hermitized_second_gemm_range)1553 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_nonsquare_hermitized_second_gemm_range){//, *utf::tolerance(0.00001)){
1554 using complex = std::complex<double>;
1555 multi::array<complex, 2> const a({2, 3}, 0.);
1556 multi::array<complex, 2> const b({4, 3}, 0.);
1557 {
1558 multi::array<complex, 2> c({2, 4}, 999.);
1559 blas::gemm_n(1., begin(a), size(a), begin(blas::H(b)), 0., begin(c));
1560 BOOST_REQUIRE( c[1][2] != 999. );
1561 }
1562 {
1563 multi::array<complex, 2> c = blas::gemm(1., a, blas::H(b)); // c=ab⸆, c⸆=ba⸆
1564 BOOST_REQUIRE( c[1][2] == 0. );
1565 }
1566 }
1567
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_nonsquare_hermitized_second)1568 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_nonsquare_hermitized_second){//, *utf::tolerance(0.00001)){
1569 namespace blas = multi::blas;
1570 multi::array<double, 2> const a = {
1571 {1, 3, 1},
1572 {9, 7, 1},
1573 };
1574 multi::array<double, 2> const b = {
1575 {11, 7, 5},
1576 {12, 19, 3},
1577 { 4, 2, 3},
1578 { 8, 7, 1}
1579 };
1580 {
1581 multi::array<double, 2> c({2, 4});
1582 blas::gemm(1., a, blas::H(b), 0., c); // c=ab, c⸆=b⸆a⸆
1583 BOOST_REQUIRE( c[1][2] == 53. );
1584 }
1585 {
1586 multi::array<double, 2> c({2, 4});
1587 blas::gemm_n(1., begin(a), size(a), begin(blas::H(b)), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1588 BOOST_REQUIRE( c[1][2] == 53. );
1589 }
1590 {
1591 multi::array<double, 2> c({2, 4});
1592 blas::gemm(0.1, a, blas::H(b), 0., c); // c=ab, c⸆=b⸆a⸆
1593 BOOST_REQUIRE_CLOSE( c[1][2] , 5.3 , 0.00001 );
1594 }
1595 {
1596 multi::array<double, 2> c({2, 4});
1597 blas::gemm_n(0.1, begin(a), size(a), begin(blas::H(b)), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1598 BOOST_REQUIRE_CLOSE( c[1][2] , 5.3 , 0.00001 );
1599 }
1600 {
1601 multi::array<double, 2> c({2, 4});
1602 c() = blas::gemm(0.1, a, blas::H(b));
1603 }
1604 {
1605 multi::array<double, 2> c = blas::gemm(0.1, a, blas::H(b)); // c=ab⸆, c⸆=ba⸆
1606 BOOST_REQUIRE_CLOSE( c[1][2] , 5.3 , 0.00001 );
1607 }
1608 }
1609
BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_real_nonsquare_hermitized_second)1610 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_real_nonsquare_hermitized_second){//, *utf::tolerance(0.00001)){
1611 namespace blas = multi::blas;
1612 using complex = std::complex<double>;
1613 multi::array<complex, 2> const a = {
1614 {1., 3., 1.},
1615 {9., 7., 1.},
1616 };
1617 multi::array<complex, 2> const b = {
1618 {11., 7., 5.},
1619 {12., 19., 3.},
1620 { 4., 2., 3.},
1621 { 8., 7., 1.}
1622 };
1623 {
1624 multi::array<complex, 2> c({2, 4});
1625 blas::gemm(1., a, blas::H(b), 0., c); // c=ab, c⸆=b⸆a⸆
1626 BOOST_REQUIRE( c[1][2] == 53. );
1627 }
1628 {
1629 multi::array<complex, 2> c({2, 4});
1630 blas::gemm_n(1., begin(a), size(a), begin(blas::H(b)), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1631 BOOST_REQUIRE( c[1][2] == 53. );
1632 }
1633 {
1634 multi::array<complex, 2> c({2, 4});
1635 blas::gemm(0.1, a, blas::H(b), 0., c); // c=ab, c⸆=b⸆a⸆
1636 BOOST_REQUIRE_CLOSE( real(c[1][2]) , 5.3 , 0.00001 );
1637 }
1638 {
1639 multi::array<complex, 2> c({2, 4});
1640 blas::gemm_n(0.1, begin(a), size(a), begin(blas::H(b)), 0., begin(c)); // c=ab, c⸆=b⸆a⸆
1641 BOOST_REQUIRE_CLOSE( real(c[1][2]) , 5.3 , 0.00001 );
1642 }
1643 {
1644 multi::array<complex, 2> c({2, 4});
1645 c() = blas::gemm(0.1, a, blas::H(b));
1646 }
1647 {
1648 multi::array<complex, 2> c = blas::gemm(0.1, a, blas::H(b)); // c=ab⸆, c⸆=ba⸆
1649 BOOST_REQUIRE_CLOSE( real(c[1][2]) , 5.3 , 0.00001 );
1650 }
1651 }
1652
BOOST_AUTO_TEST_CASE(blas_gemm_1xn_complex)1653 BOOST_AUTO_TEST_CASE(blas_gemm_1xn_complex){
1654 using complex = std::complex<double>;
1655 multi::array<complex, 2> const a({1, 100}, 1.);
1656 multi::array<complex, 2> const b({1, 100}, 1.);
1657
1658 multi::array<complex, 2> c({1, 1}, 999.);
1659 blas::gemm_n(1., begin(a), size(a), begin(blas::H(b)), 0., begin(c));
1660 BOOST_REQUIRE( c[0][0] == 100. );
1661 }
1662
BOOST_AUTO_TEST_CASE(blas_gemm_nx1_times_1x1_complex_inq_hydrogen_case)1663 BOOST_AUTO_TEST_CASE(blas_gemm_nx1_times_1x1_complex_inq_hydrogen_case){
1664 using complex = std::complex<double>;
1665 multi::array<complex, 2> const a({100, 1}, 2.);
1666 multi::array<complex, 2> const b({1, 1}, 3.);
1667
1668 multi::array<complex, 2> c({100, 1}, 999.);
1669 blas::gemm_n(1., begin(a), size(a), begin(blas::H(b)), 0., begin(c));
1670 BOOST_REQUIRE( c[0][0] == 6. );
1671 BOOST_REQUIRE( c[1][0] == 6. );
1672 }
1673
BOOST_AUTO_TEST_CASE(blas_gemm_nx1_times_1x1_1x1_complex_inq_hydrogen_case)1674 BOOST_AUTO_TEST_CASE(blas_gemm_nx1_times_1x1_1x1_complex_inq_hydrogen_case){
1675 using complex = std::complex<double>;
1676 multi::array<complex, 2> const a({1, 1}, 2.);
1677 multi::array<complex, 2> const b({1, 1}, 3.);
1678
1679 multi::array<complex, 2> c({1, 1}, 999.);
1680 c = blas::gemm(1., a, b);
1681 BOOST_REQUIRE( c[0][0] == 6. );
1682 }
1683
BOOST_AUTO_TEST_CASE(blas_gemm_inq_case)1684 BOOST_AUTO_TEST_CASE(blas_gemm_inq_case){ // https://gitlab.com/correaa/boost-multi/-/issues/97
1685
1686 using complex = std::complex<double>;
1687 multi::array<complex, 2> mat({10, 2}, 1.0);
1688 multi::array<complex, 2> vec({10, 1}, -2.0);
1689
1690 mat({0, 10}, {1, 2}) = vec;
1691
1692 namespace blas = multi::blas;
1693
1694 {
1695 auto olap1 =+ blas::gemm(1., blas::H(mat) , vec);
1696 auto olap2 =+ blas::gemm(1., blas::H(mat({0, 10}, {0, 1})), vec);
1697
1698 multi::array<complex, 2> mat2 = mat({0, 10}, {0, 1});
1699 auto olap3 =+ blas::gemm(1., blas::H(mat2), vec);
1700
1701 BOOST_REQUIRE(olap1[0][0] == olap2[0][0]);
1702 BOOST_REQUIRE(olap3[0][0] == olap2[0][0]);
1703 }
1704 {
1705 multi::array<complex, 2> mat2 = mat({0, 3}, {0, 1});
1706 auto olap3 =+ blas::gemm(1., blas::H(mat({0, 3}, {0, 1})), vec);
1707 BOOST_REQUIRE( (+blas::gemm(1., blas::H(mat2), vec))[0][0] == (+blas::gemm(1., blas::H(mat({0, 3}, {0, 1})), vec))[0][0] );
1708 }
1709
1710 }
1711
1712 #if 0
1713 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_nonsquare_hermitized_second_gpu){//, *utf::tolerance(0.00001)){
1714 namespace cuda = multi::cuda;
1715 namespace blas = multi::blas;
1716 cuda::array<double, 2> const a = {
1717 {1, 3, 1},
1718 {9, 7, 1},
1719 };
1720 cuda::array<double, 2> const b = {
1721 {11, 7, 5},
1722 {12, 19, 3},
1723 { 4, 2, 3},
1724 { 8, 7, 1}
1725 };
1726 using multi::blas::gemm;using multi::blas::hermitized;
1727 {
1728 cuda::array<double, 2> c({2, 4});
1729 blas::gemm(1., a, blas::H(b), 0., c); // c=ab, c⸆=b⸆a⸆
1730 BOOST_REQUIRE( c[1][2] == 53 );
1731 }
1732 {
1733 cuda::array<double, 2> c({2, 4});
1734 blas::gemm(0.1, a, blas::H(b), 0., c); // c=ab, c⸆=b⸆a⸆
1735 BOOST_REQUIRE( c[1][2] == 5.3 );
1736 }
1737 {
1738 cuda::array<double, 2> c({2, 4});
1739 blas::gemm(0.1, a, blas::H(b), 0., c); // c=ab, c⸆=b⸆a⸆
1740 BOOST_REQUIRE( c[1][2] == 5.3 );
1741 }
1742 {
1743 cuda::array<double, 2> c({2, 4});
1744 auto c_copy = blas::gemm(0.1, a, blas::H(b), 0., c); // c=ab, c⸆=b⸆a⸆
1745 BOOST_REQUIRE( c_copy[1][2] == 5.3 );
1746 }
1747 {
1748 multi::cuda::array<double, 2> c({2, 4});
1749 auto c_copy = blas::gemm(0.1, a, blas::H(b), 0., c); // c=ab, c⸆=b⸆a⸆
1750 BOOST_REQUIRE( c_copy[1][2] == 5.3 );
1751 }
1752 {
1753 // auto f = [](auto&& a, auto&& b){return blas::gemm(0.1, a, blas::H(b));};
1754 // auto c = f(a, b);
1755 // BOOST_REQUIRE( c[1][2] == 5.3 );
1756 }
1757 {
1758 // auto f = [](auto&& a, auto&& b){return blas::gemm(0.1, a, blas::H(b));};
1759 // cuda::array<double, 2> c;
1760 // c = f(a, b);
1761 // BOOST_REQUIRE( c[1][2] == 5.3 );
1762 }
1763 {
1764 // auto c = blas::gemm(0.1, a, blas::H(b)); // c=ab, c⸆=b⸆a⸆
1765 // BOOST_REQUIRE( c[1][2] == 5.3 );
1766 }
1767 }
1768
1769 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_real_nonsquare_hermitized_second_managed){//, *utf::tolerance(0.00001)){
1770 namespace cuda = multi::cuda;
1771 namespace blas = multi::blas;
1772 cuda::managed::array<double, 2> const a = {
1773 {1, 3, 1},
1774 {9, 7, 1},
1775 };
1776 cuda::managed::array<double, 2> const b = {
1777 {11, 7, 5},
1778 {12, 19, 3},
1779 { 4, 2, 3},
1780 { 8, 7, 1}
1781 };
1782 {
1783 cuda::managed::array<double, 2> c({2, 4});
1784 blas::gemm(1., a, blas::H(b), 0., c); // c=ab, c⸆=b⸆a⸆
1785 BOOST_REQUIRE( c[1][2] == 53 );
1786 }
1787 {
1788 cuda::managed::array<double, 2> c({2, 4});
1789 blas::gemm(0.1, a, blas::H(b), 0., c); // c=ab, c⸆=b⸆a⸆
1790 BOOST_REQUIRE( c[1][2] == 5.3 );
1791 }
1792 {
1793 cuda::managed::array<double, 2> c({2, 4});
1794 blas::gemm(0.1, a, blas::H(b), 0., c); // c=ab, c⸆=b⸆a⸆
1795 BOOST_REQUIRE( c[1][2] == 5.3 );
1796 }
1797 {
1798 cuda::managed::array<double, 2> c({2, 4});
1799 auto c_copy = blas::gemm(0.1, a, blas::H(b), 0., c); // c=ab, c⸆=b⸆a⸆
1800 BOOST_REQUIRE( c_copy[1][2] == 5.3 );
1801 }
1802 {
1803 multi::cuda::managed::array<double, 2> c({2, 4});
1804 auto c_copy = blas::gemm(0.1, a, blas::H(b), 0., c); // c=ab, c⸆=b⸆a⸆
1805 BOOST_REQUIRE( c_copy[1][2] == 5.3 );
1806 }
1807 {
1808 // auto f = [](auto&& a, auto&& b){return blas::gemm(0.1, a, blas::H(b));};
1809 // auto c = f(a, b);
1810 // BOOST_REQUIRE( c[1][2] == 5.3 );
1811 }
1812 {
1813 // auto f = [](auto&& a, auto&& b){return blas::gemm(0.1, a, blas::H(b));};
1814 // cuda::managed::array<double, 2> c;
1815 // c = f(a, b);
1816 // BOOST_REQUIRE( c[1][2] == 5.3 );
1817 }
1818 {
1819 auto f = [](auto&& a, auto&& b){return blas::gemm(0.1, a, blas::H(b));};
1820 multi::cuda::managed::array<double, 2> c = a; BOOST_REQUIRE(size(c) == 2 and size(rotated(c)) == 3);
1821 c = f(a, b);
1822 BOOST_REQUIRE( c[1][2] == 5.3 ); BOOST_REQUIRE(size(c) == 2 and size(rotated(c)) == 4);
1823 }
1824 {
1825 auto c = blas::gemm(0.1, a, blas::H(b)); // c=ab, c⸆=b⸆a⸆
1826 BOOST_REQUIRE( c[1][2] == 5.3 );
1827 }
1828 }
1829
1830 using complex = std::complex<double>; complex const I{0, 1};
1831
1832 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_nonsquare_automatic){
1833 namespace cuda = multi::cuda;
1834 namespace blas = multi::blas;
1835 multi::array<complex, 2> const a = {
1836 { 1. + 2.*I, 3. - 3.*I, 1.-9.*I},
1837 { 9. + 1.*I, 7. + 4.*I, 1.-8.*I},
1838 };
1839 multi::array<complex, 2> const b = {
1840 { 11.+1.*I, 12.+1.*I, 4.+1.*I, 8.-2.*I},
1841 { 7.+8.*I, 19.-2.*I, 2.+1.*I, 7.+1.*I},
1842 { 5.+1.*I, 3.-1.*I, 3.+8.*I, 1.+1.*I}
1843 };
1844 {
1845 multi::array<complex, 2> c({2, 4});
1846 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
1847 BOOST_REQUIRE( c[1][2] == complex(112, 12) );
1848 }
1849 {
1850 cuda::array<complex, 2> const acu = a;
1851 cuda::array<complex, 2> const bcu = b;
1852 cuda::array<complex, 2> ccu({2, 4});
1853 blas::gemm(complex(1.), acu, bcu, complex(0.), ccu);
1854 BOOST_REQUIRE( ccu[1][2] == complex(112, 12) );
1855 }
1856 {
1857 cuda::managed::array<complex, 2> const amcu = a;
1858 cuda::managed::array<complex, 2> const bmcu = b;
1859 cuda::managed::array<complex, 2> cmcu({2, 4});
1860 blas::gemm(1., amcu, bmcu, 0., cmcu);
1861 BOOST_REQUIRE( cmcu[1][2] == complex(112, 12) );
1862 }
1863 }
1864
1865 struct multiplies_bind1st{
1866 multiplies_bind1st(multi::cuda::managed::array<complex, 2>&& m) : m_(std::move(m)){}
1867 template<class A>
1868 auto operator()(A const& a) const{
1869 using multi::blas::gemm;
1870 return gemm(m_, a);
1871 }
1872 private:
1873 multi::cuda::managed::array<complex, 2> m_;
1874 };
1875
1876 BOOST_AUTO_TEST_CASE(multi_constructors_inqnvcc_bug){
1877 namespace cuda = multi::cuda;
1878 namespace blas = multi::blas;
1879 cuda::managed::array<complex, 2> m = {
1880 { 1. + 2.*I, 3. - 3.*I, 1.-9.*I},
1881 { 9. + 1.*I, 7. + 4.*I, 1.-8.*I},
1882 };
1883 cuda::managed::array<complex, 2> const b = {
1884 { 11.+1.*I, 12.+1.*I, 4.+1.*I, 8.-2.*I},
1885 { 7.+8.*I, 19.-2.*I, 2.+1.*I, 7.+1.*I},
1886 { 5.+1.*I, 3.-1.*I, 3.+8.*I, 1.+1.*I}
1887 };
1888 auto c = blas::gemm(m, b);
1889 BOOST_REQUIRE( c[1][2] == complex(112, 12) );
1890 BOOST_REQUIRE( b[1][2] == 2.+1.*I );
1891
1892 auto m_as_operator2 = [&](auto const& B){return blas::gemm(m, B);};
1893 auto c2 = m_as_operator2(b);
1894 BOOST_REQUIRE( c == c2 );
1895
1896 auto m_as_operator3 = [=](auto const& B){return blas::gemm(m, B);};
1897 auto c3 = m_as_operator3(b);
1898 BOOST_REQUIRE( c == c3 );
1899
1900 multiplies_bind1st m_as_operator4(std::move(m));
1901 auto c4 = m_as_operator4(b);
1902 BOOST_REQUIRE( c == c4 );
1903 BOOST_REQUIRE( is_empty(m) );
1904 }
1905
1906 BOOST_AUTO_TEST_CASE(multi_blas_gemm_elongated){
1907 namespace blas = multi::blas;
1908 multi::array<complex, 2> const a = {
1909 {1.-2.*I, 9.-1.*I}
1910 };
1911 BOOST_REQUIRE( size(a) == 1 and size(a[0]) == 2 );
1912 multi::array<complex, 2> const b = {
1913 {2. + 3.*I},
1914 {19.+11.*I}
1915 };
1916 BOOST_REQUIRE( size(b) == 2 and size(b[0]) == 1 );
1917 {
1918 multi::array<complex, 2> c({1, 1});
1919 blas::gemm(1., a, b, 0., c); // c=ab, c⸆=b⸆a⸆
1920 BOOST_REQUIRE( c[0][0] == a[0][0]*b[0][0] + a[0][1]*b[1][0] );
1921 }
1922 {
1923 multi::array<complex, 2> c({2, 2});
1924 blas::gemm(1., b, a, 0., c); // c=ba, c⸆=a⸆b⸆
1925 BOOST_REQUIRE( c[0][0] == b[0][0]*a[0][0] );
1926 BOOST_REQUIRE( c[1][1] == b[1][0]*a[0][1] );
1927 auto const c_copy = blas::gemm(1., b, a);
1928 BOOST_REQUIRE( c_copy == c );
1929 }
1930 {
1931 multi::array<complex, 2> c({1, 1});
1932 blas::gemm(1., a, blas::H(a), 0., c); // c=ab, c⸆=b⸆a⸆
1933 BOOST_REQUIRE( c[0][0] == a[0][0]*conj(a[0][0]) + a[0][1]*conj(a[0][1]) );
1934 auto const c_copy = blas::gemm(1., a, blas::H(a));
1935 BOOST_REQUIRE( c_copy == c );
1936 }
1937 {
1938 multi::array<complex, 2> c({2, 2});
1939 blas::gemm(1., blas::H(a), a, 0., c); // c=ab, c⸆=b⸆a⸆
1940 BOOST_REQUIRE( c[0][0] == a[0][0]*conj(a[0][0]) );
1941 auto const c_copy = blas::gemm(blas::H(a), a);
1942 BOOST_REQUIRE( c_copy == c );
1943 }
1944 {
1945 multi::array<complex, 2> const a = {
1946 {1.-2.*I, 9.-1.*I}
1947 };
1948 multi::array<complex, 2> const b = {
1949 {2.+3.*I},
1950 {19.+11.*I}
1951 };
1952 multi::array<complex, 2> c({1, 1});
1953 blas::gemm(1., a, b, 0., c);
1954 }
1955 {
1956 multi::array<double, 2> const a = {{2., 3.}};
1957 multi::array<double, 2> const b = {{4., 5.}};
1958 multi::array<double, 2> c({1, 1});
1959 blas::gemm(1., a, blas::T(b), 0., c); // blas error
1960 }
1961 {
1962 multi::array<double, 2> const a = {
1963 {2.},
1964 {3.},
1965 {5.}
1966 };
1967 multi::array<double, 2> const b = {
1968 {4.},
1969 {5.},
1970 {6.}
1971 };
1972 multi::array<double, 2> c1({1, 1}), c2({1, 1});
1973 auto ra = rotated(a).decay();
1974 blas::gemm(1., ra, b, 0., c1); // ok
1975 BOOST_REQUIRE( c1[0][0] == a[0][0]*b[0][0] + a[1][0]*b[1][0] + a[2][0]*b[2][0] );
1976
1977 // gemm(1., rotated(a), b, 0., c2); // was blas error
1978 // BOOST_REQUIRE(c1 == c2); // not reached
1979 }
1980 {
1981 multi::array<double, 2> const a = {
1982 {2.},
1983 {3.},
1984 {5.}
1985 };
1986 multi::array<double, 2> const b = {
1987 {4., 2.},
1988 {5., 1.},
1989 {6., 2.}
1990 };
1991 multi::array<double, 2> c1({1, 2}), c2({1, 2});
1992 auto ra = rotated(a).decay();
1993
1994 blas::gemm(1., ra, b, 0., c1); // ok
1995 BOOST_REQUIRE( c1[0][0] == a[0][0]*b[0][0] + a[1][0]*b[1][0] + a[2][0]*b[2][0] );
1996
1997 blas::gemm(1., blas::T(a), b, 0., c2);
1998 BOOST_REQUIRE(c1 == c2);
1999 }
2000 if(0){
2001 multi::array<double, 2> const a = {
2002 {2.},
2003 {3.},
2004 {5.}
2005 };
2006 multi::array<double, 2> const b = {
2007 {4.},
2008 {5.},
2009 {6.}
2010 };
2011 multi::array<double, 2> c1({1, 1}), c2({1, 1});
2012 auto ra = rotated(a).decay();
2013 blas::gemm(1., ra, b, 0., c1); // ok
2014 BOOST_REQUIRE( c1[0][0] == a[0][0]*b[0][0] + a[1][0]*b[1][0] + a[2][0]*b[2][0] );
2015
2016 blas::gemm(1., blas::T(a), b, 0., c2);
2017 BOOST_REQUIRE(c1 == c2);
2018 }
2019 if(0){
2020 multi::array<complex, 2> const a = {
2021 {2. + 1.*I},
2022 {3. + 2.*I}
2023 };
2024 multi::array<complex, 2> const b = {
2025 {4. + 3.*I},
2026 {5. + 4.*I}
2027 };
2028 multi::array<complex, 2> c1({1, 1}), c2({1, 1});
2029 auto ha = blas::hermitized(a).decay();
2030 blas::gemm(1., ha, b, 0., c1); // ok
2031 blas::gemm(1., blas::H(a), b, 0., c2); // was blas error
2032 print(c1);
2033 print(c2);
2034 // BOOST_REQUIRE(c1 == c2);
2035 }
2036 {
2037 multi::array<complex, 2> const a = {
2038 {1.-2.*I},
2039 {9.-1.*I}
2040 };
2041 multi::array<complex, 2> const b = {
2042 {2.+3.*I, 2. + 999.*I},
2043 {19.+11.*I, 1. + 999.*I}
2044 };
2045 multi::array<complex, 2> c1({1, 2}), c2({2, 1}, 999.);
2046 auto ha = blas::hermitized(a).decay();
2047 blas::gemm(1., ha, b, 0., c1);
2048 blas::gemm(1., blas::H(b), a, 0., c2);
2049 // print(c1);
2050 // print(c2);
2051 // std::cout << std::endl;
2052 BOOST_REQUIRE( c1 == blas::hermitized(c2) );
2053 }
2054 {
2055 // multi::array<complex, 2> c({1, 1});
2056 // using multi::blas::gemm;
2057 // using multi::blas::hermitized;
2058 // gemm(1., hermitized(b), b, 0., c);
2059 // BOOST_REQUIRE( c[0][0] == b[0][0]*conj(b[0][0]) + b[1][0]*conj(b[1][0]) );
2060 }
2061 }
2062
2063 template<class A> void what(A&&) = delete;
2064 template<class... A> void what() = delete;
2065
2066 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_nonsquare_automatic2){
2067 namespace cuda = multi::cuda;
2068 namespace blas = multi::blas;
2069 multi::array<complex, 2> const a = {
2070 {1.-2.*I, 9.-1.*I},
2071 {3.+3.*I, 7.-4.*I},
2072 {1.+9.*I, 1.+8.*I}
2073 };
2074 multi::array<complex, 2> const b = {
2075 { 11.+1.*I, 12.+1.*I, 4.+1.*I, 8.-2.*I},
2076 { 7.+8.*I, 19.-2.*I, 2.+1.*I, 7.+1.*I},
2077 { 5.+1.*I, 3.-1.*I, 3.+8.*I, 1.+1.*I}
2078 };
2079 {
2080 multi::array<complex, 2> c({2, 4});
2081 blas::gemm(1., blas::H(a), b, 0., c); // c=ab, c⸆=b⸆a⸆
2082 BOOST_REQUIRE( c[1][2] == complex(112, 12) );
2083
2084 multi::array<complex, 2> const c_copy = blas::gemm(1., blas::H(a), b);
2085 multi::array<complex, 2> const c_copy2 = blas::gemm(blas::H(a), b);
2086 BOOST_REQUIRE(( c == c_copy and c == c_copy2 ));
2087 }
2088 {
2089 cuda::array<complex, 2> const acu = a;
2090 cuda::array<complex, 2> const bcu = b;
2091 cuda::array<complex, 2> ccu({2, 4}, acu.get_allocator());
2092 blas::gemm(1., blas::H(acu), bcu, 0., ccu);
2093 BOOST_REQUIRE( ccu[1][2] == complex(112, 12) );
2094
2095 // what(base(acu));
2096 // what(blas::H(acu).get_allocator());
2097 // what<decltype(blas::H(acu))::decay_type, decltype(acu)::decay_type>();
2098 // cuda::array<complex, 2> const ccu_copy = blas::gemm(1., blas::H(acu), bcu);
2099 // cuda::array<complex, 2> const ccu_copy2 = blas::gemm(blas::H(acu), bcu);
2100 // BOOST_REQUIRE(( ccu_copy == ccu and ccu_copy2 == ccu ));
2101 }
2102 #if 0
2103 {
2104 cuda::managed::array<complex, 2> const amcu = a;
2105 cuda::managed::array<complex, 2> const bmcu = b;
2106 cuda::managed::array<complex, 2> cmcu({2, 4});
2107 blas::gemm(1., blas::H(amcu), bmcu, 0., cmcu);
2108 BOOST_REQUIRE( cmcu[1][2] == complex(112, 12) );
2109
2110 // [](void*){}();
2111
2112 cuda::managed::array<complex, 2> const cmcu_copy = blas::gemm(1., blas::H(amcu), bmcu);
2113 cuda::managed::array<complex, 2> const cmcu_copy2 = blas::gemm(blas::H(amcu), bmcu);
2114 BOOST_REQUIRE(( cmcu_copy == cmcu and cmcu_copy2 == cmcu ));
2115 }
2116 #endif
2117 }
2118
2119 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_nonsquare_automatic3){
2120 namespace cuda = multi::cuda;
2121 namespace blas = multi::blas;
2122 multi::array<complex, 2> const a = {
2123 {1.-2.*I, 9.-1.*I},
2124 {3.+3.*I, 7.-4.*I},
2125 {1.+9.*I, 1.+8.*I}
2126 };
2127 multi::array<complex, 2> const bH = {
2128 { 11.+1.*I, 12.+1.*I, 4.+1.*I, 8.-2.*I},
2129 { 7.+8.*I, 19.-2.*I, 2.+1.*I, 7.+1.*I},
2130 { 5.+1.*I, 3.-1.*I, 3.+8.*I, 1.+1.*I}
2131 };
2132 multi::array<complex, 2> const b = multi::blas::hermitized(bH);
2133 {
2134 multi::array<complex, 2> c({2, 4});
2135 blas::gemm(1., blas::H(a), blas::H(b), 0., c); // c=ab, c⸆=b⸆a⸆
2136 BOOST_REQUIRE( c[1][2] == complex(112, 12) );
2137 }
2138 {
2139 cuda::array<complex, 2> const acu = a;
2140 cuda::array<complex, 2> const bcu = b;
2141 cuda::array<complex, 2> ccu({2, 4});
2142 blas::gemm(1., blas::H(acu), blas::H(bcu), 0., ccu);
2143 BOOST_REQUIRE( ccu[1][2] == complex(112, 12) );
2144 }
2145 {
2146 cuda::managed::array<complex, 2> const amcu = a;
2147 cuda::managed::array<complex, 2> const bmcu = b;
2148 cuda::managed::array<complex, 2> cmcu({2, 4});
2149 blas::gemm(1., blas::H(amcu), blas::H(bmcu), 0., cmcu);
2150 BOOST_REQUIRE( cmcu[1][2] == complex(112, 12) );
2151 }
2152 }
2153
2154 BOOST_AUTO_TEST_CASE(multi_adaptors_blas_gemm_complex_nonsquare_automatic4){
2155 namespace blas = multi::blas;
2156
2157 multi::array<complex, 2> c({12, 12});
2158 {
2159 multi::array<complex, 2> const a({12, 100}, 1.+2.*I);
2160 multi::array<complex, 2> const b({12, 100}, 1.+2.*I);
2161 using multi::blas::hermitized;
2162 using multi::blas::gemm;
2163 blas::gemm(1., a, blas::H(b), 0., c);
2164 BOOST_REQUIRE( real(c[0][0]) > 0);
2165
2166 auto c_copy = blas::gemm(1., a, blas::H(b));
2167 BOOST_REQUIRE( c_copy == c );
2168 }
2169 {
2170 multi::array<complex, 2> const a_block({24, 100}, 1.+2.*I);
2171 multi::array<complex, 2> const b({12, 100}, 1.+2.*I);
2172 multi::array<complex, 2> c2({12, 12});
2173
2174 blas::gemm(1., a_block.strided(2), blas::H(b), 0., c2);
2175
2176 BOOST_REQUIRE( real(c[0][0]) > 0);
2177 BOOST_REQUIRE( c == c2 );
2178
2179 auto c2_copy = blas::gemm(1., a_block.strided(2), blas::H(b));
2180 BOOST_REQUIRE( c2_copy == c2 );
2181 }
2182 }
2183
2184 //BOOST_AUTO_TEST_CASE(multi_blas_gemm_complex_issue68){
2185 // namespace cuda = multi::cuda;
2186 // namespace blas = multi::blas;
2187 // cuda::managed::array<complex, 2> const a = {
2188 // {1.-2.*I},
2189 // {3.+3.*I},
2190 // {1.+9.*I}
2191 // };
2192 // {
2193 // cuda::managed::array<complex, 2> c({1, 1});
2194 // blas::gemm(1., blas::H(a), a, 0., c);
2195 // BOOST_REQUIRE( c[0][0] == 105. + 0.*I );
2196 // }
2197 //#if 0
2198 // {
2199 // auto c = blas::gemm(2., blas::H(a), a);
2200 // BOOST_REQUIRE( c[0][0] == 210. + 0.*I );
2201 // }
2202 // {
2203 // auto c = blas::gemm(blas::H(a), a);
2204
2205 // BOOST_REQUIRE( c[0][0] == 105. + 0.*I );
2206 // }
2207 //#endif
2208 //}
2209 #endif
2210
2211 #if 0
2212 BOOST_AUTO_TEST_CASE(blas_gemm_timing){
2213
2214 multi::array<complex, 2> A({1000, 2000});
2215 multi::array<complex, 2> B({2000, 3000});
2216 multi::array<complex, 2> C({size(A), size(~B)});
2217 multi::array<complex, 2> C2(extensions(C), complex{NAN, NAN});
2218 multi::array<complex, 2> C3(extensions(C), complex{NAN, NAN});
2219 multi::array<complex, 2> C4(extensions(C), complex{NAN, NAN});
2220 A[99][99] = B[11][22] = C[33][44] = 1.0;
2221 std::cerr<< "memory " << (A.num_elements()+ B.num_elements() + C.num_elements())*sizeof(complex)/1e6 <<" MB"<<std::endl;
2222
2223 {
2224 boost::timer::auto_cpu_timer t;
2225 auto rand = [d=std::uniform_real_distribution<>{0., 10.}, g=std::mt19937{}]() mutable{return complex{d(g), d(g)};};
2226 std::generate(A.elements().begin(), A.elements().end(), rand);
2227 std::generate(B.elements().begin(), B.elements().end(), rand);
2228 }
2229 namespace blas = multi::blas;
2230 {
2231 boost::timer::auto_cpu_timer t; // 0.237581s
2232 C = blas::gemm(A, B);
2233 }
2234 // {
2235 // boost::timer::auto_cpu_timer t; // 4.516157s
2236 // for(auto i : extension(~B)) (~C2)[i] = blas::gemv(A, (~B)[i]);
2237 // }
2238 // {
2239 // boost::timer::auto_cpu_timer t; // 4.516157s
2240 // for(auto i : extension(A)) C2[i] = blas::gemv(~B, A[i]);
2241 // }
2242 {
2243 boost::timer::auto_cpu_timer t; // 32.705804s
2244 for(auto i:extension(A)) for(auto j:extension(~B)) C3[i][j] = blas::dot(A[i], (~B)[j]);
2245 }
2246 using namespace blas::operators;
2247
2248 BOOST_REQUIRE( std::equal(
2249 begin(C), end(C), begin(C2), [](auto const& crow, auto const& c2row){
2250 return ((crow - c2row)^2) < 1e-13;
2251 }
2252 ) );
2253
2254 BOOST_REQUIRE( std::equal(
2255 begin(C), end(C), begin(C3), [](auto const& crow, auto const& c3row){
2256 return ((crow - c3row)^2) < 1e-13;
2257 }
2258 ) );
2259
2260 }
2261 #endif
2262
2263