1 //////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source
3 // License. See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
6 //
7 // File developed by:
8 // Miguel A. Morales, moralessilva2@llnl.gov
9 // Lawrence Livermore National Laboratory
10 // Alfredo Correa, correaa@llnl.gov
11 // Lawrence Livermore National Laboratory
12 //
13 // File created by:
14 // Miguel A. Morales, moralessilva2@llnl.gov
15 // Lawrence Livermore National Laboratory
16 ////////////////////////////////////////////////////////////////////////////////
17
18 #ifndef AFQMC_SPARSE_CPU_HPP
19 #define AFQMC_SPARSE_CPU_HPP
20
21 #if defined(HAVE_MKL)
22 #include "AFQMC/Numerics/detail/CPU/mkl_spblas.h"
23 #endif
24 #include <cassert>
25 #include <complex>
26
27 namespace ma
28 {
29 namespace backup_impl
30 {
31 template<typename T>
csrmv(const char transa,const int M,const int K,const T alpha,const char * matdescra,const T * A,const int * indx,const int * pntrb,const int * pntre,const T * x,const T beta,T * y)32 void csrmv(const char transa,
33 const int M,
34 const int K,
35 const T alpha,
36 const char* matdescra,
37 const T* A,
38 const int* indx,
39 const int* pntrb,
40 const int* pntre,
41 const T* x,
42 const T beta,
43 T* y)
44 {
45 assert(matdescra[0] == 'G' && (matdescra[3] == 'C' || matdescra[3] == 'F'));
46 int disp = (matdescra[3] == 'C') ? 0 : -1;
47 int p0 = *pntrb;
48 if (transa == 'n' || transa == 'N')
49 {
50 for (int nr = 0; nr < M; nr++, y++, pntrb++, pntre++)
51 {
52 (*y) *= beta;
53 for (int i = *pntrb - p0; i < *pntre - p0; i++)
54 {
55 if (*(indx + i) + disp >= K)
56 continue;
57 *y += alpha * (*(A + i)) * (*(x + (*(indx + i)) + disp));
58 }
59 }
60 }
61 else if (transa == 't' || transa == 'T')
62 {
63 for (int k = 0; k < K; k++)
64 (*(y + k)) *= beta;
65 for (int nr = 0; nr < M; nr++, pntrb++, pntre++, x++)
66 {
67 for (int i = *pntrb - p0; i < *pntre - p0; i++)
68 {
69 if (*(indx + i) + disp >= K)
70 continue;
71 *(y + (*(indx + i)) + disp) += alpha * (*(A + i)) * (*x);
72 }
73 }
74 }
75 else if (transa == 'h' || transa == 'H' || transa == 'c' || transa == 'C')
76 {
77 for (int k = 0; k < K; k++)
78 (*(y + k)) *= beta;
79 for (int nr = 0; nr < M; nr++, pntrb++, pntre++, x++)
80 {
81 for (int i = *pntrb - p0; i < *pntre - p0; i++)
82 {
83 if (*indx + disp >= K)
84 continue;
85 *(y + (*(indx + i)) + disp) += alpha * (*(A + i)) * (*x);
86 }
87 }
88 }
89 }
90
91 template<typename T>
csrmv(const char transa,const int M,const int K,const std::complex<T> alpha,const char * matdescra,const std::complex<T> * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<T> * x,const std::complex<T> beta,std::complex<T> * y)92 void csrmv(const char transa,
93 const int M,
94 const int K,
95 const std::complex<T> alpha,
96 const char* matdescra,
97 const std::complex<T>* A,
98 const int* indx,
99 const int* pntrb,
100 const int* pntre,
101 const std::complex<T>* x,
102 const std::complex<T> beta,
103 std::complex<T>* y)
104 {
105 assert(matdescra[0] == 'G' && (matdescra[3] == 'C')); // || matdescra[3]=='F'));
106 int disp = (matdescra[3] == 'C') ? 0 : -1;
107 int p0 = *pntrb;
108 if (transa == 'n' || transa == 'N')
109 {
110 for (int nr = 0; nr < M; nr++, y++, pntrb++, pntre++)
111 {
112 (*y) *= beta;
113 for (int i = *pntrb - p0; i < *pntre - p0; i++)
114 {
115 if (*(indx + i) + disp >= K)
116 continue;
117 *y += alpha * (*(A + i)) * (*(x + (*(indx + i)) + disp));
118 }
119 }
120 }
121 else if (transa == 't' || transa == 'T')
122 {
123 for (int k = 0; k < K; k++)
124 (*(y + k)) *= beta;
125 for (int nr = 0; nr < M; nr++, pntrb++, pntre++, x++)
126 {
127 for (int i = *pntrb - p0; i < *pntre - p0; i++)
128 {
129 if (*(indx + i) + disp >= K)
130 continue;
131 *(y + (*(indx + i)) + disp) += alpha * (*(A + i)) * (*x);
132 }
133 }
134 }
135 else if (transa == 'h' || transa == 'H' || transa == 'c' || transa == 'C')
136 {
137 for (int k = 0; k < K; k++)
138 (*(y + k)) *= beta;
139 for (int nr = 0; nr < M; nr++, pntrb++, pntre++, x++)
140 {
141 for (int i = *pntrb - p0; i < *pntre - p0; i++)
142 {
143 if (*indx + disp >= K)
144 continue;
145 *(y + (*(indx + i)) + disp) += alpha * ma::conj(*(A + i)) * (*x);
146 }
147 }
148 }
149 }
150
151 template<typename T>
csrmm(const char transa,const int M,const int N,const int K,const T alpha,const char * matdescra,const T * A,const int * indx,const int * pntrb,const int * pntre,const T * B,const int ldb,const T beta,T * C,const int ldc)152 void csrmm(const char transa,
153 const int M,
154 const int N,
155 const int K,
156 const T alpha,
157 const char* matdescra,
158 const T* A,
159 const int* indx,
160 const int* pntrb,
161 const int* pntre,
162 const T* B,
163 const int ldb,
164 const T beta,
165 T* C,
166 const int ldc)
167 {
168 assert(matdescra[0] == 'G' && (matdescra[3] == 'C')); // || matdescra[3]=='F'));
169 int p0 = *pntrb;
170 int disp = (matdescra[3] == 'C') ? 0 : -1;
171 if (transa == 'n' || transa == 'N')
172 {
173 for (int nr = 0; nr < M; nr++, pntrb++, pntre++, C += ldc)
174 {
175 for (int i = 0; i < N; i++)
176 (*(C + i)) *= beta;
177 for (int i = *pntrb - p0; i < *pntre - p0; i++)
178 {
179 if (*(indx + i) + disp >= K)
180 continue;
181 // at this point *(A+i) is A_rc, c=*(indx+i)+disp, *C is C(r,0)
182 // C(r,:) = A_rc * B(c,:)
183 const T* Bc = B + ldb * (*(indx + i) + disp);
184 T* Cr = C;
185 T Arc = alpha * (*(A + i));
186 for (int k = 0; k < N; k++, Cr++, Bc++)
187 *Cr += Arc * (*Bc);
188 }
189 }
190 }
191 else if (transa == 't' || transa == 'T')
192 {
193 // not optimal, but simple
194 for (int i = 0; i < K; i++)
195 for (int j = 0; j < N; j++)
196 (*(C + i * ldc + j)) *= beta;
197 for (int nr = 0; nr < M; nr++, pntrb++, pntre++, B += ldb)
198 {
199 for (int i = *pntrb - p0; i < *pntre - p0; i++)
200 {
201 if (*(indx + i) + disp >= K)
202 continue;
203 // at this point *(A+i) is A_rc, c=*(indx+i)+disp
204 // C(c,:) = A_rc * B(r,:)
205 const T* Br = B;
206 T* Cc = C + ldc * (*(indx + i) + disp);
207 T Arc = alpha * (*(A + i));
208 for (int k = 0; k < N; k++, Cc++, Br++)
209 *Cc += Arc * (*Br);
210 }
211 }
212 }
213 else if (transa == 'h' || transa == 'H' || transa == 'c' || transa == 'C')
214 {
215 // not optimal, but simple
216 for (int i = 0; i < K; i++)
217 for (int j = 0; j < N; j++)
218 (*(C + i * ldc + j)) *= beta;
219 for (int nr = 0; nr < M; nr++, pntrb++, pntre++, B += ldb)
220 {
221 for (int i = *pntrb - p0; i < *pntre - p0; i++)
222 {
223 if (*(indx + i) + disp >= K)
224 continue;
225 // at this point *(A+i) is A_rc, c=*(indx+i)+disp
226 // C(c,:) = A_rc * B(r,:)
227 const T* Br = B;
228 T* Cc = C + ldc * (*(indx + i) + disp);
229 T Arc = alpha * (*(A + i));
230 for (int k = 0; k < N; k++, Cc++, Br++)
231 *Cc += Arc * (*Br);
232 }
233 }
234 }
235 }
236
237 template<typename T>
csrmm(const char transa,const int M,const int N,const int K,const std::complex<T> alpha,const char * matdescra,const std::complex<T> * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<T> * B,const int ldb,const std::complex<T> beta,std::complex<T> * C,const int ldc)238 void csrmm(const char transa,
239 const int M,
240 const int N,
241 const int K,
242 const std::complex<T> alpha,
243 const char* matdescra,
244 const std::complex<T>* A,
245 const int* indx,
246 const int* pntrb,
247 const int* pntre,
248 const std::complex<T>* B,
249 const int ldb,
250 const std::complex<T> beta,
251 std::complex<T>* C,
252 const int ldc)
253 {
254 assert(matdescra[0] == 'G' && (matdescra[3] == 'C')); // || matdescra[3]=='F'));
255 int disp = (matdescra[3] == 'C') ? 0 : -1;
256 int p0 = *pntrb;
257 if (transa == 'n' || transa == 'N')
258 {
259 for (int nr = 0; nr < M; nr++, pntrb++, pntre++, C += ldc)
260 {
261 for (int i = 0; i < N; i++)
262 (*(C + i)) *= beta;
263 for (int i = *pntrb - p0; i < *pntre - p0; i++)
264 {
265 if (*(indx + i) + disp >= K)
266 continue;
267 // at this point *(A+i) is A_rc, c=*(indx+i)+disp, *C is C(r,0)
268 // C(r,:) = A_rc * B(c,:)
269 const std::complex<T>* Bc = B + ldb * (*(indx + i) + disp);
270 std::complex<T>* Cr = C;
271 std::complex<T> Arc = alpha * (*(A + i));
272 for (int k = 0; k < N; k++, Cr++, Bc++)
273 *Cr += Arc * (*Bc);
274 }
275 }
276 }
277 else if (transa == 't' || transa == 'T')
278 {
279 // not optimal, but simple
280 for (int i = 0; i < K; i++)
281 for (int j = 0; j < N; j++)
282 (*(C + i * ldc + j)) *= beta;
283 for (int nr = 0; nr < M; nr++, pntrb++, pntre++, B += ldb)
284 {
285 for (int i = *pntrb - p0; i < *pntre - p0; i++)
286 {
287 if (*(indx + i) + disp >= K)
288 continue;
289 // at this point *(A+i) is A_rc, c=*(indx+i)+disp
290 // C(c,:) = A_rc * B(r,:)
291 const std::complex<T>* Br = B;
292 std::complex<T>* Cc = C + ldc * (*(indx + i) + disp);
293 std::complex<T> Arc = alpha * (*(A + i));
294 for (int k = 0; k < N; k++, Cc++, Br++)
295 *Cc += Arc * (*Br);
296 }
297 }
298 }
299 else if (transa == 'h' || transa == 'H' || transa == 'c' || transa == 'C')
300 {
301 // not optimal, but simple
302 for (int i = 0; i < K; i++)
303 for (int j = 0; j < N; j++)
304 (*(C + i * ldc + j)) *= beta;
305 for (int nr = 0; nr < M; nr++, pntrb++, pntre++, B += ldb)
306 {
307 for (int i = *pntrb - p0; i < *pntre - p0; i++)
308 {
309 if (*(indx + i) + disp >= K)
310 continue;
311 // at this point *(A+i) is A_rc, c=*(indx+i)+disp
312 // C(c,:) = A_rc * B(r,:)
313 const std::complex<T>* Br = B;
314 std::complex<T>* Cc = C + ldc * (*(indx + i) + disp);
315 std::complex<T> Arc = alpha * ma::conj(*(A + i));
316 for (int k = 0; k < N; k++, Cc++, Br++)
317 *Cc += Arc * (*Br);
318 }
319 }
320 }
321 }
322
323 } // namespace backup_impl
324
csrmv(const char transa,const int M,const int K,const float alpha,const char * matdescra,const float * A,const int * indx,const int * pntrb,const int * pntre,const float * x,const float beta,float * y)325 inline static void csrmv(const char transa,
326 const int M,
327 const int K,
328 const float alpha,
329 const char* matdescra,
330 const float* A,
331 const int* indx,
332 const int* pntrb,
333 const int* pntre,
334 const float* x,
335 const float beta,
336 float* y)
337 {
338 #if defined(HAVE_MKL)
339 mkl_scsrmv(transa, M, K, alpha, matdescra, A, indx, pntrb, pntre, x, beta, y);
340 #else
341 backup_impl::csrmv(transa, M, K, alpha, matdescra, A, indx, pntrb, pntre, x, beta, y);
342 #endif
343 }
344
csrmv(const char transa,const int M,const int K,const double alpha,const char * matdescra,const double * A,const int * indx,const int * pntrb,const int * pntre,const double * x,const double beta,double * y)345 inline static void csrmv(const char transa,
346 const int M,
347 const int K,
348 const double alpha,
349 const char* matdescra,
350 const double* A,
351 const int* indx,
352 const int* pntrb,
353 const int* pntre,
354 const double* x,
355 const double beta,
356 double* y)
357 {
358 #if defined(HAVE_MKL)
359 mkl_dcsrmv(transa, M, K, alpha, matdescra, A, indx, pntrb, pntre, x, beta, y);
360 #else
361 backup_impl::csrmv(transa, M, K, alpha, matdescra, A, indx, pntrb, pntre, x, beta, y);
362 #endif
363 }
364
csrmv(const char transa,const int M,const int K,const std::complex<float> alpha,const char * matdescra,const std::complex<float> * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<float> * x,const std::complex<float> beta,std::complex<float> * y)365 inline static void csrmv(const char transa,
366 const int M,
367 const int K,
368 const std::complex<float> alpha,
369 const char* matdescra,
370 const std::complex<float>* A,
371 const int* indx,
372 const int* pntrb,
373 const int* pntre,
374 const std::complex<float>* x,
375 const std::complex<float> beta,
376 std::complex<float>* y)
377 {
378 #if defined(HAVE_MKL)
379 mkl_ccsrmv(transa, M, K, alpha, matdescra, A, indx, pntrb, pntre, x, beta, y);
380 #else
381 backup_impl::csrmv(transa, M, K, alpha, matdescra, A, indx, pntrb, pntre, x, beta, y);
382 #endif
383 }
384
csrmv(const char transa,const int M,const int K,const std::complex<double> alpha,const char * matdescra,const std::complex<double> * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<double> * x,const std::complex<double> beta,std::complex<double> * y)385 inline static void csrmv(const char transa,
386 const int M,
387 const int K,
388 const std::complex<double> alpha,
389 const char* matdescra,
390 const std::complex<double>* A,
391 const int* indx,
392 const int* pntrb,
393 const int* pntre,
394 const std::complex<double>* x,
395 const std::complex<double> beta,
396 std::complex<double>* y)
397 {
398 #if defined(HAVE_MKL)
399 mkl_zcsrmv(transa, M, K, alpha, matdescra, A, indx, pntrb, pntre, x, beta, y);
400 #else
401 backup_impl::csrmv(transa, M, K, alpha, matdescra, A, indx, pntrb, pntre, x, beta, y);
402 #endif
403 }
404
csrmm(const char transa,const int M,const int N,const int K,const float alpha,const char * matdescra,const float * A,const int * indx,const int * pntrb,const int * pntre,const float * B,const int ldb,const float beta,float * C,const int ldc)405 inline static void csrmm(const char transa,
406 const int M,
407 const int N,
408 const int K,
409 const float alpha,
410 const char* matdescra,
411 const float* A,
412 const int* indx,
413 const int* pntrb,
414 const int* pntre,
415 const float* B,
416 const int ldb,
417 const float beta,
418 float* C,
419 const int ldc)
420 {
421 #if defined(HAVE_MKL)
422 mkl_scsrmm(transa, M, N, K, alpha, matdescra, A, indx, pntrb, pntre, B, ldb, beta, C, ldc);
423 #else
424 backup_impl::csrmm(transa, M, N, K, alpha, matdescra, A, indx, pntrb, pntre, B, ldb, beta, C, ldc);
425 #endif
426 }
427
csrmm(const char transa,const int M,const int N,const int K,const std::complex<float> alpha,const char * matdescra,const std::complex<float> * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<float> * B,const int ldb,const std::complex<float> beta,std::complex<float> * C,const int ldc)428 inline static void csrmm(const char transa,
429 const int M,
430 const int N,
431 const int K,
432 const std::complex<float> alpha,
433 const char* matdescra,
434 const std::complex<float>* A,
435 const int* indx,
436 const int* pntrb,
437 const int* pntre,
438 const std::complex<float>* B,
439 const int ldb,
440 const std::complex<float> beta,
441 std::complex<float>* C,
442 const int ldc)
443 {
444 #if defined(HAVE_MKL)
445 mkl_ccsrmm(transa, M, N, K, alpha, matdescra, A, indx, pntrb, pntre, B, ldb, beta, C, ldc);
446 #else
447 backup_impl::csrmm(transa, M, N, K, alpha, matdescra, A, indx, pntrb, pntre, B, ldb, beta, C, ldc);
448 #endif
449 }
450
csrmm(const char transa,const int M,const int N,const int K,const double alpha,const char * matdescra,const double * A,const int * indx,const int * pntrb,const int * pntre,const double * B,const int ldb,const double beta,double * C,const int ldc)451 inline static void csrmm(const char transa,
452 const int M,
453 const int N,
454 const int K,
455 const double alpha,
456 const char* matdescra,
457 const double* A,
458 const int* indx,
459 const int* pntrb,
460 const int* pntre,
461 const double* B,
462 const int ldb,
463 const double beta,
464 double* C,
465 const int ldc)
466 {
467 #if defined(HAVE_MKL)
468 mkl_dcsrmm(transa, M, N, K, alpha, matdescra, A, indx, pntrb, pntre, B, ldb, beta, C, ldc);
469 #else
470 backup_impl::csrmm(transa, M, N, K, alpha, matdescra, A, indx, pntrb, pntre, B, ldb, beta, C, ldc);
471 #endif
472 }
473
csrmm(const char transa,const int M,const int N,const int K,const std::complex<double> alpha,const char * matdescra,const std::complex<double> * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<double> * B,const int ldb,const std::complex<double> beta,std::complex<double> * C,const int ldc)474 inline static void csrmm(const char transa,
475 const int M,
476 const int N,
477 const int K,
478 const std::complex<double> alpha,
479 const char* matdescra,
480 const std::complex<double>* A,
481 const int* indx,
482 const int* pntrb,
483 const int* pntre,
484 const std::complex<double>* B,
485 const int ldb,
486 const std::complex<double> beta,
487 std::complex<double>* C,
488 const int ldc)
489 {
490 #if defined(HAVE_MKL)
491 mkl_zcsrmm(transa, M, N, K, alpha, matdescra, A, indx, pntrb, pntre, B, ldb, beta, C, ldc);
492 #else
493 backup_impl::csrmm(transa, M, N, K, alpha, matdescra, A, indx, pntrb, pntre, B, ldb, beta, C, ldc);
494 #endif
495 }
496
csrmv(const char transa,const int M,const int K,float alpha,const char * matdescra,float * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<float> * x,const float beta,std::complex<float> * y)497 inline static void csrmv(const char transa,
498 const int M,
499 const int K,
500 float alpha,
501 const char* matdescra,
502 float* A,
503 const int* indx,
504 const int* pntrb,
505 const int* pntre,
506 const std::complex<float>* x,
507 const float beta,
508 std::complex<float>* y)
509 {
510 assert(matdescra[0] == 'G' && (matdescra[3] == 'C'));
511 csrmm(transa, M, 2, K, alpha, matdescra, A, indx, pntrb, pntre, reinterpret_cast<float const*>(x), 2, beta,
512 reinterpret_cast<float*>(y), 2);
513 }
514
csrmv(const char transa,const int M,const int K,double alpha,const char * matdescra,double * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<double> * x,const double beta,std::complex<double> * y)515 inline static void csrmv(const char transa,
516 const int M,
517 const int K,
518 double alpha,
519 const char* matdescra,
520 double* A,
521 const int* indx,
522 const int* pntrb,
523 const int* pntre,
524 const std::complex<double>* x,
525 const double beta,
526 std::complex<double>* y)
527 {
528 assert(matdescra[0] == 'G' && (matdescra[3] == 'C'));
529 csrmm(transa, M, 2, K, alpha, matdescra, A, indx, pntrb, pntre, reinterpret_cast<double const*>(x), 2, beta,
530 reinterpret_cast<double*>(y), 2);
531 }
532
csrmm(const char transa,const int M,const int N,const int K,const double alpha,const char * matdescra,const double * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<double> * B,const int ldb,const double beta,std::complex<double> * C,const int ldc)533 inline static void csrmm(const char transa,
534 const int M,
535 const int N,
536 const int K,
537 const double alpha,
538 const char* matdescra,
539 const double* A,
540 const int* indx,
541 const int* pntrb,
542 const int* pntre,
543 const std::complex<double>* B,
544 const int ldb,
545 const double beta,
546 std::complex<double>* C,
547 const int ldc)
548 {
549 assert(matdescra[0] == 'G' && (matdescra[3] == 'C'));
550 csrmm(transa, M, 2 * N, K, alpha, matdescra, A, indx, pntrb, pntre, reinterpret_cast<double const*>(B), 2 * ldb, beta,
551 reinterpret_cast<double*>(C), 2 * ldc);
552 }
553
csrmm(const char transa,const int M,const int N,const int K,const float alpha,const char * matdescra,const float * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<float> * B,const int ldb,const float beta,std::complex<float> * C,const int ldc)554 inline static void csrmm(const char transa,
555 const int M,
556 const int N,
557 const int K,
558 const float alpha,
559 const char* matdescra,
560 const float* A,
561 const int* indx,
562 const int* pntrb,
563 const int* pntre,
564 const std::complex<float>* B,
565 const int ldb,
566 const float beta,
567 std::complex<float>* C,
568 const int ldc)
569 {
570 assert(matdescra[0] == 'G' && (matdescra[3] == 'C'));
571 csrmm(transa, M, 2 * N, K, alpha, matdescra, A, indx, pntrb, pntre, reinterpret_cast<float const*>(B), 2 * ldb, beta,
572 reinterpret_cast<float*>(C), 2 * ldc);
573 }
574
575
576 } // namespace ma
577 #endif
578