1 //////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source
3 // License. See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
6 //
7 // File developed by:
8 // Lawrence Livermore National Laboratory
9 //
10 // File created by:
11 // Miguel A. Morales, moralessilva2@llnl.gov
12 // Lawrence Livermore National Laboratory
13 ////////////////////////////////////////////////////////////////////////////////
14
15 #ifndef AFQMC_BLAS_CUDA_HPP
16 #define AFQMC_BLAS_CUDA_HPP
17
18 #include <cassert>
19 #include <vector>
20 #include "AFQMC/Memory/CUDA/cuda_gpu_pointer.hpp"
21 #include "AFQMC/Numerics/detail/CUDA/cublas_wrapper.hpp"
22 #include "AFQMC/Numerics/detail/CUDA/cublasXt_wrapper.hpp"
23 // hand coded kernels for blas extensions
24 #include "AFQMC/Numerics/detail/CUDA/Kernels/adotpby.cuh"
25 #include "AFQMC/Numerics/detail/CUDA/Kernels/axty.cuh"
26 #include "AFQMC/Numerics/detail/CUDA/Kernels/sum.cuh"
27 #include "AFQMC/Numerics/detail/CUDA/Kernels/adiagApy.cuh"
28 #include "AFQMC/Numerics/detail/CUDA/Kernels/acAxpbB.cuh"
29
30 // Currently available:
31 // Lvl-1: dot, axpy, scal
32 // Lvl-2: gemv
33 // Lvl-3: gemm
34
35 namespace qmc_cuda
36 {
37 // copy Specializations
38 template<class ptr, typename = typename std::enable_if_t<(ptr::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
copy(int n,ptr x,int incx,ptr y,int incy)39 inline static void copy(int n, ptr x, int incx, ptr y, int incy)
40 {
41 if (CUBLAS_STATUS_SUCCESS !=
42 cublas::cublas_copy(*x.handles.cublas_handle, n, to_address(x), incx, to_address(y), incy))
43 throw std::runtime_error("Error: cublas_copy returned error code.");
44 }
45
46 template<class ptr,
47 typename = typename std::enable_if_t<(ptr::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
48 typename = void>
copy(int n,ptr x,int incx,ptr y,int incy)49 inline static void copy(int n, ptr x, int incx, ptr y, int incy)
50 {
51 using ma::copy;
52 return copy(n, to_address(x), incx, to_address(y), incy);
53 }
54
55 // scal Specializations
56 template<class T, class ptr, typename = typename std::enable_if_t<(ptr::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
scal(int n,T alpha,ptr x,int incx)57 inline static void scal(int n, T alpha, ptr x, int incx)
58 {
59 if (CUBLAS_STATUS_SUCCESS != cublas::cublas_scal(*x.handles.cublas_handle, n, alpha, to_address(x), incx))
60 throw std::runtime_error("Error: cublas_scal returned error code.");
61 }
62
63 template<class T,
64 class ptr,
65 typename = typename std::enable_if_t<(ptr::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
66 typename = void>
scal(int n,T alpha,ptr x,int incx)67 inline static void scal(int n, T alpha, ptr x, int incx)
68 {
69 using ma::scal;
70 return scal(n, alpha, to_address(x), incx);
71 }
72
73 // dot Specializations
74 template<class ptrA,
75 class ptrB,
76 typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
77 (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
dot(int const n,ptrA const & x,int const incx,ptrB const & y,int const incy)78 inline static auto dot(int const n, ptrA const& x, int const incx, ptrB const& y, int const incy)
79 {
80 return cublas::cublas_dot(*x.handles.cublas_handle, n, to_address(x), incx, to_address(y), incy);
81 }
82
83 template<class ptrA,
84 class ptrB,
85 typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
86 (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
87 typename = void>
dot(int const n,ptrA const & x,int const incx,ptrB const & y,int const incy)88 inline static auto dot(int const n, ptrA const& x, int const incx, ptrB const& y, int const incy)
89 {
90 using ma::dot;
91 return dot(n, to_address(x), incx, to_address(y), incy);
92 }
93
94 // axpy Specializations
95 template<typename T,
96 class ptrA,
97 class ptrB,
98 typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
99 (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
axpy(int n,T const a,ptrA const & x,int incx,ptrB && y,int incy)100 inline static void axpy(int n, T const a, ptrA const& x, int incx, ptrB&& y, int incy)
101 {
102 if (CUBLAS_STATUS_SUCCESS !=
103 cublas::cublas_axpy(*x.handles.cublas_handle, n, a, to_address(x), incx, to_address(y), incy))
104 throw std::runtime_error("Error: cublas_axpy returned error code.");
105 }
106
107 template<typename T,
108 class ptrA,
109 class ptrB,
110 typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
111 (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
112 typename = void>
axpy(int n,T const a,ptrA const & x,int incx,ptrB && y,int incy)113 inline static void axpy(int n, T const a, ptrA const& x, int incx, ptrB&& y, int incy)
114 {
115 using ma::axpy;
116 axpy(n, a, to_address(x), incx, to_address(y), incy);
117 }
118
119 // GEMV Specializations
120 template<typename T,
121 class ptrA,
122 class ptrB,
123 class ptrC,
124 typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
125 (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
126 (ptrC::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
gemv(char Atrans,int M,int N,T alpha,ptrA const & A,int lda,ptrB const & x,int incx,T beta,ptrC && y,int incy)127 inline static void gemv(char Atrans,
128 int M,
129 int N,
130 T alpha,
131 ptrA const& A,
132 int lda,
133 ptrB const& x,
134 int incx,
135 T beta,
136 ptrC&& y,
137 int incy)
138 {
139 if (CUBLAS_STATUS_SUCCESS !=
140 cublas::cublas_gemv(*A.handles.cublas_handle, Atrans, M, N, alpha, to_address(A), lda, to_address(x), incx, beta,
141 to_address(y), incy))
142 throw std::runtime_error("Error: cublas_gemv returned error code.");
143 }
144
145 template<typename T,
146 class ptrA,
147 class ptrB,
148 class ptrC,
149 typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
150 (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
151 (ptrC::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
152 typename = void>
gemv(char Atrans,int M,int N,T alpha,ptrA const & A,int lda,ptrB const & x,int incx,T beta,ptrC && y,int incy)153 inline static void gemv(char Atrans,
154 int M,
155 int N,
156 T alpha,
157 ptrA const& A,
158 int lda,
159 ptrB const& x,
160 int incx,
161 T beta,
162 ptrC&& y,
163 int incy)
164 {
165 using ma::gemv;
166 gemv(Atrans, M, N, alpha, to_address(A), lda, to_address(x), incx, beta, to_address(y), incy);
167 /*
168 const char Btrans('N');
169 const int one(1);
170 if(CUBLAS_STATUS_SUCCESS != cublas::cublasXt_gemm(*A.handles.cublasXt_handle,Atrans,Btrans,
171 M,one,K,alpha,to_address(A),lda,to_address(x),incx,
172 beta,to_address(y),incy))
173 throw std::runtime_error("Error: cublasXt_gemv (gemm) returned error code.");
174 */
175 }
176
177 // GEMM Specializations
178 template<typename T,
179 class ptrA,
180 class ptrB,
181 class ptrC,
182 typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
183 (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
184 (ptrC::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
gemm(char Atrans,char Btrans,int M,int N,int K,T alpha,ptrA const & A,int lda,ptrB const & B,int ldb,T beta,ptrC && C,int ldc)185 inline static void gemm(char Atrans,
186 char Btrans,
187 int M,
188 int N,
189 int K,
190 T alpha,
191 ptrA const& A,
192 int lda,
193 ptrB const& B,
194 int ldb,
195 T beta,
196 ptrC&& C,
197 int ldc)
198 {
199 if (CUBLAS_STATUS_SUCCESS !=
200 cublas::cublas_gemm(*A.handles.cublas_handle, Atrans, Btrans, M, N, K, alpha, to_address(A), lda, to_address(B),
201 ldb, beta, to_address(C), ldc))
202 throw std::runtime_error("Error: cublas_gemm returned error code.");
203 }
204
205 template<typename T,
206 class ptrA,
207 class ptrB,
208 class ptrC,
209 typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
210 (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
211 (ptrC::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
212 typename = void>
gemm(char Atrans,char Btrans,int M,int N,int K,T alpha,ptrA const & A,int lda,ptrB const & B,int ldb,T beta,ptrC && C,int ldc)213 inline static void gemm(char Atrans,
214 char Btrans,
215 int M,
216 int N,
217 int K,
218 T alpha,
219 ptrA const& A,
220 int lda,
221 ptrB const& B,
222 int ldb,
223 T beta,
224 ptrC&& C,
225 int ldc)
226 {
227 if (CUBLAS_STATUS_SUCCESS !=
228 cublas::cublasXt_gemm(*A.handles.cublasXt_handle, Atrans, Btrans, M, N, K, alpha, to_address(A), lda,
229 to_address(B), ldb, beta, to_address(C), ldc))
230 throw std::runtime_error("Error: cublasXt_gemm returned error code.");
231 }
232
233 // Blas Extensions
234 // geam
235 template<class T,
236 class ptrA,
237 class ptrB,
238 class ptrC,
239 typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
240 (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
241 (ptrC::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
geam(char Atrans,char Btrans,int M,int N,T const alpha,ptrA const & A,int lda,T const beta,ptrB const & B,int ldb,ptrC C,int ldc)242 inline static void geam(char Atrans,
243 char Btrans,
244 int M,
245 int N,
246 T const alpha,
247 ptrA const& A,
248 int lda,
249 T const beta,
250 ptrB const& B,
251 int ldb,
252 ptrC C,
253 int ldc)
254 {
255 if (CUBLAS_STATUS_SUCCESS !=
256 cublas::cublas_geam(*A.handles.cublas_handle, Atrans, Btrans, M, N, alpha, to_address(A), lda, beta,
257 to_address(B), ldb, to_address(C), ldc))
258 throw std::runtime_error("Error: cublas_geam returned error code.");
259 }
260
261 template<class T,
262 class ptrA,
263 class ptrB,
264 class ptrC,
265 typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
266 (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
267 (ptrC::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
268 typename = void>
geam(char Atrans,char Btrans,int M,int N,T const alpha,ptrA const & A,int lda,T const beta,ptrB const & B,int ldb,ptrC C,int ldc)269 inline static void geam(char Atrans,
270 char Btrans,
271 int M,
272 int N,
273 T const alpha,
274 ptrA const& A,
275 int lda,
276 T const beta,
277 ptrB const& B,
278 int ldb,
279 ptrC C,
280 int ldc)
281 {
282 using ma::geam;
283 return geam(Atrans, Btrans, M, N, alpha, to_address(A), lda, beta, to_address(B), ldb, to_address(C), ldc);
284 }
285
286 //template<class T,
287 template<class ptr,
288 typename = typename std::enable_if_t<not(ptr::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
289 typename = void>
290 //inline static void set1D(int n, T const alpha, ptr x, int incx)
set1D(int n,typename ptr::value_type const alpha,ptr x,int incx)291 inline static void set1D(int n, typename ptr::value_type const alpha, ptr x, int incx)
292 {
293 // No set funcion in cuda!!! Avoiding kernels for now
294 std::vector<typename ptr::value_type> buff(n, alpha);
295 if (CUBLAS_STATUS_SUCCESS !=
296 cublasSetVector(n, sizeof(typename ptr::value_type), buff.data(), 1, to_address(x), incx))
297 throw std::runtime_error("Error: cublasSetVector returned error code.");
298 }
299
300 template<class T, class ptr, typename = typename std::enable_if_t<(ptr::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>>
set1D(int n,T const alpha,ptr x,int incx)301 inline static void set1D(int n, T const alpha, ptr x, int incx)
302 {
303 auto y = to_address(x);
304 for (int i = 0; i < n; i++, y += incx)
305 *y = alpha;
306 }
307
308 // dot extension
309 template<class T,
310 class Q,
311 class ptrA,
312 class ptrB,
313 class ptrC,
314 typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
315 (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
316 (ptrC::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
adotpby(int const n,T const alpha,ptrA const & x,int const incx,ptrB const & y,int const incy,Q const beta,ptrC result)317 inline static void adotpby(int const n,
318 T const alpha,
319 ptrA const& x,
320 int const incx,
321 ptrB const& y,
322 int const incy,
323 Q const beta,
324 ptrC result)
325 {
326 kernels::adotpby(n, alpha, to_address(x), incx, to_address(y), incy, beta, to_address(result));
327 }
328
329 template<class T,
330 class Q,
331 class ptrA,
332 class ptrB,
333 class ptrC,
334 typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
335 (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
336 (ptrC::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
337 typename = void>
adotpby(int const n,T const alpha,ptrA const & x,int const incx,ptrB const & y,int const incy,Q const beta,ptrC result)338 inline static void adotpby(int const n,
339 T const alpha,
340 ptrA const& x,
341 int const incx,
342 ptrB const& y,
343 int const incy,
344 Q const beta,
345 ptrC result)
346 {
347 using ma::adotpby;
348 adotpby(n, alpha, to_address(x), incx, to_address(y), incy, beta, to_address(result));
349 }
350
351
352 // axty
353 template<class T,
354 class ptrA,
355 class ptrB,
356 typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
357 (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
axty(int n,T const alpha,ptrA const x,int incx,ptrB y,int incy)358 inline static void axty(int n, T const alpha, ptrA const x, int incx, ptrB y, int incy)
359 {
360 if (incx != 1 || incy != 1)
361 throw std::runtime_error("Error: axty with inc != 1 not implemented.");
362 kernels::axty(n, alpha, to_address(x), to_address(y));
363 }
364
365 template<class T,
366 class ptrA,
367 class ptrB,
368 typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) and
369 (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
370 typename = void>
axty(int n,T const alpha,ptrA const x,int incx,ptrB y,int incy)371 inline static void axty(int n, T const alpha, ptrA const x, int incx, ptrB y, int incy)
372 {
373 using ma::axty;
374 axty(n, alpha, to_address(x), incx, to_address(y), incy);
375 }
376
377 // acAxpbB
378 template<class T,
379 class ptrA,
380 class ptrx,
381 class ptrB,
382 typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
383 (ptrx::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
384 (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
acAxpbB(int m,int n,T const alpha,ptrA const A,int lda,ptrx const x,int incx,T const beta,ptrB B,int ldb)385 inline static void acAxpbB(int m,
386 int n,
387 T const alpha,
388 ptrA const A,
389 int lda,
390 ptrx const x,
391 int incx,
392 T const beta,
393 ptrB B,
394 int ldb)
395 {
396 kernels::acAxpbB(m, n, alpha, to_address(A), lda, to_address(x), incx, beta, to_address(B), ldb);
397 }
398
399 template<class T,
400 class ptrA,
401 class ptrx,
402 class ptrB,
403 typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) and
404 (ptrx::memory_type == CPU_OUTOFCARS_POINTER_TYPE) and
405 (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
406 typename = void>
acAxpbB(int m,int n,T const alpha,ptrA const A,int lda,ptrx const x,int incx,T const beta,ptrB B,int ldb)407 inline static void acAxpbB(int m,
408 int n,
409 T const alpha,
410 ptrA const A,
411 int lda,
412 ptrx const x,
413 int incx,
414 T const beta,
415 ptrB B,
416 int ldb)
417 {
418 using ma::acAxpbB;
419 acAxpbB(m, n, alpha, to_address(A), lda, to_address(x), incx, beta, to_address(B), ldb);
420 }
421
422 // adiagApy
423 template<class T,
424 class ptrA,
425 class ptrB,
426 typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
427 (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
adiagApy(int n,T const alpha,ptrA const A,int lda,ptrB y,int incy)428 inline static void adiagApy(int n, T const alpha, ptrA const A, int lda, ptrB y, int incy)
429 {
430 kernels::adiagApy(n, alpha, to_address(A), lda, to_address(y), incy);
431 }
432
433 template<class T,
434 class ptrA,
435 class ptrB,
436 typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) or
437 (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
438 typename = void>
adiagApy(int n,T const alpha,ptrA const A,int lda,ptrB y,int incy)439 inline static void adiagApy(int n, T const alpha, ptrA const A, int lda, ptrB y, int incy)
440 {
441 using ma::adiagApy;
442 adiagApy(n, alpha, to_address(A), lda, to_address(y), incy);
443 }
444
445 template<class ptr, typename = typename std::enable_if_t<(ptr::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
sum(int n,ptr const x,int incx)446 inline static auto sum(int n, ptr const x, int incx)
447 {
448 return kernels::sum(n, to_address(x), incx);
449 }
450
451 template<class ptr, typename = typename std::enable_if_t<(ptr::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
sum(int m,int n,ptr const A,int lda)452 inline static auto sum(int m, int n, ptr const A, int lda)
453 {
454 return kernels::sum(m, n, to_address(A), lda);
455 }
456
457 template<class ptr,
458 typename = typename std::enable_if_t<(ptr::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
459 typename = void>
sum(int n,ptr const x,int incx)460 inline static auto sum(int n, ptr const x, int incx)
461 {
462 using ma::sum;
463 return sum(n, to_address(x), incx);
464 }
465
466 template<class ptr,
467 typename = typename std::enable_if_t<(ptr::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
468 typename = void>
sum(int m,int n,ptr const A,int lda)469 inline static auto sum(int m, int n, ptr const A, int lda)
470 {
471 using ma::sum;
472 return sum(m, n, to_address(A), lda);
473 }
474
475 template<class T,
476 class ptrA,
477 class ptrB,
478 class ptrC,
479 typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
480 (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
481 (ptrC::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
gemmStridedBatched(char Atrans,char Btrans,int M,int N,int K,T const alpha,ptrA const A,int lda,int strideA,ptrB const B,int ldb,int strideB,T beta,ptrC C,int ldc,int strideC,int batchSize)482 inline static void gemmStridedBatched(char Atrans,
483 char Btrans,
484 int M,
485 int N,
486 int K,
487 T const alpha,
488 ptrA const A,
489 int lda,
490 int strideA,
491 ptrB const B,
492 int ldb,
493 int strideB,
494 T beta,
495 ptrC C,
496 int ldc,
497 int strideC,
498 int batchSize)
499 {
500 cublas::cublas_gemmStridedBatched(*A.handles.cublas_handle, Atrans, Btrans, M, N, K, alpha, to_address(A), lda,
501 strideA, to_address(B), ldb, strideB, beta, to_address(C), ldc, strideC, batchSize);
502 }
503
504 template<class T,
505 class ptrA,
506 class ptrB,
507 class ptrC,
508 typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) and
509 (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE) and
510 (ptrC::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
511 typename = void>
gemmStridedBatched(char Atrans,char Btrans,int M,int N,int K,T const alpha,ptrA const A,int lda,int strideA,ptrB const B,int ldb,int strideB,T beta,ptrC C,int ldc,int strideC,int batchSize)512 inline static void gemmStridedBatched(char Atrans,
513 char Btrans,
514 int M,
515 int N,
516 int K,
517 T const alpha,
518 ptrA const A,
519 int lda,
520 int strideA,
521 ptrB const B,
522 int ldb,
523 int strideB,
524 T beta,
525 ptrC C,
526 int ldc,
527 int strideC,
528 int batchSize)
529 {
530 using ma::gemmStridedBatched;
531 gemmStridedBatched(Atrans, Btrans, M, N, K, alpha, to_address(A), lda, strideA, to_address(B), ldb, strideB, beta,
532 to_address(C), ldc, strideC, batchSize);
533 }
534
535 template<class T,
536 class ptrA,
537 class ptrB,
538 class ptrC,
539 typename = typename std::enable_if_t<(ptrA::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
540 (ptrB::memory_type != CPU_OUTOFCARS_POINTER_TYPE) and
541 (ptrC::memory_type != CPU_OUTOFCARS_POINTER_TYPE)>>
gemmBatched(char Atrans,char Btrans,int M,int N,int K,T const alpha,ptrA const * A,int lda,ptrB const * B,int ldb,T beta,ptrC * C,int ldc,int batchSize)542 inline static void gemmBatched(char Atrans,
543 char Btrans,
544 int M,
545 int N,
546 int K,
547 T const alpha,
548 ptrA const* A,
549 int lda,
550 ptrB const* B,
551 int ldb,
552 T beta,
553 ptrC* C,
554 int ldc,
555 int batchSize)
556 {
557 using Q = typename ptrA::value_type;
558 Q **A_d, **B_d, **C_d;
559 Q **A_h, **B_h, **C_h;
560 A_h = new Q*[batchSize];
561 B_h = new Q*[batchSize];
562 C_h = new Q*[batchSize];
563 for (int i = 0; i < batchSize; i++)
564 {
565 A_h[i] = to_address(A[i]);
566 B_h[i] = to_address(B[i]);
567 C_h[i] = to_address(C[i]);
568 }
569 arch::malloc((void**)&A_d, batchSize * sizeof(*A_h));
570 arch::malloc((void**)&B_d, batchSize * sizeof(*B_h));
571 arch::malloc((void**)&C_d, batchSize * sizeof(*C_h));
572 arch::memcopy(A_d, A_h, batchSize * sizeof(*A_h), arch::memcopyH2D);
573 arch::memcopy(B_d, B_h, batchSize * sizeof(*B_h), arch::memcopyH2D);
574 arch::memcopy(C_d, C_h, batchSize * sizeof(*C_h), arch::memcopyH2D);
575 cublas::cublas_gemmBatched(*(A[0]).handles.cublas_handle, Atrans, Btrans, M, N, K, alpha, A_d, lda, B_d, ldb, beta,
576 C_d, ldc, batchSize);
577 arch::free(A_d);
578 arch::free(B_d);
579 arch::free(C_d);
580 delete[] A_h;
581 delete[] B_h;
582 delete[] C_h;
583 }
584
585 template<class T,
586 class ptrA,
587 class ptrB,
588 class ptrC,
589 typename = typename std::enable_if_t<(ptrA::memory_type == CPU_OUTOFCARS_POINTER_TYPE) and
590 (ptrB::memory_type == CPU_OUTOFCARS_POINTER_TYPE) and
591 (ptrC::memory_type == CPU_OUTOFCARS_POINTER_TYPE)>,
592 typename = void>
gemmBatched(char Atrans,char Btrans,int M,int N,int K,T const alpha,ptrA const * A,int lda,ptrB const * B,int ldb,T beta,ptrC * C,int ldc,int batchSize)593 inline static void gemmBatched(char Atrans,
594 char Btrans,
595 int M,
596 int N,
597 int K,
598 T const alpha,
599 ptrA const* A,
600 int lda,
601 ptrB const* B,
602 int ldb,
603 T beta,
604 ptrC* C,
605 int ldc,
606 int batchSize)
607 {
608 using Q = typename ptrA::value_type;
609 Q** A_d;
610 Q** B_d;
611 Q** C_d;
612 A_d = new Q*[batchSize];
613 B_d = new Q*[batchSize];
614 C_d = new Q*[batchSize];
615 for (int i = 0; i < batchSize; i++)
616 {
617 A_d[i] = to_address(A[i]);
618 B_d[i] = to_address(B[i]);
619 C_d[i] = to_address(C[i]);
620 }
621 using ma::gemmBatched;
622 gemmBatched(Atrans, Btrans, M, N, K, alpha, A_d, lda, B_d, ldb, beta, C_d, ldc, batchSize);
623 delete[] A_d;
624 delete[] B_d;
625 delete[] C_d;
626 }
627
628 } // namespace qmc_cuda
629
630 #endif
631