1 //////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source
3 // License.  See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
6 //
7 // File developed by:
8 // Miguel A. Morales, moralessilva2@llnl.gov
9 //    Lawrence Livermore National Laboratory
10 // Alfredo Correa, correaa@llnl.gov
11 //    Lawrence Livermore National Laboratory
12 //
13 // File created by:
14 // Miguel A. Morales, moralessilva2@llnl.gov
15 //    Lawrence Livermore National Laboratory
16 ////////////////////////////////////////////////////////////////////////////////
17 
18 #ifndef AFQMC_SPARSE_CPU_HPP
19 #define AFQMC_SPARSE_CPU_HPP
20 
21 #if defined(HAVE_MKL)
22 #include "AFQMC/Numerics/detail/CPU/mkl_spblas.h"
23 #endif
24 #include <cassert>
25 #include <complex>
26 
27 namespace ma
28 {
29 namespace backup_impl
30 {
31 template<typename T>
csrmv(const char transa,const int M,const int K,const T alpha,const char * matdescra,const T * A,const int * indx,const int * pntrb,const int * pntre,const T * x,const T beta,T * y)32 void csrmv(const char transa,
33            const int M,
34            const int K,
35            const T alpha,
36            const char* matdescra,
37            const T* A,
38            const int* indx,
39            const int* pntrb,
40            const int* pntre,
41            const T* x,
42            const T beta,
43            T* y)
44 {
45   assert(matdescra[0] == 'G' && (matdescra[3] == 'C' || matdescra[3] == 'F'));
46   int disp = (matdescra[3] == 'C') ? 0 : -1;
47   int p0   = *pntrb;
48   if (transa == 'n' || transa == 'N')
49   {
50     for (int nr = 0; nr < M; nr++, y++, pntrb++, pntre++)
51     {
52       (*y) *= beta;
53       for (int i = *pntrb - p0; i < *pntre - p0; i++)
54       {
55         if (*(indx + i) + disp >= K)
56           continue;
57         *y += alpha * (*(A + i)) * (*(x + (*(indx + i)) + disp));
58       }
59     }
60   }
61   else if (transa == 't' || transa == 'T')
62   {
63     for (int k = 0; k < K; k++)
64       (*(y + k)) *= beta;
65     for (int nr = 0; nr < M; nr++, pntrb++, pntre++, x++)
66     {
67       for (int i = *pntrb - p0; i < *pntre - p0; i++)
68       {
69         if (*(indx + i) + disp >= K)
70           continue;
71         *(y + (*(indx + i)) + disp) += alpha * (*(A + i)) * (*x);
72       }
73     }
74   }
75   else if (transa == 'h' || transa == 'H' || transa == 'c' || transa == 'C')
76   {
77     for (int k = 0; k < K; k++)
78       (*(y + k)) *= beta;
79     for (int nr = 0; nr < M; nr++, pntrb++, pntre++, x++)
80     {
81       for (int i = *pntrb - p0; i < *pntre - p0; i++)
82       {
83         if (*indx + disp >= K)
84           continue;
85         *(y + (*(indx + i)) + disp) += alpha * (*(A + i)) * (*x);
86       }
87     }
88   }
89 }
90 
91 template<typename T>
csrmv(const char transa,const int M,const int K,const std::complex<T> alpha,const char * matdescra,const std::complex<T> * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<T> * x,const std::complex<T> beta,std::complex<T> * y)92 void csrmv(const char transa,
93            const int M,
94            const int K,
95            const std::complex<T> alpha,
96            const char* matdescra,
97            const std::complex<T>* A,
98            const int* indx,
99            const int* pntrb,
100            const int* pntre,
101            const std::complex<T>* x,
102            const std::complex<T> beta,
103            std::complex<T>* y)
104 {
105   assert(matdescra[0] == 'G' && (matdescra[3] == 'C')); // || matdescra[3]=='F'));
106   int disp = (matdescra[3] == 'C') ? 0 : -1;
107   int p0   = *pntrb;
108   if (transa == 'n' || transa == 'N')
109   {
110     for (int nr = 0; nr < M; nr++, y++, pntrb++, pntre++)
111     {
112       (*y) *= beta;
113       for (int i = *pntrb - p0; i < *pntre - p0; i++)
114       {
115         if (*(indx + i) + disp >= K)
116           continue;
117         *y += alpha * (*(A + i)) * (*(x + (*(indx + i)) + disp));
118       }
119     }
120   }
121   else if (transa == 't' || transa == 'T')
122   {
123     for (int k = 0; k < K; k++)
124       (*(y + k)) *= beta;
125     for (int nr = 0; nr < M; nr++, pntrb++, pntre++, x++)
126     {
127       for (int i = *pntrb - p0; i < *pntre - p0; i++)
128       {
129         if (*(indx + i) + disp >= K)
130           continue;
131         *(y + (*(indx + i)) + disp) += alpha * (*(A + i)) * (*x);
132       }
133     }
134   }
135   else if (transa == 'h' || transa == 'H' || transa == 'c' || transa == 'C')
136   {
137     for (int k = 0; k < K; k++)
138       (*(y + k)) *= beta;
139     for (int nr = 0; nr < M; nr++, pntrb++, pntre++, x++)
140     {
141       for (int i = *pntrb - p0; i < *pntre - p0; i++)
142       {
143         if (*indx + disp >= K)
144           continue;
145         *(y + (*(indx + i)) + disp) += alpha * ma::conj(*(A + i)) * (*x);
146       }
147     }
148   }
149 }
150 
151 template<typename T>
csrmm(const char transa,const int M,const int N,const int K,const T alpha,const char * matdescra,const T * A,const int * indx,const int * pntrb,const int * pntre,const T * B,const int ldb,const T beta,T * C,const int ldc)152 void csrmm(const char transa,
153            const int M,
154            const int N,
155            const int K,
156            const T alpha,
157            const char* matdescra,
158            const T* A,
159            const int* indx,
160            const int* pntrb,
161            const int* pntre,
162            const T* B,
163            const int ldb,
164            const T beta,
165            T* C,
166            const int ldc)
167 {
168   assert(matdescra[0] == 'G' && (matdescra[3] == 'C')); // || matdescra[3]=='F'));
169   int p0   = *pntrb;
170   int disp = (matdescra[3] == 'C') ? 0 : -1;
171   if (transa == 'n' || transa == 'N')
172   {
173     for (int nr = 0; nr < M; nr++, pntrb++, pntre++, C += ldc)
174     {
175       for (int i = 0; i < N; i++)
176         (*(C + i)) *= beta;
177       for (int i = *pntrb - p0; i < *pntre - p0; i++)
178       {
179         if (*(indx + i) + disp >= K)
180           continue;
181         // at this point *(A+i) is A_rc, c=*(indx+i)+disp, *C is C(r,0)
182         // C(r,:) = A_rc * B(c,:)
183         const T* Bc = B + ldb * (*(indx + i) + disp);
184         T* Cr       = C;
185         T Arc       = alpha * (*(A + i));
186         for (int k = 0; k < N; k++, Cr++, Bc++)
187           *Cr += Arc * (*Bc);
188       }
189     }
190   }
191   else if (transa == 't' || transa == 'T')
192   {
193     // not optimal, but simple
194     for (int i = 0; i < K; i++)
195       for (int j = 0; j < N; j++)
196         (*(C + i * ldc + j)) *= beta;
197     for (int nr = 0; nr < M; nr++, pntrb++, pntre++, B += ldb)
198     {
199       for (int i = *pntrb - p0; i < *pntre - p0; i++)
200       {
201         if (*(indx + i) + disp >= K)
202           continue;
203         // at this point *(A+i) is A_rc, c=*(indx+i)+disp
204         // C(c,:) = A_rc * B(r,:)
205         const T* Br = B;
206         T* Cc       = C + ldc * (*(indx + i) + disp);
207         T Arc       = alpha * (*(A + i));
208         for (int k = 0; k < N; k++, Cc++, Br++)
209           *Cc += Arc * (*Br);
210       }
211     }
212   }
213   else if (transa == 'h' || transa == 'H' || transa == 'c' || transa == 'C')
214   {
215     // not optimal, but simple
216     for (int i = 0; i < K; i++)
217       for (int j = 0; j < N; j++)
218         (*(C + i * ldc + j)) *= beta;
219     for (int nr = 0; nr < M; nr++, pntrb++, pntre++, B += ldb)
220     {
221       for (int i = *pntrb - p0; i < *pntre - p0; i++)
222       {
223         if (*(indx + i) + disp >= K)
224           continue;
225         // at this point *(A+i) is A_rc, c=*(indx+i)+disp
226         // C(c,:) = A_rc * B(r,:)
227         const T* Br = B;
228         T* Cc       = C + ldc * (*(indx + i) + disp);
229         T Arc       = alpha * (*(A + i));
230         for (int k = 0; k < N; k++, Cc++, Br++)
231           *Cc += Arc * (*Br);
232       }
233     }
234   }
235 }
236 
237 template<typename T>
csrmm(const char transa,const int M,const int N,const int K,const std::complex<T> alpha,const char * matdescra,const std::complex<T> * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<T> * B,const int ldb,const std::complex<T> beta,std::complex<T> * C,const int ldc)238 void csrmm(const char transa,
239            const int M,
240            const int N,
241            const int K,
242            const std::complex<T> alpha,
243            const char* matdescra,
244            const std::complex<T>* A,
245            const int* indx,
246            const int* pntrb,
247            const int* pntre,
248            const std::complex<T>* B,
249            const int ldb,
250            const std::complex<T> beta,
251            std::complex<T>* C,
252            const int ldc)
253 {
254   assert(matdescra[0] == 'G' && (matdescra[3] == 'C')); // || matdescra[3]=='F'));
255   int disp = (matdescra[3] == 'C') ? 0 : -1;
256   int p0   = *pntrb;
257   if (transa == 'n' || transa == 'N')
258   {
259     for (int nr = 0; nr < M; nr++, pntrb++, pntre++, C += ldc)
260     {
261       for (int i = 0; i < N; i++)
262         (*(C + i)) *= beta;
263       for (int i = *pntrb - p0; i < *pntre - p0; i++)
264       {
265         if (*(indx + i) + disp >= K)
266           continue;
267         // at this point *(A+i) is A_rc, c=*(indx+i)+disp, *C is C(r,0)
268         // C(r,:) = A_rc * B(c,:)
269         const std::complex<T>* Bc = B + ldb * (*(indx + i) + disp);
270         std::complex<T>* Cr       = C;
271         std::complex<T> Arc       = alpha * (*(A + i));
272         for (int k = 0; k < N; k++, Cr++, Bc++)
273           *Cr += Arc * (*Bc);
274       }
275     }
276   }
277   else if (transa == 't' || transa == 'T')
278   {
279     // not optimal, but simple
280     for (int i = 0; i < K; i++)
281       for (int j = 0; j < N; j++)
282         (*(C + i * ldc + j)) *= beta;
283     for (int nr = 0; nr < M; nr++, pntrb++, pntre++, B += ldb)
284     {
285       for (int i = *pntrb - p0; i < *pntre - p0; i++)
286       {
287         if (*(indx + i) + disp >= K)
288           continue;
289         // at this point *(A+i) is A_rc, c=*(indx+i)+disp
290         // C(c,:) = A_rc * B(r,:)
291         const std::complex<T>* Br = B;
292         std::complex<T>* Cc       = C + ldc * (*(indx + i) + disp);
293         std::complex<T> Arc       = alpha * (*(A + i));
294         for (int k = 0; k < N; k++, Cc++, Br++)
295           *Cc += Arc * (*Br);
296       }
297     }
298   }
299   else if (transa == 'h' || transa == 'H' || transa == 'c' || transa == 'C')
300   {
301     // not optimal, but simple
302     for (int i = 0; i < K; i++)
303       for (int j = 0; j < N; j++)
304         (*(C + i * ldc + j)) *= beta;
305     for (int nr = 0; nr < M; nr++, pntrb++, pntre++, B += ldb)
306     {
307       for (int i = *pntrb - p0; i < *pntre - p0; i++)
308       {
309         if (*(indx + i) + disp >= K)
310           continue;
311         // at this point *(A+i) is A_rc, c=*(indx+i)+disp
312         // C(c,:) = A_rc * B(r,:)
313         const std::complex<T>* Br = B;
314         std::complex<T>* Cc       = C + ldc * (*(indx + i) + disp);
315         std::complex<T> Arc       = alpha * ma::conj(*(A + i));
316         for (int k = 0; k < N; k++, Cc++, Br++)
317           *Cc += Arc * (*Br);
318       }
319     }
320   }
321 }
322 
323 } // namespace backup_impl
324 
csrmv(const char transa,const int M,const int K,const float alpha,const char * matdescra,const float * A,const int * indx,const int * pntrb,const int * pntre,const float * x,const float beta,float * y)325 inline static void csrmv(const char transa,
326                          const int M,
327                          const int K,
328                          const float alpha,
329                          const char* matdescra,
330                          const float* A,
331                          const int* indx,
332                          const int* pntrb,
333                          const int* pntre,
334                          const float* x,
335                          const float beta,
336                          float* y)
337 {
338 #if defined(HAVE_MKL)
339   mkl_scsrmv(transa, M, K, alpha, matdescra, A, indx, pntrb, pntre, x, beta, y);
340 #else
341   backup_impl::csrmv(transa, M, K, alpha, matdescra, A, indx, pntrb, pntre, x, beta, y);
342 #endif
343 }
344 
csrmv(const char transa,const int M,const int K,const double alpha,const char * matdescra,const double * A,const int * indx,const int * pntrb,const int * pntre,const double * x,const double beta,double * y)345 inline static void csrmv(const char transa,
346                          const int M,
347                          const int K,
348                          const double alpha,
349                          const char* matdescra,
350                          const double* A,
351                          const int* indx,
352                          const int* pntrb,
353                          const int* pntre,
354                          const double* x,
355                          const double beta,
356                          double* y)
357 {
358 #if defined(HAVE_MKL)
359   mkl_dcsrmv(transa, M, K, alpha, matdescra, A, indx, pntrb, pntre, x, beta, y);
360 #else
361   backup_impl::csrmv(transa, M, K, alpha, matdescra, A, indx, pntrb, pntre, x, beta, y);
362 #endif
363 }
364 
csrmv(const char transa,const int M,const int K,const std::complex<float> alpha,const char * matdescra,const std::complex<float> * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<float> * x,const std::complex<float> beta,std::complex<float> * y)365 inline static void csrmv(const char transa,
366                          const int M,
367                          const int K,
368                          const std::complex<float> alpha,
369                          const char* matdescra,
370                          const std::complex<float>* A,
371                          const int* indx,
372                          const int* pntrb,
373                          const int* pntre,
374                          const std::complex<float>* x,
375                          const std::complex<float> beta,
376                          std::complex<float>* y)
377 {
378 #if defined(HAVE_MKL)
379   mkl_ccsrmv(transa, M, K, alpha, matdescra, A, indx, pntrb, pntre, x, beta, y);
380 #else
381   backup_impl::csrmv(transa, M, K, alpha, matdescra, A, indx, pntrb, pntre, x, beta, y);
382 #endif
383 }
384 
csrmv(const char transa,const int M,const int K,const std::complex<double> alpha,const char * matdescra,const std::complex<double> * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<double> * x,const std::complex<double> beta,std::complex<double> * y)385 inline static void csrmv(const char transa,
386                          const int M,
387                          const int K,
388                          const std::complex<double> alpha,
389                          const char* matdescra,
390                          const std::complex<double>* A,
391                          const int* indx,
392                          const int* pntrb,
393                          const int* pntre,
394                          const std::complex<double>* x,
395                          const std::complex<double> beta,
396                          std::complex<double>* y)
397 {
398 #if defined(HAVE_MKL)
399   mkl_zcsrmv(transa, M, K, alpha, matdescra, A, indx, pntrb, pntre, x, beta, y);
400 #else
401   backup_impl::csrmv(transa, M, K, alpha, matdescra, A, indx, pntrb, pntre, x, beta, y);
402 #endif
403 }
404 
csrmm(const char transa,const int M,const int N,const int K,const float alpha,const char * matdescra,const float * A,const int * indx,const int * pntrb,const int * pntre,const float * B,const int ldb,const float beta,float * C,const int ldc)405 inline static void csrmm(const char transa,
406                          const int M,
407                          const int N,
408                          const int K,
409                          const float alpha,
410                          const char* matdescra,
411                          const float* A,
412                          const int* indx,
413                          const int* pntrb,
414                          const int* pntre,
415                          const float* B,
416                          const int ldb,
417                          const float beta,
418                          float* C,
419                          const int ldc)
420 {
421 #if defined(HAVE_MKL)
422   mkl_scsrmm(transa, M, N, K, alpha, matdescra, A, indx, pntrb, pntre, B, ldb, beta, C, ldc);
423 #else
424   backup_impl::csrmm(transa, M, N, K, alpha, matdescra, A, indx, pntrb, pntre, B, ldb, beta, C, ldc);
425 #endif
426 }
427 
csrmm(const char transa,const int M,const int N,const int K,const std::complex<float> alpha,const char * matdescra,const std::complex<float> * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<float> * B,const int ldb,const std::complex<float> beta,std::complex<float> * C,const int ldc)428 inline static void csrmm(const char transa,
429                          const int M,
430                          const int N,
431                          const int K,
432                          const std::complex<float> alpha,
433                          const char* matdescra,
434                          const std::complex<float>* A,
435                          const int* indx,
436                          const int* pntrb,
437                          const int* pntre,
438                          const std::complex<float>* B,
439                          const int ldb,
440                          const std::complex<float> beta,
441                          std::complex<float>* C,
442                          const int ldc)
443 {
444 #if defined(HAVE_MKL)
445   mkl_ccsrmm(transa, M, N, K, alpha, matdescra, A, indx, pntrb, pntre, B, ldb, beta, C, ldc);
446 #else
447   backup_impl::csrmm(transa, M, N, K, alpha, matdescra, A, indx, pntrb, pntre, B, ldb, beta, C, ldc);
448 #endif
449 }
450 
csrmm(const char transa,const int M,const int N,const int K,const double alpha,const char * matdescra,const double * A,const int * indx,const int * pntrb,const int * pntre,const double * B,const int ldb,const double beta,double * C,const int ldc)451 inline static void csrmm(const char transa,
452                          const int M,
453                          const int N,
454                          const int K,
455                          const double alpha,
456                          const char* matdescra,
457                          const double* A,
458                          const int* indx,
459                          const int* pntrb,
460                          const int* pntre,
461                          const double* B,
462                          const int ldb,
463                          const double beta,
464                          double* C,
465                          const int ldc)
466 {
467 #if defined(HAVE_MKL)
468   mkl_dcsrmm(transa, M, N, K, alpha, matdescra, A, indx, pntrb, pntre, B, ldb, beta, C, ldc);
469 #else
470   backup_impl::csrmm(transa, M, N, K, alpha, matdescra, A, indx, pntrb, pntre, B, ldb, beta, C, ldc);
471 #endif
472 }
473 
csrmm(const char transa,const int M,const int N,const int K,const std::complex<double> alpha,const char * matdescra,const std::complex<double> * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<double> * B,const int ldb,const std::complex<double> beta,std::complex<double> * C,const int ldc)474 inline static void csrmm(const char transa,
475                          const int M,
476                          const int N,
477                          const int K,
478                          const std::complex<double> alpha,
479                          const char* matdescra,
480                          const std::complex<double>* A,
481                          const int* indx,
482                          const int* pntrb,
483                          const int* pntre,
484                          const std::complex<double>* B,
485                          const int ldb,
486                          const std::complex<double> beta,
487                          std::complex<double>* C,
488                          const int ldc)
489 {
490 #if defined(HAVE_MKL)
491   mkl_zcsrmm(transa, M, N, K, alpha, matdescra, A, indx, pntrb, pntre, B, ldb, beta, C, ldc);
492 #else
493   backup_impl::csrmm(transa, M, N, K, alpha, matdescra, A, indx, pntrb, pntre, B, ldb, beta, C, ldc);
494 #endif
495 }
496 
csrmv(const char transa,const int M,const int K,float alpha,const char * matdescra,float * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<float> * x,const float beta,std::complex<float> * y)497 inline static void csrmv(const char transa,
498                          const int M,
499                          const int K,
500                          float alpha,
501                          const char* matdescra,
502                          float* A,
503                          const int* indx,
504                          const int* pntrb,
505                          const int* pntre,
506                          const std::complex<float>* x,
507                          const float beta,
508                          std::complex<float>* y)
509 {
510   assert(matdescra[0] == 'G' && (matdescra[3] == 'C'));
511   csrmm(transa, M, 2, K, alpha, matdescra, A, indx, pntrb, pntre, reinterpret_cast<float const*>(x), 2, beta,
512         reinterpret_cast<float*>(y), 2);
513 }
514 
csrmv(const char transa,const int M,const int K,double alpha,const char * matdescra,double * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<double> * x,const double beta,std::complex<double> * y)515 inline static void csrmv(const char transa,
516                          const int M,
517                          const int K,
518                          double alpha,
519                          const char* matdescra,
520                          double* A,
521                          const int* indx,
522                          const int* pntrb,
523                          const int* pntre,
524                          const std::complex<double>* x,
525                          const double beta,
526                          std::complex<double>* y)
527 {
528   assert(matdescra[0] == 'G' && (matdescra[3] == 'C'));
529   csrmm(transa, M, 2, K, alpha, matdescra, A, indx, pntrb, pntre, reinterpret_cast<double const*>(x), 2, beta,
530         reinterpret_cast<double*>(y), 2);
531 }
532 
csrmm(const char transa,const int M,const int N,const int K,const double alpha,const char * matdescra,const double * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<double> * B,const int ldb,const double beta,std::complex<double> * C,const int ldc)533 inline static void csrmm(const char transa,
534                          const int M,
535                          const int N,
536                          const int K,
537                          const double alpha,
538                          const char* matdescra,
539                          const double* A,
540                          const int* indx,
541                          const int* pntrb,
542                          const int* pntre,
543                          const std::complex<double>* B,
544                          const int ldb,
545                          const double beta,
546                          std::complex<double>* C,
547                          const int ldc)
548 {
549   assert(matdescra[0] == 'G' && (matdescra[3] == 'C'));
550   csrmm(transa, M, 2 * N, K, alpha, matdescra, A, indx, pntrb, pntre, reinterpret_cast<double const*>(B), 2 * ldb, beta,
551         reinterpret_cast<double*>(C), 2 * ldc);
552 }
553 
csrmm(const char transa,const int M,const int N,const int K,const float alpha,const char * matdescra,const float * A,const int * indx,const int * pntrb,const int * pntre,const std::complex<float> * B,const int ldb,const float beta,std::complex<float> * C,const int ldc)554 inline static void csrmm(const char transa,
555                          const int M,
556                          const int N,
557                          const int K,
558                          const float alpha,
559                          const char* matdescra,
560                          const float* A,
561                          const int* indx,
562                          const int* pntrb,
563                          const int* pntre,
564                          const std::complex<float>* B,
565                          const int ldb,
566                          const float beta,
567                          std::complex<float>* C,
568                          const int ldc)
569 {
570   assert(matdescra[0] == 'G' && (matdescra[3] == 'C'));
571   csrmm(transa, M, 2 * N, K, alpha, matdescra, A, indx, pntrb, pntre, reinterpret_cast<float const*>(B), 2 * ldb, beta,
572         reinterpret_cast<float*>(C), 2 * ldc);
573 }
574 
575 
576 } // namespace ma
577 #endif
578