1 //////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source
3 // License.  See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
6 //
7 // File developed by:
8 //    Lawrence Livermore National Laboratory
9 //
10 // File created by:
11 // Miguel A. Morales, moralessilva2@llnl.gov
12 //    Lawrence Livermore National Laboratory
13 ////////////////////////////////////////////////////////////////////////////////
14 
15 #ifndef AFQMC_BLAS_CUDA_GPU_PTR_HPP
16 #define AFQMC_BLAS_CUDA_GPU_PTR_HPP
17 
18 #include <type_traits>
19 #include <cassert>
20 #include <vector>
21 //#include "AFQMC/Memory/CUDA/cuda_gpu_pointer.hpp"
22 #include "AFQMC/Utilities/type_conversion.hpp"
23 #include "AFQMC/Memory/device_pointers.hpp"
24 #include "AFQMC/Memory/arch.hpp"
25 #include "AFQMC/Numerics/detail/CUDA/cublas_wrapper.hpp"
26 //#include "AFQMC/Numerics/detail/CUDA/cublasXt_wrapper.hpp"
27 // hand coded kernels for blas extensions
28 #include "AFQMC/Numerics/detail/CUDA/Kernels/adotpby.cuh"
29 #include "AFQMC/Numerics/detail/CUDA/Kernels/setIdentity.cuh"
30 #include "AFQMC/Numerics/detail/CUDA/Kernels/axty.cuh"
31 #include "AFQMC/Numerics/detail/CUDA/Kernels/sum.cuh"
32 #include "AFQMC/Numerics/detail/CUDA/Kernels/adiagApy.cuh"
33 #include "AFQMC/Numerics/detail/CUDA/Kernels/acAxpbB.cuh"
34 #include "AFQMC/Numerics/detail/CUDA/Kernels/zero_complex_part.cuh"
35 #include "AFQMC/Numerics/detail/CUDA/Kernels/axpyBatched.cuh"
36 #include "AFQMC/Numerics/detail/CUDA/Kernels/get_diagonal.cuh"
37 
38 // Currently available:
39 // Lvl-1: dot, axpy, scal
40 // Lvl-2: gemv
41 // Lvl-3: gemm
42 
43 namespace device
44 {
45 // copy Specializations
46 template<typename T, typename Q>
copy(int n,device_pointer<Q> x,int incx,device_pointer<T> y,int incy)47 inline static void copy(int n, device_pointer<Q> x, int incx, device_pointer<T> y, int incy)
48 {
49   static_assert(std::is_same<typename std::decay<Q>::type, T>::value, "Wrong dispatch.\n");
50   if (CUBLAS_STATUS_SUCCESS !=
51       cublas::cublas_copy(*x.handles.cublas_handle, n, to_address(x), incx, to_address(y), incy))
52     throw std::runtime_error("Error: cublas_copy returned error code.");
53 }
54 
55 template<typename T, typename Q>
copy(int n,T const * x,int incx,device_pointer<Q> y,int incy)56 inline static void copy(int n, T const* x, int incx, device_pointer<Q> y, int incy)
57 {
58   static_assert(std::is_same<typename std::decay<Q>::type, T>::value, "Wrong dispatch.\n");
59   arch::memcopy2D(to_address(y), sizeof(Q) * incy, x, sizeof(T) * incx, sizeof(T), n, arch::memcopyH2D,
60                   "lapack_cuda_gpu_ptr::copy");
61 }
62 
63 template<typename T, typename Q>
copy(int n,device_pointer<Q> x,int incx,T * y,int incy)64 inline static void copy(int n, device_pointer<Q> x, int incx, T* y, int incy)
65 {
66   static_assert(std::is_same<typename std::decay<Q>::type, T>::value, "Wrong dispatch.\n");
67   assert(sizeof(Q) == sizeof(T));
68   arch::memcopy2D(y, sizeof(T) * incy, to_address(x), sizeof(Q) * incx, sizeof(T), n, arch::memcopyD2H,
69                   "lapack_cuda_gpu_ptr::copy");
70 }
71 
72 // scal Specializations
73 template<typename T, typename Q>
scal(int n,Q alpha,device_pointer<T> x,int incx=1)74 inline static void scal(int n, Q alpha, device_pointer<T> x, int incx = 1)
75 {
76   static_assert(std::is_convertible<typename std::decay<Q>::type, T>::value, "Wrong dispatch.\n");
77   if (CUBLAS_STATUS_SUCCESS != cublas::cublas_scal(*x.handles.cublas_handle, n, T(alpha), to_address(x), incx))
78     throw std::runtime_error("Error: cublas_scal returned error code.");
79 }
80 
81 // dot Specializations
82 template<typename T, typename Q>
dot(int const n,device_pointer<Q> x,int const incx,device_pointer<T> y,int const incy)83 inline static auto dot(int const n, device_pointer<Q> x, int const incx, device_pointer<T> y, int const incy)
84 {
85   static_assert(std::is_same<typename std::decay<Q>::type, typename std::decay<T>::type>::value, "Wrong dispatch.\n");
86   return cublas::cublas_dot(*x.handles.cublas_handle, n, to_address(x), incx, to_address(y), incy);
87 }
88 
89 // axpy Specializations
90 template<typename T, typename Q>
axpy(int n,T const a,device_pointer<Q> x,int incx,device_pointer<T> y,int incy)91 inline static void axpy(int n, T const a, device_pointer<Q> x, int incx, device_pointer<T> y, int incy)
92 {
93   static_assert(std::is_same<typename std::decay<Q>::type, T>::value, "Wrong dispatch.\n");
94   if (CUBLAS_STATUS_SUCCESS !=
95       cublas::cublas_axpy(*x.handles.cublas_handle, n, a, to_address(x), incx, to_address(y), incy))
96     throw std::runtime_error("Error: cublas_axpy returned error code.");
97 }
98 
99 // GEMV Specializations
100 template<typename T, typename T2, typename Q1, typename Q2>
gemv(char Atrans,int M,int N,T2 alpha,device_pointer<Q1> A,int lda,device_pointer<Q2> x,int incx,T2 beta,device_pointer<T> y,int incy)101 inline static void gemv(char Atrans,
102                         int M,
103                         int N,
104                         T2 alpha,
105                         device_pointer<Q1> A,
106                         int lda,
107                         device_pointer<Q2> x,
108                         int incx,
109                         T2 beta,
110                         device_pointer<T> y,
111                         int incy)
112 {
113   static_assert(std::is_same<typename std::decay<Q1>::type, T2>::value, "Wrong dispatch.\n");
114   static_assert(std::is_same<typename std::decay<Q2>::type, T>::value, "Wrong dispatch.\n");
115   if (CUBLAS_STATUS_SUCCESS !=
116       cublas::cublas_gemv(*A.handles.cublas_handle, Atrans, M, N, alpha, to_address(A), lda, to_address(x), incx, beta,
117                           to_address(y), incy))
118     throw std::runtime_error("Error: cublas_gemv returned error code.");
119 }
120 
121 // GEMM Specializations
122 // why is this not working with T const????
123 template<typename T, typename T2, typename Q1, typename Q2>
gemm(char Atrans,char Btrans,int M,int N,int K,T2 alpha,device_pointer<Q1> A,int lda,device_pointer<Q2> B,int ldb,T2 beta,device_pointer<T> C,int ldc)124 inline static void gemm(char Atrans,
125                         char Btrans,
126                         int M,
127                         int N,
128                         int K,
129                         T2 alpha,
130                         device_pointer<Q1> A,
131                         int lda,
132                         device_pointer<Q2> B,
133                         int ldb,
134                         T2 beta,
135                         device_pointer<T> C,
136                         int ldc)
137 {
138   static_assert(std::is_same<typename std::decay<Q1>::type, T>::value, "Wrong dispatch.\n");
139   static_assert(std::is_same<typename std::decay<Q2>::type, T2>::value, "Wrong dispatch.\n");
140   if (CUBLAS_STATUS_SUCCESS !=
141       cublas::cublas_gemm(*A.handles.cublas_handle, Atrans, Btrans, M, N, K, alpha, to_address(A), lda, to_address(B),
142                           ldb, beta, to_address(C), ldc))
143     throw std::runtime_error("Error: cublas_gemm returned error code.");
144 }
145 
146 // Blas Extensions
147 // geam
148 template<typename T, typename Q1, typename Q2>
geam(char Atrans,char Btrans,int M,int N,T const alpha,device_pointer<Q1> A,int lda,T const beta,device_pointer<Q2> B,int ldb,device_pointer<T> C,int ldc)149 inline static void geam(char Atrans,
150                         char Btrans,
151                         int M,
152                         int N,
153                         T const alpha,
154                         device_pointer<Q1> A,
155                         int lda,
156                         T const beta,
157                         device_pointer<Q2> B,
158                         int ldb,
159                         device_pointer<T> C,
160                         int ldc)
161 {
162   static_assert(std::is_same<typename std::decay<Q1>::type, T>::value, "Wrong dispatch.\n");
163   static_assert(std::is_same<typename std::decay<Q2>::type, T>::value, "Wrong dispatch.\n");
164   if (CUBLAS_STATUS_SUCCESS !=
165       cublas::cublas_geam(*A.handles.cublas_handle, Atrans, Btrans, M, N, alpha, to_address(A), lda, beta,
166                           to_address(B), ldb, to_address(C), ldc))
167     throw std::runtime_error("Error: cublas_geam returned error code.");
168 }
169 
170 template<typename T>
171 //inline static void set1D(int n, T const alpha, ptr x, int incx)
set1D(int n,T const alpha,device_pointer<T> x,int incx)172 inline static void set1D(int n, T const alpha, device_pointer<T> x, int incx)
173 {
174   // No set funcion in cuda!!! Avoiding kernels for now
175   //std::vector<T> buff(n,alpha);
176   //if(CUBLAS_STATUS_SUCCESS != cublasSetVector(n,sizeof(T),buff.data(),1,to_address(x),incx))
177   T alpha_(alpha);
178   if (CUBLAS_STATUS_SUCCESS != cublasSetVector(n, sizeof(T), std::addressof(alpha), 1, to_address(x), incx))
179     throw std::runtime_error("Error: cublasSetVector returned error code.");
180 }
181 
182 // dot extension
183 template<typename T, typename T1, typename T2, typename Q1, typename Q2>
adotpby(int const n,T1 const alpha,device_pointer<Q1> x,int const incx,device_pointer<Q2> y,int const incy,T2 const beta,T * result)184 inline static void adotpby(int const n,
185                            T1 const alpha,
186                            device_pointer<Q1> x,
187                            int const incx,
188                            device_pointer<Q2> y,
189                            int const incy,
190                            T2 const beta,
191                            T* result)
192 {
193   static_assert(std::is_same<typename std::decay<Q1>::type, T1>::value, "Wrong dispatch.\n");
194   static_assert(std::is_same<typename std::decay<Q2>::type, T1>::value, "Wrong dispatch.\n");
195   static_assert(std::is_same<typename std::decay<T2>::type, T>::value, "Wrong dispatch.\n");
196   kernels::adotpby(n, alpha, to_address(x), incx, to_address(y), incy, beta, result);
197 }
198 
199 // dot extension
200 template<typename T, typename T1, typename T2, typename Q1, typename Q2>
strided_adotpby(int nk,int const n,T1 const alpha,device_pointer<Q1> A,int const lda,device_pointer<Q2> B,int const ldb,T2 const beta,T * y,int inc)201 inline static void strided_adotpby(int nk,
202                                    int const n,
203                                    T1 const alpha,
204                                    device_pointer<Q1> A,
205                                    int const lda,
206                                    device_pointer<Q2> B,
207                                    int const ldb,
208                                    T2 const beta,
209                                    T* y,
210                                    int inc)
211 {
212   static_assert(std::is_same<typename std::decay<Q1>::type, T1>::value, "Wrong dispatch.\n");
213   static_assert(std::is_same<typename std::decay<Q2>::type, T1>::value, "Wrong dispatch.\n");
214   static_assert(std::is_same<typename std::decay<T2>::type, T>::value, "Wrong dispatch.\n");
215   kernels::strided_adotpby(nk, n, alpha, to_address(A), lda, to_address(B), ldb, beta, y, inc);
216 }
217 
218 // axty
219 template<typename T, typename Q>
axty(int n,T const alpha,device_pointer<Q> x,int incx,device_pointer<T> y,int incy)220 inline static void axty(int n, T const alpha, device_pointer<Q> x, int incx, device_pointer<T> y, int incy)
221 {
222   static_assert(std::is_same<typename std::decay<Q>::type, T>::value, "Wrong dispatch.\n");
223   if (incx != 1 || incy != 1)
224     throw std::runtime_error("Error: axty with inc != 1 not implemented.");
225   kernels::axty(n, alpha, to_address(x), to_address(y));
226 }
227 
228 // acAxpbB
229 template<typename T, typename Q1, typename Q2>
acAxpbB(int m,int n,T const alpha,device_pointer<Q1> A,int lda,device_pointer<Q2> x,int incx,T const beta,device_pointer<T> B,int ldb)230 inline static void acAxpbB(int m,
231                            int n,
232                            T const alpha,
233                            device_pointer<Q1> A,
234                            int lda,
235                            device_pointer<Q2> x,
236                            int incx,
237                            T const beta,
238                            device_pointer<T> B,
239                            int ldb)
240 {
241   static_assert(std::is_same<typename std::decay<Q1>::type, T>::value, "Wrong dispatch.\n");
242   static_assert(std::is_same<typename std::decay<Q2>::type, T>::value, "Wrong dispatch.\n");
243   kernels::acAxpbB(m, n, alpha, to_address(A), lda, to_address(x), incx, beta, to_address(B), ldb);
244 }
245 
246 // adiagApy
247 template<typename T, typename Q1>
adiagApy(int n,T const alpha,device_pointer<Q1> A,int lda,device_pointer<T> y,int incy)248 inline static void adiagApy(int n, T const alpha, device_pointer<Q1> A, int lda, device_pointer<T> y, int incy)
249 {
250   static_assert(std::is_same<typename std::decay<Q1>::type, T>::value, "Wrong dispatch.\n");
251   kernels::adiagApy(n, alpha, to_address(A), lda, to_address(y), incy);
252 }
253 
254 template<typename T>
zero_complex_part(int n,device_pointer<T> x)255 inline static void zero_complex_part(int n, device_pointer<T> x)
256 {
257   kernels::zero_complex_part(n, to_address(x));
258 }
259 
260 template<typename T>
sum(int n,device_pointer<T> x,int incx)261 inline static auto sum(int n, device_pointer<T> x, int incx)
262 {
263   return kernels::sum(n, to_address(x), incx);
264 }
265 
266 template<typename T>
sum(int m,int n,device_pointer<T> A,int lda)267 inline static auto sum(int m, int n, device_pointer<T> A, int lda)
268 {
269   return kernels::sum(m, n, to_address(A), lda);
270 }
271 
272 template<typename T>
set_identity(int m,int n,device_pointer<T> A,int lda)273 void set_identity(int m, int n, device_pointer<T> A, int lda)
274 {
275   kernels::set_identity(m, n, to_address(A), lda);
276 }
277 
278 template<typename T>
set_identity_strided(int nbatch,int stride,int m,int n,device_pointer<T> A,int lda)279 void set_identity_strided(int nbatch, int stride, int m, int n, device_pointer<T> A, int lda)
280 {
281   kernels::set_identity_strided(nbatch, stride, m, n, to_address(A), lda);
282 }
283 
284 template<typename T, typename Q1, typename Q2>
gemmStridedBatched(char Atrans,char Btrans,int M,int N,int K,T const alpha,device_pointer<Q1> A,int lda,int strideA,device_pointer<Q2> B,int ldb,int strideB,T beta,device_pointer<T> C,int ldc,int strideC,int batchSize)285 inline static void gemmStridedBatched(char Atrans,
286                                       char Btrans,
287                                       int M,
288                                       int N,
289                                       int K,
290                                       T const alpha,
291                                       device_pointer<Q1> A,
292                                       int lda,
293                                       int strideA,
294                                       device_pointer<Q2> B,
295                                       int ldb,
296                                       int strideB,
297                                       T beta,
298                                       device_pointer<T> C,
299                                       int ldc,
300                                       int strideC,
301                                       int batchSize)
302 {
303   static_assert(std::is_same<typename std::decay<Q1>::type, T>::value, "Wrong dispatch.\n");
304   static_assert(std::is_same<typename std::decay<Q2>::type, T>::value, "Wrong dispatch.\n");
305   cublas::cublas_gemmStridedBatched(*A.handles.cublas_handle, Atrans, Btrans, M, N, K, alpha, to_address(A), lda,
306                                     strideA, to_address(B), ldb, strideB, beta, to_address(C), ldc, strideC, batchSize);
307 }
308 
309 template<typename T,
310          typename Q1,
311          typename Q2,
312          typename = typename std::enable_if_t<std::is_same<typename std::decay<Q1>::type, T>::value>,
313          typename = typename std::enable_if_t<std::is_same<typename std::decay<Q2>::type, T>::value>>
gemmBatched(char Atrans,char Btrans,int M,int N,int K,T const alpha,device_pointer<Q1> * A,int lda,device_pointer<Q2> * B,int ldb,T const beta,device_pointer<T> * C,int ldc,int batchSize)314 inline static void gemmBatched(char Atrans,
315                                char Btrans,
316                                int M,
317                                int N,
318                                int K,
319                                T const alpha,
320                                device_pointer<Q1>* A,
321                                int lda,
322                                device_pointer<Q2>* B,
323                                int ldb,
324                                T const beta,
325                                device_pointer<T>* C,
326                                int ldc,
327                                int batchSize)
328 {
329   static_assert(std::is_same<typename std::decay<Q1>::type, T>::value, "Wrong dispatch.\n");
330   static_assert(std::is_same<typename std::decay<Q2>::type, T>::value, "Wrong dispatch.\n");
331   // replace with single call to arch::malloc and arch::memcopy
332   T **A_d, **B_d, **C_d;
333   Q1** A_h;
334   Q2** B_h;
335   T** C_h;
336   A_h = new Q1*[batchSize];
337   B_h = new Q2*[batchSize];
338   C_h = new T*[batchSize];
339   for (int i = 0; i < batchSize; i++)
340   {
341     A_h[i] = to_address(A[i]);
342     B_h[i] = to_address(B[i]);
343     C_h[i] = to_address(C[i]);
344   }
345   arch::malloc((void**)&A_d, batchSize * sizeof(*A_h));
346   arch::malloc((void**)&B_d, batchSize * sizeof(*B_h));
347   arch::malloc((void**)&C_d, batchSize * sizeof(*C_h));
348   arch::memcopy(A_d, A_h, batchSize * sizeof(*A_h), arch::memcopyH2D);
349   arch::memcopy(B_d, B_h, batchSize * sizeof(*B_h), arch::memcopyH2D);
350   arch::memcopy(C_d, C_h, batchSize * sizeof(*C_h), arch::memcopyH2D);
351   cublas::cublas_gemmBatched(*(A[0]).handles.cublas_handle, Atrans, Btrans, M, N, K, alpha, A_d, lda, B_d, ldb, beta,
352                              C_d, ldc, batchSize);
353   arch::free(A_d);
354   arch::free(B_d);
355   arch::free(C_d);
356   delete[] A_h;
357   delete[] B_h;
358   delete[] C_h;
359 }
360 
361 template<typename T,
362          typename Q1,
363          typename Q2,
364          typename T2,
365          typename = typename std::enable_if_t<std::is_same<typename std::decay<Q1>::type, T2>::value>,
366          typename = typename std::enable_if_t<std::is_same<typename std::decay<Q2>::type, T>::value>,
367          typename = typename std::enable_if_t<std::is_same<std::complex<T>, T2>::value>>
gemmBatched(char Atrans,char Btrans,int M,int N,int K,T const alpha,device_pointer<Q1> * A,int lda,device_pointer<Q2> * B,int ldb,T const beta,device_pointer<T2> * C,int ldc,int batchSize)368 inline static void gemmBatched(char Atrans,
369                                char Btrans,
370                                int M,
371                                int N,
372                                int K,
373                                T const alpha,
374                                device_pointer<Q1>* A,
375                                int lda,
376                                device_pointer<Q2>* B,
377                                int ldb,
378                                T const beta,
379                                device_pointer<T2>* C,
380                                int ldc,
381                                int batchSize)
382 {
383   // check that remove_complex<T2> == T ???
384   static_assert(std::is_same<typename std::decay<Q1>::type, T2>::value, "Wrong dispatch.\n");
385   static_assert(std::is_same<typename std::decay<Q2>::type, T>::value, "Wrong dispatch.\n");
386   assert(Atrans == 'N' || Atrans == 'n');
387   // replace with single call to arch::malloc and arch::memcopy
388   T2** A_d;
389   T** B_d;
390   T2** C_d;
391   Q1** A_h;
392   Q2** B_h;
393   T2** C_h;
394   A_h = new Q1*[batchSize];
395   B_h = new Q2*[batchSize];
396   C_h = new T2*[batchSize];
397   for (int i = 0; i < batchSize; i++)
398   {
399     A_h[i] = to_address(A[i]);
400     B_h[i] = to_address(B[i]);
401     C_h[i] = to_address(C[i]);
402   }
403   arch::malloc((void**)&A_d, batchSize * sizeof(*A_h));
404   arch::malloc((void**)&B_d, batchSize * sizeof(*B_h));
405   arch::malloc((void**)&C_d, batchSize * sizeof(*C_h));
406   arch::memcopy(A_d, A_h, batchSize * sizeof(*A_h), arch::memcopyH2D);
407   arch::memcopy(B_d, B_h, batchSize * sizeof(*B_h), arch::memcopyH2D);
408   arch::memcopy(C_d, C_h, batchSize * sizeof(*C_h), arch::memcopyH2D);
409   cublas::cublas_gemmBatched(*(A[0]).handles.cublas_handle, Atrans, Btrans, M, N, K, alpha, A_d, lda, B_d, ldb, beta,
410                              C_d, ldc, batchSize);
411   arch::free(A_d);
412   arch::free(B_d);
413   arch::free(C_d);
414   delete[] A_h;
415   delete[] B_h;
416   delete[] C_h;
417 }
418 
419 template<typename T1, typename T2, typename T3>
axpyBatched(int n,T1 * x,device_pointer<T2> * a,int inca,device_pointer<T3> * b,int incb,int batchSize)420 inline static void axpyBatched(int n,
421                                T1* x,
422                                device_pointer<T2>* a,
423                                int inca,
424                                device_pointer<T3>* b,
425                                int incb,
426                                int batchSize)
427 {
428   T2 const** a_ = new T2 const*[batchSize];
429   T3** b_       = new T3*[batchSize];
430   for (int i = 0; i < batchSize; i++)
431   {
432     a_[i] = to_address(a[i]);
433     b_[i] = to_address(b[i]);
434   }
435   kernels::axpy_batched_gpu(n, x, a_, inca, b_, incb, batchSize);
436   delete[] a_;
437   delete[] b_;
438 }
439 
440 template<typename T1, typename T2, typename T3>
sumGwBatched(int n,T1 * x,device_pointer<T2> * a,int inca,device_pointer<T3> * b,int incb,int b0,int nw,int batchSize)441 inline static void sumGwBatched(int n,
442                                 T1* x,
443                                 device_pointer<T2>* a,
444                                 int inca,
445                                 device_pointer<T3>* b,
446                                 int incb,
447                                 int b0,
448                                 int nw,
449                                 int batchSize)
450 {
451   T2 const** a_ = new T2 const*[batchSize];
452   T3** b_       = new T3*[batchSize];
453   for (int i = 0; i < batchSize; i++)
454   {
455     a_[i] = to_address(a[i]);
456     b_[i] = to_address(b[i]);
457   }
458   kernels::sumGw_batched_gpu(n, x, a_, inca, b_, incb, b0, nw, batchSize);
459   delete[] a_;
460   delete[] b_;
461 }
462 
463 template<typename T, typename T2>
copy2D(int N,int M,device_pointer<T> src,int lda,device_pointer<T2> dst,int ldb)464 inline static void copy2D(int N, int M, device_pointer<T> src, int lda, device_pointer<T2> dst, int ldb)
465 {
466   static_assert(std::is_same<typename std::decay<T>::type, T2>::value, "Wrong dispatch.\n");
467   arch::memcopy2D(to_address(dst), sizeof(T2) * ldb, to_address(src), sizeof(T) * lda, M * sizeof(T), N,
468                   arch::memcopyD2D, "blas_cuda_gpu_ptr::copy2D");
469 }
470 
471 template<typename T, typename T2>
copy2D(int N,int M,T const * src,int lda,device_pointer<T2> dst,int ldb)472 inline static void copy2D(int N, int M, T const* src, int lda, device_pointer<T2> dst, int ldb)
473 {
474   static_assert(std::is_same<typename std::decay<T>::type, T2>::value, "Wrong dispatch.\n");
475   arch::memcopy2D(to_address(dst), sizeof(T2) * ldb, src, sizeof(T) * lda, M * sizeof(T), N, arch::memcopyH2D,
476                   "blas_cuda_gpu_ptr::copy2D");
477 }
478 
479 template<typename T, typename T2>
copy2D(int N,int M,device_pointer<T> src,int lda,T2 * dst,int ldb)480 inline static void copy2D(int N, int M, device_pointer<T> src, int lda, T2* dst, int ldb)
481 {
482   static_assert(std::is_same<typename std::decay<T>::type, T2>::value, "Wrong dispatch.\n");
483   arch::memcopy2D(dst, sizeof(T2) * ldb, to_address(src), sizeof(T) * lda, M * sizeof(T), N, arch::memcopyD2H,
484                   "blas_cuda_gpu_ptr::copy2D");
485 }
486 
487 template<typename T, typename T2>
get_diagonal_strided(int nk,int ni,device_pointer<T> A,int lda,int stride,device_pointer<T2> B,int ldb)488 inline static void get_diagonal_strided(int nk,
489                                         int ni,
490                                         device_pointer<T> A,
491                                         int lda,
492                                         int stride,
493                                         device_pointer<T2> B,
494                                         int ldb)
495 {
496   kernels::get_diagonal_strided(nk, ni, to_address(A), lda, stride, to_address(B), ldb);
497 }
498 } // namespace device
499 
500 #endif
501