1 //////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source
3 // License.  See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
6 //
7 // File developed by:
8 //    Lawrence Livermore National Laboratory
9 //
10 // File created by:
11 // Miguel A. Morales, moralessilva2@llnl.gov
12 //    Lawrence Livermore National Laboratory
13 ////////////////////////////////////////////////////////////////////////////////
14 
15 #ifndef AFQMC_BLAS_CUDA_HPP
16 #define AFQMC_BLAS_CUDA_HPP
17 
18 #include <cassert>
19 #include <vector>
20 #include "AFQMC/Memory/CUDA/cuda_gpu_pointer.hpp"
21 #include "AFQMC/Numerics/detail/CUDA/cublas_wrapper.hpp"
22 #include "AFQMC/Numerics/detail/CUDA/cublasXt_wrapper.hpp"
23 // hand coded kernels for blas extensions
24 #include "AFQMC/Numerics/detail/CUDA/Kernels/adotpby.cuh"
25 #include "AFQMC/Numerics/detail/CUDA/Kernels/axty.cuh"
26 #include "AFQMC/Numerics/detail/CUDA/Kernels/sum.cuh"
27 #include "AFQMC/Numerics/detail/CUDA/Kernels/adiagApy.cuh"
28 #include "AFQMC/Numerics/detail/CUDA/Kernels/acAxpbB.cuh"
29 
30 // Currently available:
31 // Lvl-1: dot, axpy, scal
32 // Lvl-2: gemv
33 // Lvl-3: gemm
34 
35 namespace qmc_cuda
36 {
37 // copy Specializations
38 template<class ptr, typename = typename std::enable_if_t<(ptr::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
copy(int n,ptr x,int incx,ptr y,int incy)39 inline static void copy(int n, ptr x, int incx, ptr y, int incy)
40 {
41   if (CUBLAS_STATUS_SUCCESS !=
42       cublas::cublas_copy(*x.handles.cublas_handle, n, to_address(x), incx, to_address(y), incy))
43     throw std::runtime_error("Error: cublas_copy returned error code.");
44 }
45 
46 template<class ptr,
47          typename = typename std::enable_if_t<(ptr::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
48          typename = void>
copy(int n,ptr x,int incx,ptr y,int incy)49 inline static void copy(int n, ptr x, int incx, ptr y, int incy)
50 {
51   using ma::copy;
52   return copy(n, to_address(x), incx, to_address(y), incy);
53 }
54 
55 // scal Specializations
56 template<class T, class ptr, typename = typename std::enable_if_t<(ptr::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
scal(int n,T alpha,ptr x,int incx)57 inline static void scal(int n, T alpha, ptr x, int incx)
58 {
59   if (CUBLAS_STATUS_SUCCESS != cublas::cublas_scal(*x.handles.cublas_handle, n, alpha, to_address(x), incx))
60     throw std::runtime_error("Error: cublas_scal returned error code.");
61 }
62 
63 template<class T,
64          class ptr,
65          typename = typename std::enable_if_t<(ptr::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
66          typename = void>
scal(int n,T alpha,ptr x,int incx)67 inline static void scal(int n, T alpha, ptr x, int incx)
68 {
69   using ma::scal;
70   return scal(n, alpha, to_address(x), incx);
71 }
72 
73 // dot Specializations
74 template<class ptrA,
75          class ptrB,
76          typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
77                                               (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
dot(int const n,ptrA const & x,int const incx,ptrB const & y,int const incy)78 inline static auto dot(int const n, ptrA const& x, int const incx, ptrB const& y, int const incy)
79 {
80   return cublas::cublas_dot(*x.handles.cublas_handle, n, to_address(x), incx, to_address(y), incy);
81 }
82 
83 template<class ptrA,
84          class ptrB,
85          typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
86                                               (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
87          typename = void>
dot(int const n,ptrA const & x,int const incx,ptrB const & y,int const incy)88 inline static auto dot(int const n, ptrA const& x, int const incx, ptrB const& y, int const incy)
89 {
90   using ma::dot;
91   return dot(n, to_address(x), incx, to_address(y), incy);
92 }
93 
94 // axpy Specializations
95 template<typename T,
96          class ptrA,
97          class ptrB,
98          typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
99                                               (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
axpy(int n,T const a,ptrA const & x,int incx,ptrB && y,int incy)100 inline static void axpy(int n, T const a, ptrA const& x, int incx, ptrB&& y, int incy)
101 {
102   if (CUBLAS_STATUS_SUCCESS !=
103       cublas::cublas_axpy(*x.handles.cublas_handle, n, a, to_address(x), incx, to_address(y), incy))
104     throw std::runtime_error("Error: cublas_axpy returned error code.");
105 }
106 
107 template<typename T,
108          class ptrA,
109          class ptrB,
110          typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
111                                               (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
112          typename = void>
axpy(int n,T const a,ptrA const & x,int incx,ptrB && y,int incy)113 inline static void axpy(int n, T const a, ptrA const& x, int incx, ptrB&& y, int incy)
114 {
115   using ma::axpy;
116   axpy(n, a, to_address(x), incx, to_address(y), incy);
117 }
118 
119 // GEMV Specializations
120 template<typename T,
121          class ptrA,
122          class ptrB,
123          class ptrC,
124          typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
125                                               (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
126                                               (ptrC::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
gemv(char Atrans,int M,int N,T alpha,ptrA const & A,int lda,ptrB const & x,int incx,T beta,ptrC && y,int incy)127 inline static void gemv(char Atrans,
128                         int M,
129                         int N,
130                         T alpha,
131                         ptrA const& A,
132                         int lda,
133                         ptrB const& x,
134                         int incx,
135                         T beta,
136                         ptrC&& y,
137                         int incy)
138 {
139   if (CUBLAS_STATUS_SUCCESS !=
140       cublas::cublas_gemv(*A.handles.cublas_handle, Atrans, M, N, alpha, to_address(A), lda, to_address(x), incx, beta,
141                           to_address(y), incy))
142     throw std::runtime_error("Error: cublas_gemv returned error code.");
143 }
144 
145 template<typename T,
146          class ptrA,
147          class ptrB,
148          class ptrC,
149          typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
150                                               (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
151                                               (ptrC::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
152          typename = void>
gemv(char Atrans,int M,int N,T alpha,ptrA const & A,int lda,ptrB const & x,int incx,T beta,ptrC && y,int incy)153 inline static void gemv(char Atrans,
154                         int M,
155                         int N,
156                         T alpha,
157                         ptrA const& A,
158                         int lda,
159                         ptrB const& x,
160                         int incx,
161                         T beta,
162                         ptrC&& y,
163                         int incy)
164 {
165   using ma::gemv;
166   gemv(Atrans, M, N, alpha, to_address(A), lda, to_address(x), incx, beta, to_address(y), incy);
167   /*
168     const char Btrans('N');
169     const int one(1);
170     if(CUBLAS_STATUS_SUCCESS != cublas::cublasXt_gemm(*A.handles.cublasXt_handle,Atrans,Btrans,
171                                             M,one,K,alpha,to_address(A),lda,to_address(x),incx,
172                                             beta,to_address(y),incy))
173       throw std::runtime_error("Error: cublasXt_gemv (gemm) returned error code.");
174 */
175 }
176 
177 // GEMM Specializations
178 template<typename T,
179          class ptrA,
180          class ptrB,
181          class ptrC,
182          typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
183                                               (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
184                                               (ptrC::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
gemm(char Atrans,char Btrans,int M,int N,int K,T alpha,ptrA const & A,int lda,ptrB const & B,int ldb,T beta,ptrC && C,int ldc)185 inline static void gemm(char Atrans,
186                         char Btrans,
187                         int M,
188                         int N,
189                         int K,
190                         T alpha,
191                         ptrA const& A,
192                         int lda,
193                         ptrB const& B,
194                         int ldb,
195                         T beta,
196                         ptrC&& C,
197                         int ldc)
198 {
199   if (CUBLAS_STATUS_SUCCESS !=
200       cublas::cublas_gemm(*A.handles.cublas_handle, Atrans, Btrans, M, N, K, alpha, to_address(A), lda, to_address(B),
201                           ldb, beta, to_address(C), ldc))
202     throw std::runtime_error("Error: cublas_gemm returned error code.");
203 }
204 
205 template<typename T,
206          class ptrA,
207          class ptrB,
208          class ptrC,
209          typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
210                                               (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
211                                               (ptrC::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
212          typename = void>
gemm(char Atrans,char Btrans,int M,int N,int K,T alpha,ptrA const & A,int lda,ptrB const & B,int ldb,T beta,ptrC && C,int ldc)213 inline static void gemm(char Atrans,
214                         char Btrans,
215                         int M,
216                         int N,
217                         int K,
218                         T alpha,
219                         ptrA const& A,
220                         int lda,
221                         ptrB const& B,
222                         int ldb,
223                         T beta,
224                         ptrC&& C,
225                         int ldc)
226 {
227   if (CUBLAS_STATUS_SUCCESS !=
228       cublas::cublasXt_gemm(*A.handles.cublasXt_handle, Atrans, Btrans, M, N, K, alpha, to_address(A), lda,
229                             to_address(B), ldb, beta, to_address(C), ldc))
230     throw std::runtime_error("Error: cublasXt_gemm returned error code.");
231 }
232 
233 // Blas Extensions
234 // geam
235 template<class T,
236          class ptrA,
237          class ptrB,
238          class ptrC,
239          typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
240                                               (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
241                                               (ptrC::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
geam(char Atrans,char Btrans,int M,int N,T const alpha,ptrA const & A,int lda,T const beta,ptrB const & B,int ldb,ptrC C,int ldc)242 inline static void geam(char Atrans,
243                         char Btrans,
244                         int M,
245                         int N,
246                         T const alpha,
247                         ptrA const& A,
248                         int lda,
249                         T const beta,
250                         ptrB const& B,
251                         int ldb,
252                         ptrC C,
253                         int ldc)
254 {
255   if (CUBLAS_STATUS_SUCCESS !=
256       cublas::cublas_geam(*A.handles.cublas_handle, Atrans, Btrans, M, N, alpha, to_address(A), lda, beta,
257                           to_address(B), ldb, to_address(C), ldc))
258     throw std::runtime_error("Error: cublas_geam returned error code.");
259 }
260 
261 template<class T,
262          class ptrA,
263          class ptrB,
264          class ptrC,
265          typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
266                                               (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
267                                               (ptrC::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
268          typename = void>
geam(char Atrans,char Btrans,int M,int N,T const alpha,ptrA const & A,int lda,T const beta,ptrB const & B,int ldb,ptrC C,int ldc)269 inline static void geam(char Atrans,
270                         char Btrans,
271                         int M,
272                         int N,
273                         T const alpha,
274                         ptrA const& A,
275                         int lda,
276                         T const beta,
277                         ptrB const& B,
278                         int ldb,
279                         ptrC C,
280                         int ldc)
281 {
282   using ma::geam;
283   return geam(Atrans, Btrans, M, N, alpha, to_address(A), lda, beta, to_address(B), ldb, to_address(C), ldc);
284 }
285 
286 //template<class T,
287 template<class ptr,
288          typename = typename std::enable_if_t<not(ptr::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
289          typename = void>
290 //inline static void set1D(int n, T const alpha, ptr x, int incx)
set1D(int n,typename ptr::value_type const alpha,ptr x,int incx)291 inline static void set1D(int n, typename ptr::value_type const alpha, ptr x, int incx)
292 {
293   // No set funcion in cuda!!! Avoiding kernels for now
294   std::vector<typename ptr::value_type> buff(n, alpha);
295   if (CUBLAS_STATUS_SUCCESS !=
296       cublasSetVector(n, sizeof(typename ptr::value_type), buff.data(), 1, to_address(x), incx))
297     throw std::runtime_error("Error: cublasSetVector returned error code.");
298 }
299 
300 template<class T, class ptr, typename = typename std::enable_if_t<(ptr::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>>
set1D(int n,T const alpha,ptr x,int incx)301 inline static void set1D(int n, T const alpha, ptr x, int incx)
302 {
303   auto y = to_address(x);
304   for (int i = 0; i < n; i++, y += incx)
305     *y = alpha;
306 }
307 
308 // dot extension
309 template<class T,
310          class Q,
311          class ptrA,
312          class ptrB,
313          class ptrC,
314          typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
315                                               (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
316                                               (ptrC::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
adotpby(int const n,T const alpha,ptrA const & x,int const incx,ptrB const & y,int const incy,Q const beta,ptrC result)317 inline static void adotpby(int const n,
318                            T const alpha,
319                            ptrA const& x,
320                            int const incx,
321                            ptrB const& y,
322                            int const incy,
323                            Q const beta,
324                            ptrC result)
325 {
326   kernels::adotpby(n, alpha, to_address(x), incx, to_address(y), incy, beta, to_address(result));
327 }
328 
329 template<class T,
330          class Q,
331          class ptrA,
332          class ptrB,
333          class ptrC,
334          typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
335                                               (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
336                                               (ptrC::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
337          typename = void>
adotpby(int const n,T const alpha,ptrA const & x,int const incx,ptrB const & y,int const incy,Q const beta,ptrC result)338 inline static void adotpby(int const n,
339                            T const alpha,
340                            ptrA const& x,
341                            int const incx,
342                            ptrB const& y,
343                            int const incy,
344                            Q const beta,
345                            ptrC result)
346 {
347   using ma::adotpby;
348   adotpby(n, alpha, to_address(x), incx, to_address(y), incy, beta, to_address(result));
349 }
350 
351 
352 // axty
353 template<class T,
354          class ptrA,
355          class ptrB,
356          typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
357                                               (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
axty(int n,T const alpha,ptrA const x,int incx,ptrB y,int incy)358 inline static void axty(int n, T const alpha, ptrA const x, int incx, ptrB y, int incy)
359 {
360   if (incx != 1 || incy != 1)
361     throw std::runtime_error("Error: axty with inc != 1 not implemented.");
362   kernels::axty(n, alpha, to_address(x), to_address(y));
363 }
364 
365 template<class T,
366          class ptrA,
367          class ptrB,
368          typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) and
369                                               (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
370          typename = void>
axty(int n,T const alpha,ptrA const x,int incx,ptrB y,int incy)371 inline static void axty(int n, T const alpha, ptrA const x, int incx, ptrB y, int incy)
372 {
373   using ma::axty;
374   axty(n, alpha, to_address(x), incx, to_address(y), incy);
375 }
376 
377 // acAxpbB
378 template<class T,
379          class ptrA,
380          class ptrx,
381          class ptrB,
382          typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
383                                               (ptrx::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
384                                               (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
acAxpbB(int m,int n,T const alpha,ptrA const A,int lda,ptrx const x,int incx,T const beta,ptrB B,int ldb)385 inline static void acAxpbB(int m,
386                            int n,
387                            T const alpha,
388                            ptrA const A,
389                            int lda,
390                            ptrx const x,
391                            int incx,
392                            T const beta,
393                            ptrB B,
394                            int ldb)
395 {
396   kernels::acAxpbB(m, n, alpha, to_address(A), lda, to_address(x), incx, beta, to_address(B), ldb);
397 }
398 
399 template<class T,
400          class ptrA,
401          class ptrx,
402          class ptrB,
403          typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) and
404                                               (ptrx::memory_type == CPU_OUTOFCARS_POINTER_TYPE) and
405                                               (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
406          typename = void>
acAxpbB(int m,int n,T const alpha,ptrA const A,int lda,ptrx const x,int incx,T const beta,ptrB B,int ldb)407 inline static void acAxpbB(int m,
408                            int n,
409                            T const alpha,
410                            ptrA const A,
411                            int lda,
412                            ptrx const x,
413                            int incx,
414                            T const beta,
415                            ptrB B,
416                            int ldb)
417 {
418   using ma::acAxpbB;
419   acAxpbB(m, n, alpha, to_address(A), lda, to_address(x), incx, beta, to_address(B), ldb);
420 }
421 
422 // adiagApy
423 template<class T,
424          class ptrA,
425          class ptrB,
426          typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
427                                               (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
adiagApy(int n,T const alpha,ptrA const A,int lda,ptrB y,int incy)428 inline static void adiagApy(int n, T const alpha, ptrA const A, int lda, ptrB y, int incy)
429 {
430   kernels::adiagApy(n, alpha, to_address(A), lda, to_address(y), incy);
431 }
432 
433 template<class T,
434          class ptrA,
435          class ptrB,
436          typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
437                                               (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
438          typename = void>
adiagApy(int n,T const alpha,ptrA const A,int lda,ptrB y,int incy)439 inline static void adiagApy(int n, T const alpha, ptrA const A, int lda, ptrB y, int incy)
440 {
441   using ma::adiagApy;
442   adiagApy(n, alpha, to_address(A), lda, to_address(y), incy);
443 }
444 
445 template<class ptr, typename = typename std::enable_if_t<(ptr::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
sum(int n,ptr const x,int incx)446 inline static auto sum(int n, ptr const x, int incx)
447 {
448   return kernels::sum(n, to_address(x), incx);
449 }
450 
451 template<class ptr, typename = typename std::enable_if_t<(ptr::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
sum(int m,int n,ptr const A,int lda)452 inline static auto sum(int m, int n, ptr const A, int lda)
453 {
454   return kernels::sum(m, n, to_address(A), lda);
455 }
456 
457 template<class ptr,
458          typename = typename std::enable_if_t<(ptr::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
459          typename = void>
sum(int n,ptr const x,int incx)460 inline static auto sum(int n, ptr const x, int incx)
461 {
462   using ma::sum;
463   return sum(n, to_address(x), incx);
464 }
465 
466 template<class ptr,
467          typename = typename std::enable_if_t<(ptr::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
468          typename = void>
sum(int m,int n,ptr const A,int lda)469 inline static auto sum(int m, int n, ptr const A, int lda)
470 {
471   using ma::sum;
472   return sum(m, n, to_address(A), lda);
473 }
474 
475 template<class T,
476          class ptrA,
477          class ptrB,
478          class ptrC,
479          typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
480                                               (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
481                                               (ptrC::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
gemmStridedBatched(char Atrans,char Btrans,int M,int N,int K,T const alpha,ptrA const A,int lda,int strideA,ptrB const B,int ldb,int strideB,T beta,ptrC C,int ldc,int strideC,int batchSize)482 inline static void gemmStridedBatched(char Atrans,
483                                       char Btrans,
484                                       int M,
485                                       int N,
486                                       int K,
487                                       T const alpha,
488                                       ptrA const A,
489                                       int lda,
490                                       int strideA,
491                                       ptrB const B,
492                                       int ldb,
493                                       int strideB,
494                                       T beta,
495                                       ptrC C,
496                                       int ldc,
497                                       int strideC,
498                                       int batchSize)
499 {
500   cublas::cublas_gemmStridedBatched(*A.handles.cublas_handle, Atrans, Btrans, M, N, K, alpha, to_address(A), lda,
501                                     strideA, to_address(B), ldb, strideB, beta, to_address(C), ldc, strideC, batchSize);
502 }
503 
504 template<class T,
505          class ptrA,
506          class ptrB,
507          class ptrC,
508          typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) and
509                                               (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE) and
510                                               (ptrC::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
511          typename = void>
gemmStridedBatched(char Atrans,char Btrans,int M,int N,int K,T const alpha,ptrA const A,int lda,int strideA,ptrB const B,int ldb,int strideB,T beta,ptrC C,int ldc,int strideC,int batchSize)512 inline static void gemmStridedBatched(char Atrans,
513                                       char Btrans,
514                                       int M,
515                                       int N,
516                                       int K,
517                                       T const alpha,
518                                       ptrA const A,
519                                       int lda,
520                                       int strideA,
521                                       ptrB const B,
522                                       int ldb,
523                                       int strideB,
524                                       T beta,
525                                       ptrC C,
526                                       int ldc,
527                                       int strideC,
528                                       int batchSize)
529 {
530   using ma::gemmStridedBatched;
531   gemmStridedBatched(Atrans, Btrans, M, N, K, alpha, to_address(A), lda, strideA, to_address(B), ldb, strideB, beta,
532                      to_address(C), ldc, strideC, batchSize);
533 }
534 
535 template<class T,
536          class ptrA,
537          class ptrB,
538          class ptrC,
539          typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
540                                               (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
541                                               (ptrC::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
gemmBatched(char Atrans,char Btrans,int M,int N,int K,T const alpha,ptrA const * A,int lda,ptrB const * B,int ldb,T beta,ptrC * C,int ldc,int batchSize)542 inline static void gemmBatched(char Atrans,
543                                char Btrans,
544                                int M,
545                                int N,
546                                int K,
547                                T const alpha,
548                                ptrA const* A,
549                                int lda,
550                                ptrB const* B,
551                                int ldb,
552                                T beta,
553                                ptrC* C,
554                                int ldc,
555                                int batchSize)
556 {
557   using Q = typename ptrA::value_type;
558   Q **A_d, **B_d, **C_d;
559   Q **A_h, **B_h, **C_h;
560   A_h = new Q*[batchSize];
561   B_h = new Q*[batchSize];
562   C_h = new Q*[batchSize];
563   for (int i = 0; i < batchSize; i++)
564   {
565     A_h[i] = to_address(A[i]);
566     B_h[i] = to_address(B[i]);
567     C_h[i] = to_address(C[i]);
568   }
569   arch::malloc((void**)&A_d, batchSize * sizeof(*A_h));
570   arch::malloc((void**)&B_d, batchSize * sizeof(*B_h));
571   arch::malloc((void**)&C_d, batchSize * sizeof(*C_h));
572   arch::memcopy(A_d, A_h, batchSize * sizeof(*A_h), arch::memcopyH2D);
573   arch::memcopy(B_d, B_h, batchSize * sizeof(*B_h), arch::memcopyH2D);
574   arch::memcopy(C_d, C_h, batchSize * sizeof(*C_h), arch::memcopyH2D);
575   cublas::cublas_gemmBatched(*(A[0]).handles.cublas_handle, Atrans, Btrans, M, N, K, alpha, A_d, lda, B_d, ldb, beta,
576                              C_d, ldc, batchSize);
577   arch::free(A_d);
578   arch::free(B_d);
579   arch::free(C_d);
580   delete[] A_h;
581   delete[] B_h;
582   delete[] C_h;
583 }
584 
585 template<class T,
586          class ptrA,
587          class ptrB,
588          class ptrC,
589          typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) and
590                                               (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE) and
591                                               (ptrC::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
592          typename = void>
gemmBatched(char Atrans,char Btrans,int M,int N,int K,T const alpha,ptrA const * A,int lda,ptrB const * B,int ldb,T beta,ptrC * C,int ldc,int batchSize)593 inline static void gemmBatched(char Atrans,
594                                char Btrans,
595                                int M,
596                                int N,
597                                int K,
598                                T const alpha,
599                                ptrA const* A,
600                                int lda,
601                                ptrB const* B,
602                                int ldb,
603                                T beta,
604                                ptrC* C,
605                                int ldc,
606                                int batchSize)
607 {
608   using Q = typename ptrA::value_type;
609   Q** A_d;
610   Q** B_d;
611   Q** C_d;
612   A_d = new Q*[batchSize];
613   B_d = new Q*[batchSize];
614   C_d = new Q*[batchSize];
615   for (int i = 0; i < batchSize; i++)
616   {
617     A_d[i] = to_address(A[i]);
618     B_d[i] = to_address(B[i]);
619     C_d[i] = to_address(C[i]);
620   }
621   using ma::gemmBatched;
622   gemmBatched(Atrans, Btrans, M, N, K, alpha, A_d, lda, B_d, ldb, beta, C_d, ldc, batchSize);
623   delete[] A_d;
624   delete[] B_d;
625   delete[] C_d;
626 }
627 
628 } // namespace qmc_cuda
629 
630 #endif
631