1 //////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source
3 // License. See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
6 //
7 // File developed by:
8 // Lawrence Livermore National Laboratory
9 //
10 // File created by:
11 // Miguel A. Morales, moralessilva2@llnl.gov
12 // Lawrence Livermore National Laboratory
13 ////////////////////////////////////////////////////////////////////////////////
14
15 #ifndef AFQMC_BLAS_CUDA_GPU_PTR_HPP
16 #define AFQMC_BLAS_CUDA_GPU_PTR_HPP
17
18 #include <type_traits>
19 #include <cassert>
20 #include <vector>
21 //#include "AFQMC/Memory/CUDA/cuda_gpu_pointer.hpp"
22 #include "AFQMC/Utilities/type_conversion.hpp"
23 #include "AFQMC/Memory/device_pointers.hpp"
24 #include "AFQMC/Memory/arch.hpp"
25 #include "AFQMC/Numerics/detail/CUDA/cublas_wrapper.hpp"
26 //#include "AFQMC/Numerics/detail/CUDA/cublasXt_wrapper.hpp"
27 // hand coded kernels for blas extensions
28 #include "AFQMC/Numerics/detail/CUDA/Kernels/adotpby.cuh"
29 #include "AFQMC/Numerics/detail/CUDA/Kernels/setIdentity.cuh"
30 #include "AFQMC/Numerics/detail/CUDA/Kernels/axty.cuh"
31 #include "AFQMC/Numerics/detail/CUDA/Kernels/sum.cuh"
32 #include "AFQMC/Numerics/detail/CUDA/Kernels/adiagApy.cuh"
33 #include "AFQMC/Numerics/detail/CUDA/Kernels/acAxpbB.cuh"
34 #include "AFQMC/Numerics/detail/CUDA/Kernels/zero_complex_part.cuh"
35 #include "AFQMC/Numerics/detail/CUDA/Kernels/axpyBatched.cuh"
36 #include "AFQMC/Numerics/detail/CUDA/Kernels/get_diagonal.cuh"
37
38 // Currently available:
39 // Lvl-1: dot, axpy, scal
40 // Lvl-2: gemv
41 // Lvl-3: gemm
42
43 namespace device
44 {
45 // copy Specializations
46 template<typename T, typename Q>
copy(int n,device_pointer<Q> x,int incx,device_pointer<T> y,int incy)47 inline static void copy(int n, device_pointer<Q> x, int incx, device_pointer<T> y, int incy)
48 {
49 static_assert(std::is_same<typename std::decay<Q>::type, T>::value, "Wrong dispatch.\n");
50 if (CUBLAS_STATUS_SUCCESS !=
51 cublas::cublas_copy(*x.handles.cublas_handle, n, to_address(x), incx, to_address(y), incy))
52 throw std::runtime_error("Error: cublas_copy returned error code.");
53 }
54
55 template<typename T, typename Q>
copy(int n,T const * x,int incx,device_pointer<Q> y,int incy)56 inline static void copy(int n, T const* x, int incx, device_pointer<Q> y, int incy)
57 {
58 static_assert(std::is_same<typename std::decay<Q>::type, T>::value, "Wrong dispatch.\n");
59 arch::memcopy2D(to_address(y), sizeof(Q) * incy, x, sizeof(T) * incx, sizeof(T), n, arch::memcopyH2D,
60 "lapack_cuda_gpu_ptr::copy");
61 }
62
63 template<typename T, typename Q>
copy(int n,device_pointer<Q> x,int incx,T * y,int incy)64 inline static void copy(int n, device_pointer<Q> x, int incx, T* y, int incy)
65 {
66 static_assert(std::is_same<typename std::decay<Q>::type, T>::value, "Wrong dispatch.\n");
67 assert(sizeof(Q) == sizeof(T));
68 arch::memcopy2D(y, sizeof(T) * incy, to_address(x), sizeof(Q) * incx, sizeof(T), n, arch::memcopyD2H,
69 "lapack_cuda_gpu_ptr::copy");
70 }
71
72 // scal Specializations
73 template<typename T, typename Q>
scal(int n,Q alpha,device_pointer<T> x,int incx=1)74 inline static void scal(int n, Q alpha, device_pointer<T> x, int incx = 1)
75 {
76 static_assert(std::is_convertible<typename std::decay<Q>::type, T>::value, "Wrong dispatch.\n");
77 if (CUBLAS_STATUS_SUCCESS != cublas::cublas_scal(*x.handles.cublas_handle, n, T(alpha), to_address(x), incx))
78 throw std::runtime_error("Error: cublas_scal returned error code.");
79 }
80
81 // dot Specializations
82 template<typename T, typename Q>
dot(int const n,device_pointer<Q> x,int const incx,device_pointer<T> y,int const incy)83 inline static auto dot(int const n, device_pointer<Q> x, int const incx, device_pointer<T> y, int const incy)
84 {
85 static_assert(std::is_same<typename std::decay<Q>::type, typename std::decay<T>::type>::value, "Wrong dispatch.\n");
86 return cublas::cublas_dot(*x.handles.cublas_handle, n, to_address(x), incx, to_address(y), incy);
87 }
88
89 // axpy Specializations
90 template<typename T, typename Q>
axpy(int n,T const a,device_pointer<Q> x,int incx,device_pointer<T> y,int incy)91 inline static void axpy(int n, T const a, device_pointer<Q> x, int incx, device_pointer<T> y, int incy)
92 {
93 static_assert(std::is_same<typename std::decay<Q>::type, T>::value, "Wrong dispatch.\n");
94 if (CUBLAS_STATUS_SUCCESS !=
95 cublas::cublas_axpy(*x.handles.cublas_handle, n, a, to_address(x), incx, to_address(y), incy))
96 throw std::runtime_error("Error: cublas_axpy returned error code.");
97 }
98
99 // GEMV Specializations
100 template<typename T, typename T2, typename Q1, typename Q2>
gemv(char Atrans,int M,int N,T2 alpha,device_pointer<Q1> A,int lda,device_pointer<Q2> x,int incx,T2 beta,device_pointer<T> y,int incy)101 inline static void gemv(char Atrans,
102 int M,
103 int N,
104 T2 alpha,
105 device_pointer<Q1> A,
106 int lda,
107 device_pointer<Q2> x,
108 int incx,
109 T2 beta,
110 device_pointer<T> y,
111 int incy)
112 {
113 static_assert(std::is_same<typename std::decay<Q1>::type, T2>::value, "Wrong dispatch.\n");
114 static_assert(std::is_same<typename std::decay<Q2>::type, T>::value, "Wrong dispatch.\n");
115 if (CUBLAS_STATUS_SUCCESS !=
116 cublas::cublas_gemv(*A.handles.cublas_handle, Atrans, M, N, alpha, to_address(A), lda, to_address(x), incx, beta,
117 to_address(y), incy))
118 throw std::runtime_error("Error: cublas_gemv returned error code.");
119 }
120
121 // GEMM Specializations
122 // why is this not working with T const????
123 template<typename T, typename T2, typename Q1, typename Q2>
gemm(char Atrans,char Btrans,int M,int N,int K,T2 alpha,device_pointer<Q1> A,int lda,device_pointer<Q2> B,int ldb,T2 beta,device_pointer<T> C,int ldc)124 inline static void gemm(char Atrans,
125 char Btrans,
126 int M,
127 int N,
128 int K,
129 T2 alpha,
130 device_pointer<Q1> A,
131 int lda,
132 device_pointer<Q2> B,
133 int ldb,
134 T2 beta,
135 device_pointer<T> C,
136 int ldc)
137 {
138 static_assert(std::is_same<typename std::decay<Q1>::type, T>::value, "Wrong dispatch.\n");
139 static_assert(std::is_same<typename std::decay<Q2>::type, T2>::value, "Wrong dispatch.\n");
140 if (CUBLAS_STATUS_SUCCESS !=
141 cublas::cublas_gemm(*A.handles.cublas_handle, Atrans, Btrans, M, N, K, alpha, to_address(A), lda, to_address(B),
142 ldb, beta, to_address(C), ldc))
143 throw std::runtime_error("Error: cublas_gemm returned error code.");
144 }
145
146 // Blas Extensions
147 // geam
148 template<typename T, typename Q1, typename Q2>
geam(char Atrans,char Btrans,int M,int N,T const alpha,device_pointer<Q1> A,int lda,T const beta,device_pointer<Q2> B,int ldb,device_pointer<T> C,int ldc)149 inline static void geam(char Atrans,
150 char Btrans,
151 int M,
152 int N,
153 T const alpha,
154 device_pointer<Q1> A,
155 int lda,
156 T const beta,
157 device_pointer<Q2> B,
158 int ldb,
159 device_pointer<T> C,
160 int ldc)
161 {
162 static_assert(std::is_same<typename std::decay<Q1>::type, T>::value, "Wrong dispatch.\n");
163 static_assert(std::is_same<typename std::decay<Q2>::type, T>::value, "Wrong dispatch.\n");
164 if (CUBLAS_STATUS_SUCCESS !=
165 cublas::cublas_geam(*A.handles.cublas_handle, Atrans, Btrans, M, N, alpha, to_address(A), lda, beta,
166 to_address(B), ldb, to_address(C), ldc))
167 throw std::runtime_error("Error: cublas_geam returned error code.");
168 }
169
170 template<typename T>
171 //inline static void set1D(int n, T const alpha, ptr x, int incx)
set1D(int n,T const alpha,device_pointer<T> x,int incx)172 inline static void set1D(int n, T const alpha, device_pointer<T> x, int incx)
173 {
174 // No set funcion in cuda!!! Avoiding kernels for now
175 //std::vector<T> buff(n,alpha);
176 //if(CUBLAS_STATUS_SUCCESS != cublasSetVector(n,sizeof(T),buff.data(),1,to_address(x),incx))
177 T alpha_(alpha);
178 if (CUBLAS_STATUS_SUCCESS != cublasSetVector(n, sizeof(T), std::addressof(alpha), 1, to_address(x), incx))
179 throw std::runtime_error("Error: cublasSetVector returned error code.");
180 }
181
182 // dot extension
183 template<typename T, typename T1, typename T2, typename Q1, typename Q2>
adotpby(int const n,T1 const alpha,device_pointer<Q1> x,int const incx,device_pointer<Q2> y,int const incy,T2 const beta,T * result)184 inline static void adotpby(int const n,
185 T1 const alpha,
186 device_pointer<Q1> x,
187 int const incx,
188 device_pointer<Q2> y,
189 int const incy,
190 T2 const beta,
191 T* result)
192 {
193 static_assert(std::is_same<typename std::decay<Q1>::type, T1>::value, "Wrong dispatch.\n");
194 static_assert(std::is_same<typename std::decay<Q2>::type, T1>::value, "Wrong dispatch.\n");
195 static_assert(std::is_same<typename std::decay<T2>::type, T>::value, "Wrong dispatch.\n");
196 kernels::adotpby(n, alpha, to_address(x), incx, to_address(y), incy, beta, result);
197 }
198
199 // dot extension
200 template<typename T, typename T1, typename T2, typename Q1, typename Q2>
strided_adotpby(int nk,int const n,T1 const alpha,device_pointer<Q1> A,int const lda,device_pointer<Q2> B,int const ldb,T2 const beta,T * y,int inc)201 inline static void strided_adotpby(int nk,
202 int const n,
203 T1 const alpha,
204 device_pointer<Q1> A,
205 int const lda,
206 device_pointer<Q2> B,
207 int const ldb,
208 T2 const beta,
209 T* y,
210 int inc)
211 {
212 static_assert(std::is_same<typename std::decay<Q1>::type, T1>::value, "Wrong dispatch.\n");
213 static_assert(std::is_same<typename std::decay<Q2>::type, T1>::value, "Wrong dispatch.\n");
214 static_assert(std::is_same<typename std::decay<T2>::type, T>::value, "Wrong dispatch.\n");
215 kernels::strided_adotpby(nk, n, alpha, to_address(A), lda, to_address(B), ldb, beta, y, inc);
216 }
217
218 // axty
219 template<typename T, typename Q>
axty(int n,T const alpha,device_pointer<Q> x,int incx,device_pointer<T> y,int incy)220 inline static void axty(int n, T const alpha, device_pointer<Q> x, int incx, device_pointer<T> y, int incy)
221 {
222 static_assert(std::is_same<typename std::decay<Q>::type, T>::value, "Wrong dispatch.\n");
223 if (incx != 1 || incy != 1)
224 throw std::runtime_error("Error: axty with inc != 1 not implemented.");
225 kernels::axty(n, alpha, to_address(x), to_address(y));
226 }
227
228 // acAxpbB
229 template<typename T, typename Q1, typename Q2>
acAxpbB(int m,int n,T const alpha,device_pointer<Q1> A,int lda,device_pointer<Q2> x,int incx,T const beta,device_pointer<T> B,int ldb)230 inline static void acAxpbB(int m,
231 int n,
232 T const alpha,
233 device_pointer<Q1> A,
234 int lda,
235 device_pointer<Q2> x,
236 int incx,
237 T const beta,
238 device_pointer<T> B,
239 int ldb)
240 {
241 static_assert(std::is_same<typename std::decay<Q1>::type, T>::value, "Wrong dispatch.\n");
242 static_assert(std::is_same<typename std::decay<Q2>::type, T>::value, "Wrong dispatch.\n");
243 kernels::acAxpbB(m, n, alpha, to_address(A), lda, to_address(x), incx, beta, to_address(B), ldb);
244 }
245
246 // adiagApy
247 template<typename T, typename Q1>
adiagApy(int n,T const alpha,device_pointer<Q1> A,int lda,device_pointer<T> y,int incy)248 inline static void adiagApy(int n, T const alpha, device_pointer<Q1> A, int lda, device_pointer<T> y, int incy)
249 {
250 static_assert(std::is_same<typename std::decay<Q1>::type, T>::value, "Wrong dispatch.\n");
251 kernels::adiagApy(n, alpha, to_address(A), lda, to_address(y), incy);
252 }
253
254 template<typename T>
zero_complex_part(int n,device_pointer<T> x)255 inline static void zero_complex_part(int n, device_pointer<T> x)
256 {
257 kernels::zero_complex_part(n, to_address(x));
258 }
259
260 template<typename T>
sum(int n,device_pointer<T> x,int incx)261 inline static auto sum(int n, device_pointer<T> x, int incx)
262 {
263 return kernels::sum(n, to_address(x), incx);
264 }
265
266 template<typename T>
sum(int m,int n,device_pointer<T> A,int lda)267 inline static auto sum(int m, int n, device_pointer<T> A, int lda)
268 {
269 return kernels::sum(m, n, to_address(A), lda);
270 }
271
272 template<typename T>
set_identity(int m,int n,device_pointer<T> A,int lda)273 void set_identity(int m, int n, device_pointer<T> A, int lda)
274 {
275 kernels::set_identity(m, n, to_address(A), lda);
276 }
277
278 template<typename T>
set_identity_strided(int nbatch,int stride,int m,int n,device_pointer<T> A,int lda)279 void set_identity_strided(int nbatch, int stride, int m, int n, device_pointer<T> A, int lda)
280 {
281 kernels::set_identity_strided(nbatch, stride, m, n, to_address(A), lda);
282 }
283
284 template<typename T, typename Q1, typename Q2>
gemmStridedBatched(char Atrans,char Btrans,int M,int N,int K,T const alpha,device_pointer<Q1> A,int lda,int strideA,device_pointer<Q2> B,int ldb,int strideB,T beta,device_pointer<T> C,int ldc,int strideC,int batchSize)285 inline static void gemmStridedBatched(char Atrans,
286 char Btrans,
287 int M,
288 int N,
289 int K,
290 T const alpha,
291 device_pointer<Q1> A,
292 int lda,
293 int strideA,
294 device_pointer<Q2> B,
295 int ldb,
296 int strideB,
297 T beta,
298 device_pointer<T> C,
299 int ldc,
300 int strideC,
301 int batchSize)
302 {
303 static_assert(std::is_same<typename std::decay<Q1>::type, T>::value, "Wrong dispatch.\n");
304 static_assert(std::is_same<typename std::decay<Q2>::type, T>::value, "Wrong dispatch.\n");
305 cublas::cublas_gemmStridedBatched(*A.handles.cublas_handle, Atrans, Btrans, M, N, K, alpha, to_address(A), lda,
306 strideA, to_address(B), ldb, strideB, beta, to_address(C), ldc, strideC, batchSize);
307 }
308
309 template<typename T,
310 typename Q1,
311 typename Q2,
312 typename = typename std::enable_if_t<std::is_same<typename std::decay<Q1>::type, T>::value>,
313 typename = typename std::enable_if_t<std::is_same<typename std::decay<Q2>::type, T>::value>>
gemmBatched(char Atrans,char Btrans,int M,int N,int K,T const alpha,device_pointer<Q1> * A,int lda,device_pointer<Q2> * B,int ldb,T const beta,device_pointer<T> * C,int ldc,int batchSize)314 inline static void gemmBatched(char Atrans,
315 char Btrans,
316 int M,
317 int N,
318 int K,
319 T const alpha,
320 device_pointer<Q1>* A,
321 int lda,
322 device_pointer<Q2>* B,
323 int ldb,
324 T const beta,
325 device_pointer<T>* C,
326 int ldc,
327 int batchSize)
328 {
329 static_assert(std::is_same<typename std::decay<Q1>::type, T>::value, "Wrong dispatch.\n");
330 static_assert(std::is_same<typename std::decay<Q2>::type, T>::value, "Wrong dispatch.\n");
331 // replace with single call to arch::malloc and arch::memcopy
332 T **A_d, **B_d, **C_d;
333 Q1** A_h;
334 Q2** B_h;
335 T** C_h;
336 A_h = new Q1*[batchSize];
337 B_h = new Q2*[batchSize];
338 C_h = new T*[batchSize];
339 for (int i = 0; i < batchSize; i++)
340 {
341 A_h[i] = to_address(A[i]);
342 B_h[i] = to_address(B[i]);
343 C_h[i] = to_address(C[i]);
344 }
345 arch::malloc((void**)&A_d, batchSize * sizeof(*A_h));
346 arch::malloc((void**)&B_d, batchSize * sizeof(*B_h));
347 arch::malloc((void**)&C_d, batchSize * sizeof(*C_h));
348 arch::memcopy(A_d, A_h, batchSize * sizeof(*A_h), arch::memcopyH2D);
349 arch::memcopy(B_d, B_h, batchSize * sizeof(*B_h), arch::memcopyH2D);
350 arch::memcopy(C_d, C_h, batchSize * sizeof(*C_h), arch::memcopyH2D);
351 cublas::cublas_gemmBatched(*(A[0]).handles.cublas_handle, Atrans, Btrans, M, N, K, alpha, A_d, lda, B_d, ldb, beta,
352 C_d, ldc, batchSize);
353 arch::free(A_d);
354 arch::free(B_d);
355 arch::free(C_d);
356 delete[] A_h;
357 delete[] B_h;
358 delete[] C_h;
359 }
360
361 template<typename T,
362 typename Q1,
363 typename Q2,
364 typename T2,
365 typename = typename std::enable_if_t<std::is_same<typename std::decay<Q1>::type, T2>::value>,
366 typename = typename std::enable_if_t<std::is_same<typename std::decay<Q2>::type, T>::value>,
367 typename = typename std::enable_if_t<std::is_same<std::complex<T>, T2>::value>>
gemmBatched(char Atrans,char Btrans,int M,int N,int K,T const alpha,device_pointer<Q1> * A,int lda,device_pointer<Q2> * B,int ldb,T const beta,device_pointer<T2> * C,int ldc,int batchSize)368 inline static void gemmBatched(char Atrans,
369 char Btrans,
370 int M,
371 int N,
372 int K,
373 T const alpha,
374 device_pointer<Q1>* A,
375 int lda,
376 device_pointer<Q2>* B,
377 int ldb,
378 T const beta,
379 device_pointer<T2>* C,
380 int ldc,
381 int batchSize)
382 {
383 // check that remove_complex<T2> == T ???
384 static_assert(std::is_same<typename std::decay<Q1>::type, T2>::value, "Wrong dispatch.\n");
385 static_assert(std::is_same<typename std::decay<Q2>::type, T>::value, "Wrong dispatch.\n");
386 assert(Atrans == 'N' || Atrans == 'n');
387 // replace with single call to arch::malloc and arch::memcopy
388 T2** A_d;
389 T** B_d;
390 T2** C_d;
391 Q1** A_h;
392 Q2** B_h;
393 T2** C_h;
394 A_h = new Q1*[batchSize];
395 B_h = new Q2*[batchSize];
396 C_h = new T2*[batchSize];
397 for (int i = 0; i < batchSize; i++)
398 {
399 A_h[i] = to_address(A[i]);
400 B_h[i] = to_address(B[i]);
401 C_h[i] = to_address(C[i]);
402 }
403 arch::malloc((void**)&A_d, batchSize * sizeof(*A_h));
404 arch::malloc((void**)&B_d, batchSize * sizeof(*B_h));
405 arch::malloc((void**)&C_d, batchSize * sizeof(*C_h));
406 arch::memcopy(A_d, A_h, batchSize * sizeof(*A_h), arch::memcopyH2D);
407 arch::memcopy(B_d, B_h, batchSize * sizeof(*B_h), arch::memcopyH2D);
408 arch::memcopy(C_d, C_h, batchSize * sizeof(*C_h), arch::memcopyH2D);
409 cublas::cublas_gemmBatched(*(A[0]).handles.cublas_handle, Atrans, Btrans, M, N, K, alpha, A_d, lda, B_d, ldb, beta,
410 C_d, ldc, batchSize);
411 arch::free(A_d);
412 arch::free(B_d);
413 arch::free(C_d);
414 delete[] A_h;
415 delete[] B_h;
416 delete[] C_h;
417 }
418
419 template<typename T1, typename T2, typename T3>
axpyBatched(int n,T1 * x,device_pointer<T2> * a,int inca,device_pointer<T3> * b,int incb,int batchSize)420 inline static void axpyBatched(int n,
421 T1* x,
422 device_pointer<T2>* a,
423 int inca,
424 device_pointer<T3>* b,
425 int incb,
426 int batchSize)
427 {
428 T2 const** a_ = new T2 const*[batchSize];
429 T3** b_ = new T3*[batchSize];
430 for (int i = 0; i < batchSize; i++)
431 {
432 a_[i] = to_address(a[i]);
433 b_[i] = to_address(b[i]);
434 }
435 kernels::axpy_batched_gpu(n, x, a_, inca, b_, incb, batchSize);
436 delete[] a_;
437 delete[] b_;
438 }
439
440 template<typename T1, typename T2, typename T3>
sumGwBatched(int n,T1 * x,device_pointer<T2> * a,int inca,device_pointer<T3> * b,int incb,int b0,int nw,int batchSize)441 inline static void sumGwBatched(int n,
442 T1* x,
443 device_pointer<T2>* a,
444 int inca,
445 device_pointer<T3>* b,
446 int incb,
447 int b0,
448 int nw,
449 int batchSize)
450 {
451 T2 const** a_ = new T2 const*[batchSize];
452 T3** b_ = new T3*[batchSize];
453 for (int i = 0; i < batchSize; i++)
454 {
455 a_[i] = to_address(a[i]);
456 b_[i] = to_address(b[i]);
457 }
458 kernels::sumGw_batched_gpu(n, x, a_, inca, b_, incb, b0, nw, batchSize);
459 delete[] a_;
460 delete[] b_;
461 }
462
463 template<typename T, typename T2>
copy2D(int N,int M,device_pointer<T> src,int lda,device_pointer<T2> dst,int ldb)464 inline static void copy2D(int N, int M, device_pointer<T> src, int lda, device_pointer<T2> dst, int ldb)
465 {
466 static_assert(std::is_same<typename std::decay<T>::type, T2>::value, "Wrong dispatch.\n");
467 arch::memcopy2D(to_address(dst), sizeof(T2) * ldb, to_address(src), sizeof(T) * lda, M * sizeof(T), N,
468 arch::memcopyD2D, "blas_cuda_gpu_ptr::copy2D");
469 }
470
471 template<typename T, typename T2>
copy2D(int N,int M,T const * src,int lda,device_pointer<T2> dst,int ldb)472 inline static void copy2D(int N, int M, T const* src, int lda, device_pointer<T2> dst, int ldb)
473 {
474 static_assert(std::is_same<typename std::decay<T>::type, T2>::value, "Wrong dispatch.\n");
475 arch::memcopy2D(to_address(dst), sizeof(T2) * ldb, src, sizeof(T) * lda, M * sizeof(T), N, arch::memcopyH2D,
476 "blas_cuda_gpu_ptr::copy2D");
477 }
478
479 template<typename T, typename T2>
copy2D(int N,int M,device_pointer<T> src,int lda,T2 * dst,int ldb)480 inline static void copy2D(int N, int M, device_pointer<T> src, int lda, T2* dst, int ldb)
481 {
482 static_assert(std::is_same<typename std::decay<T>::type, T2>::value, "Wrong dispatch.\n");
483 arch::memcopy2D(dst, sizeof(T2) * ldb, to_address(src), sizeof(T) * lda, M * sizeof(T), N, arch::memcopyD2H,
484 "blas_cuda_gpu_ptr::copy2D");
485 }
486
487 template<typename T, typename T2>
get_diagonal_strided(int nk,int ni,device_pointer<T> A,int lda,int stride,device_pointer<T2> B,int ldb)488 inline static void get_diagonal_strided(int nk,
489 int ni,
490 device_pointer<T> A,
491 int lda,
492 int stride,
493 device_pointer<T2> B,
494 int ldb)
495 {
496 kernels::get_diagonal_strided(nk, ni, to_address(A), lda, stride, to_address(B), ldb);
497 }
498 } // namespace device
499
500 #endif
501