1 //////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source
3 // License. See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
6 //
7 // File developed by:
8 // Lawrence Livermore National Laboratory
9 //
10 // File created by:
11 // Miguel A. Morales, moralessilva2@llnl.gov
12 // Lawrence Livermore National Laboratory
13 ////////////////////////////////////////////////////////////////////////////////
14
15 #include <cassert>
16 #include <complex>
17 #include <cuda.h>
18 #include <thrust/complex.h>
19 #include <cuda_runtime.h>
20 #include <thrust/system/cuda/detail/core/util.h>
21 #include "AFQMC/Numerics/detail/CUDA/Kernels/cuda_settings.h"
22 #define ENABLE_CUDA 1
23 #include "AFQMC/Memory/CUDA/cuda_utilities.h"
24
25 namespace kernels
26 {
27 // C[u][w] = alpha * sum_a A[u][w][a] * B[u][a]
28 template<typename T>
kernel_Auwn_Bun_Cuw(int nu,int nw,int na,thrust::complex<T> const alpha,thrust::complex<T> const * A,thrust::complex<T> const * B,thrust::complex<T> * C)29 __global__ void kernel_Auwn_Bun_Cuw(int nu,
30 int nw,
31 int na,
32 thrust::complex<T> const alpha,
33 thrust::complex<T> const* A,
34 thrust::complex<T> const* B,
35 thrust::complex<T>* C)
36 {
37 int nu_per_block = blockDim.x;
38 int u = blockIdx.x * nu_per_block + threadIdx.x;
39 int w = threadIdx.y;
40
41 if ((u < nu) && (w < nw))
42 {
43 thrust::complex<T> Cuw = 0;
44 thrust::complex<T> const* A_(A + (u * nw + w) * na);
45 thrust::complex<T> const* B_(B + u * na);
46 for (int a = 0; a < na; ++a, ++A_, ++B_)
47 Cuw += (*A_) * (*B_);
48 C[u * nw + w] = alpha * Cuw;
49 }
50 }
51
52 // NOT OPTIMAL: poor memory access pattern
53 template<typename T>
kernel_Awiu_Biu_Cuw(int nu,int nw,int ni,thrust::complex<T> const alpha,thrust::complex<T> const * A,T const * B,int ldb,thrust::complex<T> * C,int ldc)54 __global__ void kernel_Awiu_Biu_Cuw(int nu,
55 int nw,
56 int ni,
57 thrust::complex<T> const alpha,
58 thrust::complex<T> const* A,
59 T const* B,
60 int ldb,
61 thrust::complex<T>* C,
62 int ldc)
63 {
64 int nu_per_block = blockDim.x;
65 int u = blockIdx.x * nu_per_block + threadIdx.x;
66 int w = threadIdx.y;
67
68 if ((u < nu) && (w < nw))
69 {
70 thrust::complex<T> Cuw = 0;
71 thrust::complex<T> const* A_(A + w * ni * nu + u);
72 T const* B_(B + u);
73 for (int i = 0; i < ni; ++i, A_ += nu, B_ += ldb)
74 Cuw += (*A_) * (*B_);
75 C[u * ldc + w] = alpha * Cuw;
76 }
77 }
78
79 // C[u][w] = alpha * sum_i A[w][i][u] * B[i][u]
80 template<typename T>
kernel_Awiu_Biu_Cuw(int nu,int nw,int ni,thrust::complex<T> const alpha,thrust::complex<T> const * A,thrust::complex<T> const * B,int ldb,thrust::complex<T> * C,int ldc)81 __global__ void kernel_Awiu_Biu_Cuw(int nu,
82 int nw,
83 int ni,
84 thrust::complex<T> const alpha,
85 thrust::complex<T> const* A,
86 thrust::complex<T> const* B,
87 int ldb,
88 thrust::complex<T>* C,
89 int ldc)
90 {
91 int nu_per_block = blockDim.x;
92 int u = blockIdx.x * nu_per_block + threadIdx.x;
93 int w = threadIdx.y;
94
95 if ((u < nu) && (w < nw))
96 {
97 thrust::complex<T> Cuw = 0;
98 thrust::complex<T> const* A_(A + w * ni * nu + u);
99 thrust::complex<T> const* B_(B + u);
100 for (int i = 0; i < ni; ++i, A_ += nu, B_ += ldb)
101 Cuw += (*A_) * (*B_);
102 C[u * ldc + w] = alpha * Cuw;
103 }
104 }
105
106 // Cik = sum_j Aijk * Bkj ( nthreads per block 32 )
107 template<typename T, typename T1>
kernel_Aijk_Bkj_Cik(int ni,int nj,int nk,thrust::complex<T> const * A,int lda,int stride,T1 const * B,int ldb,thrust::complex<T> * C,int ldc)108 __global__ void kernel_Aijk_Bkj_Cik(int ni,
109 int nj,
110 int nk,
111 thrust::complex<T> const* A,
112 int lda,
113 int stride,
114 T1 const* B,
115 int ldb,
116 thrust::complex<T>* C,
117 int ldc)
118 {
119 __shared__ thrust::cuda_cub::core::uninitialized_array<thrust::complex<T>, 32> cache;
120 int k = blockIdx.x;
121 int i = blockIdx.y;
122 if ((i < ni) && (k < nk))
123 {
124 cache[threadIdx.x] = thrust::complex<T>(0.0);
125 auto A_(A + i * stride + k + threadIdx.x * lda);
126 auto B_(B + k * ldb + threadIdx.x);
127 auto Bn_(B + k * ldb + nj);
128 while (B_ < Bn_)
129 {
130 cache[threadIdx.x] += (*A_) * static_cast<thrust::complex<T>>(*B_);
131 A_ += blockDim.x * lda;
132 B_ += blockDim.x;
133 }
134
135 __syncthreads();
136 int j = 16;
137 while (j > 0)
138 {
139 if (threadIdx.x < j)
140 cache[threadIdx.x] += cache[threadIdx.x + j];
141 __syncthreads();
142 j /= 2; //not sure bitwise operations are actually faster
143 }
144 if (threadIdx.x == 0)
145 *(C + i * ldc + k) += cache[0];
146 }
147 }
148
149 // A[w][i][j] = B[i][w][j]
150 template<typename T, typename T1>
kernel_viwj_vwij(int nw,int ni,int i0,int iN,thrust::complex<T> const * B,thrust::complex<T1> * A)151 __global__ void kernel_viwj_vwij(int nw, int ni, int i0, int iN, thrust::complex<T> const* B, thrust::complex<T1>* A)
152 {
153 int w = blockIdx.x;
154 int i = blockIdx.y + i0;
155 if ((w < nw) && (i < ni))
156 {
157 int j = threadIdx.x;
158 auto A_(A + (w * ni + i) * ni);
159 auto B_(B + (i * nw + w) * ni);
160 while (j < ni)
161 {
162 A_[j] = static_cast<thrust::complex<T1>>(B_[j]);
163 j += blockDim.x;
164 }
165 }
166 }
167
168 // element-wise C[k][i][j] = A[i][j] * B[j][k]
169 template<typename T>
kernel_element_wise_Aij_Bjk_Ckij(char transA,int ni,int nj,int nk,T const * A,int lda,thrust::complex<T> const * B,int ldb,thrust::complex<T> * C,int ldc1,int ldc2)170 __global__ void kernel_element_wise_Aij_Bjk_Ckij(char transA,
171 int ni,
172 int nj,
173 int nk,
174 T const* A,
175 int lda,
176 thrust::complex<T> const* B,
177 int ldb,
178 thrust::complex<T>* C,
179 int ldc1,
180 int ldc2)
181 {
182 int i = blockIdx.x;
183 int j = blockIdx.y * blockDim.x + threadIdx.x;
184 int k = blockIdx.z;
185
186 if ((i < ni) && (j < nj) && (k < nk))
187 C[(k * ldc1 + i) * ldc2 + j] = A[i * lda + j] * B[j * ldb + k];
188 }
189
190 template<typename T>
kernel_element_wise_Aij_Bjk_Ckij(char transA,int ni,int nj,int nk,thrust::complex<T> const * A,int lda,thrust::complex<T> const * B,int ldb,thrust::complex<T> * C,int ldc1,int ldc2)191 __global__ void kernel_element_wise_Aij_Bjk_Ckij(char transA,
192 int ni,
193 int nj,
194 int nk,
195 thrust::complex<T> const* A,
196 int lda,
197 thrust::complex<T> const* B,
198 int ldb,
199 thrust::complex<T>* C,
200 int ldc1,
201 int ldc2)
202 {
203 int i = blockIdx.x;
204 int j = blockIdx.y * blockDim.x + threadIdx.x;
205 int k = blockIdx.z;
206
207 if ((i < ni) && (j < nj) && (k < nk))
208 {
209 if (transA == 'N')
210 C[(k * ldc1 + i) * ldc2 + j] = A[i * lda + j] * B[j * ldb + k];
211 else if (transA == 'C')
212 C[(k * ldc1 + i) * ldc2 + j] = conj(A[i * lda + j]) * B[j * ldb + k];
213 }
214 }
215
216 // Ckji = Aij * Bjk
217 template<typename T, typename T2>
kernel_element_wise_Aij_Bjk_Ckji(int ni,int nj,int nk,T2 const * A,int lda,thrust::complex<T> const * B,int ldb,thrust::complex<T> * C,int ldc,int stride)218 __global__ void kernel_element_wise_Aij_Bjk_Ckji(int ni,
219 int nj,
220 int nk,
221 T2 const* A,
222 int lda,
223 thrust::complex<T> const* B,
224 int ldb,
225 thrust::complex<T>* C,
226 int ldc,
227 int stride)
228 {
229 // hard-coded to TILE_DIM=32
230 int TILE_DIM = 32;
231 __shared__ thrust::cuda_cub::core::uninitialized_array<T2, 32 * 32> Acache;
232 __shared__ thrust::cuda_cub::core::uninitialized_array<thrust::complex<T>, 32> Bcache;
233
234 int k = blockIdx.z;
235 int j = blockIdx.x * TILE_DIM + threadIdx.x;
236 int i = blockIdx.y * TILE_DIM + threadIdx.y;
237
238 if ((k < nk) && (j < nj))
239 {
240 int n(threadIdx.y);
241 while ((i < ni) && (n < TILE_DIM))
242 {
243 Acache[n * 32 + threadIdx.x] = A[i * lda + j];
244 n += blockDim.y;
245 i += blockDim.y;
246 }
247 if (threadIdx.y == 0)
248 Bcache[threadIdx.x] = B[j * ldb + k];
249 }
250
251 __syncthreads();
252
253 // subtle interchange of threadIdx.x/threadIdx.y
254 i = blockIdx.y * TILE_DIM + threadIdx.x;
255 j = blockIdx.x * TILE_DIM + threadIdx.y;
256
257 if ((k < nk) && (i < ni))
258 {
259 int n(threadIdx.y);
260 while ((j < nj) && (n < TILE_DIM))
261 {
262 C[k * stride + j * ldc + i] = static_cast<thrust::complex<T>>(Acache[threadIdx.x * 32 + n]) * Bcache[n];
263 n += blockDim.y;
264 j += blockDim.y;
265 }
266 }
267 }
268
269 // C[u][w] = alpha * sum_a A[u][w][a] * B[u][a]
Auwn_Bun_Cuw(int nu,int nw,int na,std::complex<double> alpha,std::complex<double> const * A,std::complex<double> const * B,std::complex<double> * C)270 void Auwn_Bun_Cuw(int nu,
271 int nw,
272 int na,
273 std::complex<double> alpha,
274 std::complex<double> const* A,
275 std::complex<double> const* B,
276 std::complex<double>* C)
277 {
278 if (size_t(nw) > MAX_THREADS_PER_DIM)
279 throw;
280 size_t nthr = std::max(size_t(1), MAX_THREADS_PER_DIM / size_t(nw));
281 size_t nbks = (nu + nthr - 1) / nthr;
282 dim3 grid_dim(nbks, 1, 1);
283 dim3 block_dim(nthr, nw, 1);
284 kernel_Auwn_Bun_Cuw<<<grid_dim, block_dim>>>(nu, nw, na, static_cast<thrust::complex<double> const>(alpha),
285 reinterpret_cast<thrust::complex<double> const*>(A),
286 reinterpret_cast<thrust::complex<double> const*>(B),
287 reinterpret_cast<thrust::complex<double>*>(C));
288 qmc_cuda::cuda_check(cudaGetLastError());
289 qmc_cuda::cuda_check(cudaDeviceSynchronize());
290 }
291
292
Auwn_Bun_Cuw(int nu,int nw,int na,std::complex<float> alpha,std::complex<float> const * A,std::complex<float> const * B,std::complex<float> * C)293 void Auwn_Bun_Cuw(int nu,
294 int nw,
295 int na,
296 std::complex<float> alpha,
297 std::complex<float> const* A,
298 std::complex<float> const* B,
299 std::complex<float>* C)
300 {
301 if (size_t(nw) > MAX_THREADS_PER_DIM)
302 throw;
303 size_t nthr = std::max(size_t(1), MAX_THREADS_PER_DIM / size_t(nw));
304 size_t nbks = (nu + nthr - 1) / nthr;
305 dim3 grid_dim(nbks, 1, 1);
306 dim3 block_dim(nthr, nw, 1);
307 kernel_Auwn_Bun_Cuw<<<grid_dim, block_dim>>>(nu, nw, na, static_cast<thrust::complex<float> const>(alpha),
308 reinterpret_cast<thrust::complex<float> const*>(A),
309 reinterpret_cast<thrust::complex<float> const*>(B),
310 reinterpret_cast<thrust::complex<float>*>(C));
311 qmc_cuda::cuda_check(cudaGetLastError());
312 qmc_cuda::cuda_check(cudaDeviceSynchronize());
313 }
314
315 // C[u][w] = alpha * sum_i A[w][i][u] * B[i][u]
Awiu_Biu_Cuw(int nu,int nw,int na,std::complex<double> alpha,std::complex<double> const * A,double const * B,int ldb,std::complex<double> * C,int ldc)316 void Awiu_Biu_Cuw(int nu,
317 int nw,
318 int na,
319 std::complex<double> alpha,
320 std::complex<double> const* A,
321 double const* B,
322 int ldb,
323 std::complex<double>* C,
324 int ldc)
325 {
326 if (size_t(nw) > MAX_THREADS_PER_DIM)
327 throw;
328 size_t nthr = std::max(size_t(1), MAX_THREADS_PER_DIM / size_t(nw));
329 size_t nbks = (nu + nthr - 1) / nthr;
330 dim3 grid_dim(nbks, 1, 1);
331 dim3 block_dim(nthr, nw, 1);
332 kernel_Awiu_Biu_Cuw<<<grid_dim, block_dim>>>(nu, nw, na, static_cast<thrust::complex<double> const>(alpha),
333 reinterpret_cast<thrust::complex<double> const*>(A), B, ldb,
334 reinterpret_cast<thrust::complex<double>*>(C), ldc);
335 qmc_cuda::cuda_check(cudaGetLastError());
336 qmc_cuda::cuda_check(cudaDeviceSynchronize());
337 }
338
339
Awiu_Biu_Cuw(int nu,int nw,int na,std::complex<float> alpha,std::complex<float> const * A,float const * B,int ldb,std::complex<float> * C,int ldc)340 void Awiu_Biu_Cuw(int nu,
341 int nw,
342 int na,
343 std::complex<float> alpha,
344 std::complex<float> const* A,
345 float const* B,
346 int ldb,
347 std::complex<float>* C,
348 int ldc)
349 {
350 if (size_t(nw) > MAX_THREADS_PER_DIM)
351 throw;
352 size_t nthr = std::max(size_t(1), MAX_THREADS_PER_DIM / size_t(nw));
353 size_t nbks = (nu + nthr - 1) / nthr;
354 dim3 grid_dim(nbks, 1, 1);
355 dim3 block_dim(nthr, nw, 1);
356 kernel_Awiu_Biu_Cuw<<<grid_dim, block_dim>>>(nu, nw, na, static_cast<thrust::complex<float> const>(alpha),
357 reinterpret_cast<thrust::complex<float> const*>(A), B, ldb,
358 reinterpret_cast<thrust::complex<float>*>(C), ldc);
359 qmc_cuda::cuda_check(cudaGetLastError());
360 qmc_cuda::cuda_check(cudaDeviceSynchronize());
361 }
362
Awiu_Biu_Cuw(int nu,int nw,int na,std::complex<double> alpha,std::complex<double> const * A,std::complex<double> const * B,int ldb,std::complex<double> * C,int ldc)363 void Awiu_Biu_Cuw(int nu,
364 int nw,
365 int na,
366 std::complex<double> alpha,
367 std::complex<double> const* A,
368 std::complex<double> const* B,
369 int ldb,
370 std::complex<double>* C,
371 int ldc)
372 {
373 if (size_t(nw) > MAX_THREADS_PER_DIM)
374 throw;
375 size_t nthr = std::max(size_t(1), MAX_THREADS_PER_DIM / size_t(nw));
376 size_t nbks = (nu + nthr - 1) / nthr;
377 dim3 grid_dim(nbks, 1, 1);
378 dim3 block_dim(nthr, nw, 1);
379 kernel_Awiu_Biu_Cuw<<<grid_dim, block_dim>>>(nu, nw, na, static_cast<thrust::complex<double> const>(alpha),
380 reinterpret_cast<thrust::complex<double> const*>(A),
381 reinterpret_cast<thrust::complex<double> const*>(B), ldb,
382 reinterpret_cast<thrust::complex<double>*>(C), ldc);
383 qmc_cuda::cuda_check(cudaGetLastError());
384 qmc_cuda::cuda_check(cudaDeviceSynchronize());
385 }
386
387
Awiu_Biu_Cuw(int nu,int nw,int na,std::complex<float> alpha,std::complex<float> const * A,std::complex<float> const * B,int ldb,std::complex<float> * C,int ldc)388 void Awiu_Biu_Cuw(int nu,
389 int nw,
390 int na,
391 std::complex<float> alpha,
392 std::complex<float> const* A,
393 std::complex<float> const* B,
394 int ldb,
395 std::complex<float>* C,
396 int ldc)
397 {
398 if (size_t(nw) > MAX_THREADS_PER_DIM)
399 throw;
400 size_t nthr = std::max(size_t(1), MAX_THREADS_PER_DIM / size_t(nw));
401 size_t nbks = (nu + nthr - 1) / nthr;
402 dim3 grid_dim(nbks, 1, 1);
403 dim3 block_dim(nthr, nw, 1);
404 kernel_Awiu_Biu_Cuw<<<grid_dim, block_dim>>>(nu, nw, na, static_cast<thrust::complex<float> const>(alpha),
405 reinterpret_cast<thrust::complex<float> const*>(A),
406 reinterpret_cast<thrust::complex<float> const*>(B), ldb,
407 reinterpret_cast<thrust::complex<float>*>(C), ldc);
408 qmc_cuda::cuda_check(cudaGetLastError());
409 qmc_cuda::cuda_check(cudaDeviceSynchronize());
410 }
411
412
413 // C[i][k] = sum_i A[i][j][k] * B[k][j]
Aijk_Bkj_Cik(int ni,int nj,int nk,std::complex<double> const * A,int lda,int stride,std::complex<double> const * B,int ldb,std::complex<double> * C,int ldc)414 void Aijk_Bkj_Cik(int ni,
415 int nj,
416 int nk,
417 std::complex<double> const* A,
418 int lda,
419 int stride,
420 std::complex<double> const* B,
421 int ldb,
422 std::complex<double>* C,
423 int ldc)
424 {
425 // expect nk >> ni,nj
426 dim3 grid_dim(nk, ni, 1);
427 kernel_Aijk_Bkj_Cik<<<grid_dim, 32>>>(ni, nj, nk, reinterpret_cast<thrust::complex<double> const*>(A), lda, stride,
428 reinterpret_cast<thrust::complex<double> const*>(B), ldb,
429 reinterpret_cast<thrust::complex<double>*>(C), ldc);
430 qmc_cuda::cuda_check(cudaGetLastError());
431 qmc_cuda::cuda_check(cudaDeviceSynchronize());
432 }
433
Aijk_Bkj_Cik(int ni,int nj,int nk,std::complex<double> const * A,int lda,int stride,double const * B,int ldb,std::complex<double> * C,int ldc)434 void Aijk_Bkj_Cik(int ni,
435 int nj,
436 int nk,
437 std::complex<double> const* A,
438 int lda,
439 int stride,
440 double const* B,
441 int ldb,
442 std::complex<double>* C,
443 int ldc)
444 {
445 // expect nk >> ni,nj
446 dim3 grid_dim(nk, ni, 1);
447 kernel_Aijk_Bkj_Cik<<<grid_dim, 32>>>(ni, nj, nk, reinterpret_cast<thrust::complex<double> const*>(A), lda, stride, B,
448 ldb, reinterpret_cast<thrust::complex<double>*>(C), ldc);
449 qmc_cuda::cuda_check(cudaGetLastError());
450 qmc_cuda::cuda_check(cudaDeviceSynchronize());
451 }
452
Aijk_Bkj_Cik(int ni,int nj,int nk,std::complex<float> const * A,int lda,int stride,std::complex<float> const * B,int ldb,std::complex<float> * C,int ldc)453 void Aijk_Bkj_Cik(int ni,
454 int nj,
455 int nk,
456 std::complex<float> const* A,
457 int lda,
458 int stride,
459 std::complex<float> const* B,
460 int ldb,
461 std::complex<float>* C,
462 int ldc)
463 {
464 // expect nk >> ni,nj
465 dim3 grid_dim(nk, ni, 1);
466 kernel_Aijk_Bkj_Cik<<<grid_dim, 32>>>(ni, nj, nk, reinterpret_cast<thrust::complex<float> const*>(A), lda, stride,
467 reinterpret_cast<thrust::complex<float> const*>(B), ldb,
468 reinterpret_cast<thrust::complex<float>*>(C), ldc);
469 qmc_cuda::cuda_check(cudaGetLastError());
470 qmc_cuda::cuda_check(cudaDeviceSynchronize());
471 }
472
Aijk_Bkj_Cik(int ni,int nj,int nk,std::complex<float> const * A,int lda,int stride,float const * B,int ldb,std::complex<float> * C,int ldc)473 void Aijk_Bkj_Cik(int ni,
474 int nj,
475 int nk,
476 std::complex<float> const* A,
477 int lda,
478 int stride,
479 float const* B,
480 int ldb,
481 std::complex<float>* C,
482 int ldc)
483 {
484 // expect nk >> ni,nj
485 dim3 grid_dim(nk, ni, 1);
486 kernel_Aijk_Bkj_Cik<<<grid_dim, 32>>>(ni, nj, nk, reinterpret_cast<thrust::complex<float> const*>(A), lda, stride, B,
487 ldb, reinterpret_cast<thrust::complex<float>*>(C), ldc);
488 qmc_cuda::cuda_check(cudaGetLastError());
489 qmc_cuda::cuda_check(cudaDeviceSynchronize());
490 }
491
492 // v[w][i][j] = v[i][w][j]
viwj_vwij(int nw,int ni,int i0,int iN,std::complex<double> const * B,std::complex<double> * A)493 void viwj_vwij(int nw, int ni, int i0, int iN, std::complex<double> const* B, std::complex<double>* A)
494 {
495 // expect ni > nw
496 dim3 grid_dim(nw, (iN - i0), 1);
497 kernel_viwj_vwij<<<grid_dim, MAX_THREADS_PER_DIM>>>(nw, ni, i0, iN,
498 reinterpret_cast<thrust::complex<double> const*>(B),
499 reinterpret_cast<thrust::complex<double>*>(A));
500 qmc_cuda::cuda_check(cudaGetLastError());
501 qmc_cuda::cuda_check(cudaDeviceSynchronize());
502 }
503
viwj_vwij(int nw,int ni,int i0,int iN,std::complex<double> const * B,std::complex<float> * A)504 void viwj_vwij(int nw, int ni, int i0, int iN, std::complex<double> const* B, std::complex<float>* A)
505 {
506 // expect ni > nw
507 dim3 grid_dim(nw, (iN - i0), 1);
508 kernel_viwj_vwij<<<grid_dim, MAX_THREADS_PER_DIM>>>(nw, ni, i0, iN,
509 reinterpret_cast<thrust::complex<double> const*>(B),
510 reinterpret_cast<thrust::complex<float>*>(A));
511 qmc_cuda::cuda_check(cudaGetLastError());
512 qmc_cuda::cuda_check(cudaDeviceSynchronize());
513 }
514
viwj_vwij(int nw,int ni,int i0,int iN,std::complex<float> const * B,std::complex<double> * A)515 void viwj_vwij(int nw, int ni, int i0, int iN, std::complex<float> const* B, std::complex<double>* A)
516 {
517 // expect ni > nw
518 dim3 grid_dim(nw, (iN - i0), 1);
519 kernel_viwj_vwij<<<grid_dim, MAX_THREADS_PER_DIM>>>(nw, ni, i0, iN,
520 reinterpret_cast<thrust::complex<float> const*>(B),
521 reinterpret_cast<thrust::complex<double>*>(A));
522 qmc_cuda::cuda_check(cudaGetLastError());
523 qmc_cuda::cuda_check(cudaDeviceSynchronize());
524 }
525
viwj_vwij(int nw,int ni,int i0,int iN,std::complex<float> const * B,std::complex<float> * A)526 void viwj_vwij(int nw, int ni, int i0, int iN, std::complex<float> const* B, std::complex<float>* A)
527 {
528 // expect ni > nw
529 dim3 grid_dim(nw, (iN - i0), 1);
530 kernel_viwj_vwij<<<grid_dim, MAX_THREADS_PER_DIM>>>(nw, ni, i0, iN,
531 reinterpret_cast<thrust::complex<float> const*>(B),
532 reinterpret_cast<thrust::complex<float>*>(A));
533 qmc_cuda::cuda_check(cudaGetLastError());
534 qmc_cuda::cuda_check(cudaDeviceSynchronize());
535 }
536
537 // element-wise C[k][i][j] = A[i][j] * B[j][k]
element_wise_Aij_Bjk_Ckij(char transA,int ni,int nj,int nk,double const * A,int lda,std::complex<double> const * B,int ldb,std::complex<double> * C,int ldc1,int ldc2)538 void element_wise_Aij_Bjk_Ckij(char transA,
539 int ni,
540 int nj,
541 int nk,
542 double const* A,
543 int lda,
544 std::complex<double> const* B,
545 int ldb,
546 std::complex<double>* C,
547 int ldc1,
548 int ldc2)
549 {
550 // setup for nj >> ni,nk
551 size_t nbks = (nj + MAX_THREADS_PER_DIM - 1) / MAX_THREADS_PER_DIM;
552 dim3 grid_dim(ni, nbks, nk);
553 kernel_element_wise_Aij_Bjk_Ckij<<<grid_dim, MAX_THREADS_PER_DIM>>>(transA, ni, nj, nk, A, lda,
554 reinterpret_cast<thrust::complex<double> const*>(
555 B),
556 ldb,
557 reinterpret_cast<thrust::complex<double>*>(C),
558 ldc1, ldc2);
559 qmc_cuda::cuda_check(cudaGetLastError());
560 qmc_cuda::cuda_check(cudaDeviceSynchronize());
561 }
562
element_wise_Aij_Bjk_Ckij(char transA,int ni,int nj,int nk,float const * A,int lda,std::complex<float> const * B,int ldb,std::complex<float> * C,int ldc1,int ldc2)563 void element_wise_Aij_Bjk_Ckij(char transA,
564 int ni,
565 int nj,
566 int nk,
567 float const* A,
568 int lda,
569 std::complex<float> const* B,
570 int ldb,
571 std::complex<float>* C,
572 int ldc1,
573 int ldc2)
574 {
575 // setup for nj >> ni,nk
576 size_t nbks = (nj + MAX_THREADS_PER_DIM - 1) / MAX_THREADS_PER_DIM;
577 dim3 grid_dim(ni, nbks, nk);
578 kernel_element_wise_Aij_Bjk_Ckij<<<grid_dim, MAX_THREADS_PER_DIM>>>(transA, ni, nj, nk, A, lda,
579 reinterpret_cast<thrust::complex<float> const*>(
580 B),
581 ldb, reinterpret_cast<thrust::complex<float>*>(C),
582 ldc1, ldc2);
583 qmc_cuda::cuda_check(cudaGetLastError());
584 qmc_cuda::cuda_check(cudaDeviceSynchronize());
585 }
586
element_wise_Aij_Bjk_Ckij(char transA,int ni,int nj,int nk,std::complex<double> const * A,int lda,std::complex<double> const * B,int ldb,std::complex<double> * C,int ldc1,int ldc2)587 void element_wise_Aij_Bjk_Ckij(char transA,
588 int ni,
589 int nj,
590 int nk,
591 std::complex<double> const* A,
592 int lda,
593 std::complex<double> const* B,
594 int ldb,
595 std::complex<double>* C,
596 int ldc1,
597 int ldc2)
598 {
599 // setup for nj >> ni,nk
600 size_t nbks = (nj + MAX_THREADS_PER_DIM - 1) / MAX_THREADS_PER_DIM;
601 dim3 grid_dim(ni, nbks, nk);
602 kernel_element_wise_Aij_Bjk_Ckij<<<grid_dim, MAX_THREADS_PER_DIM>>>(transA, ni, nj, nk,
603 reinterpret_cast<thrust::complex<double> const*>(
604 A),
605 lda,
606 reinterpret_cast<thrust::complex<double> const*>(
607 B),
608 ldb,
609 reinterpret_cast<thrust::complex<double>*>(C),
610 ldc1, ldc2);
611 qmc_cuda::cuda_check(cudaGetLastError());
612 qmc_cuda::cuda_check(cudaDeviceSynchronize());
613 }
614
element_wise_Aij_Bjk_Ckij(char transA,int ni,int nj,int nk,std::complex<float> const * A,int lda,std::complex<float> const * B,int ldb,std::complex<float> * C,int ldc1,int ldc2)615 void element_wise_Aij_Bjk_Ckij(char transA,
616 int ni,
617 int nj,
618 int nk,
619 std::complex<float> const* A,
620 int lda,
621 std::complex<float> const* B,
622 int ldb,
623 std::complex<float>* C,
624 int ldc1,
625 int ldc2)
626 {
627 // setup for nj >> ni,nk
628 size_t nbks = (nj + MAX_THREADS_PER_DIM - 1) / MAX_THREADS_PER_DIM;
629 dim3 grid_dim(ni, nbks, nk);
630 kernel_element_wise_Aij_Bjk_Ckij<<<grid_dim, MAX_THREADS_PER_DIM>>>(transA, ni, nj, nk,
631 reinterpret_cast<thrust::complex<float> const*>(
632 A),
633 lda,
634 reinterpret_cast<thrust::complex<float> const*>(
635 B),
636 ldb, reinterpret_cast<thrust::complex<float>*>(C),
637 ldc1, ldc2);
638 qmc_cuda::cuda_check(cudaGetLastError());
639 qmc_cuda::cuda_check(cudaDeviceSynchronize());
640 }
641
642
643 // element-wise C[k][j][i] = A[i][j] * B[j][k]
element_wise_Aij_Bjk_Ckji(int ni,int nj,int nk,double const * A,int lda,std::complex<double> const * B,int ldb,std::complex<double> * C,int ldc,int stride)644 void element_wise_Aij_Bjk_Ckji(int ni,
645 int nj,
646 int nk,
647 double const* A,
648 int lda,
649 std::complex<double> const* B,
650 int ldb,
651 std::complex<double>* C,
652 int ldc,
653 int stride)
654 {
655 // setup for nj >> ni,nk
656 size_t nthr = 32;
657 size_t nthrj = 8;
658 size_t ib = (ni + nthr - 1) / nthr;
659 size_t jb = (nj + nthr - 1) / nthr;
660 dim3 grid_dim(jb, ib, nk);
661 dim3 block_dim(nthr, nthrj, 1);
662 kernel_element_wise_Aij_Bjk_Ckji<<<grid_dim, block_dim>>>(ni, nj, nk, A, lda,
663 reinterpret_cast<thrust::complex<double> const*>(B), ldb,
664 reinterpret_cast<thrust::complex<double>*>(C), ldc, stride);
665 qmc_cuda::cuda_check(cudaGetLastError());
666 qmc_cuda::cuda_check(cudaDeviceSynchronize());
667 }
668
element_wise_Aij_Bjk_Ckji(int ni,int nj,int nk,std::complex<double> const * A,int lda,std::complex<double> const * B,int ldb,std::complex<double> * C,int ldc,int stride)669 void element_wise_Aij_Bjk_Ckji(int ni,
670 int nj,
671 int nk,
672 std::complex<double> const* A,
673 int lda,
674 std::complex<double> const* B,
675 int ldb,
676 std::complex<double>* C,
677 int ldc,
678 int stride)
679 {
680 // setup for nj >> ni,nk
681 size_t nthr = 32;
682 size_t nthrj = 8;
683 size_t ib = (ni + nthr - 1) / nthr;
684 size_t jb = (nj + nthr - 1) / nthr;
685 // jb goes along x since this is the fast index in Aij, needed for better memory access patterns
686 dim3 grid_dim(jb, ib, nk);
687 dim3 block_dim(nthr, nthrj, 1);
688 kernel_element_wise_Aij_Bjk_Ckji<<<grid_dim, block_dim>>>(ni, nj, nk,
689 reinterpret_cast<thrust::complex<double> const*>(A), lda,
690 reinterpret_cast<thrust::complex<double> const*>(B), ldb,
691 reinterpret_cast<thrust::complex<double>*>(C), ldc, stride);
692 qmc_cuda::cuda_check(cudaGetLastError());
693 qmc_cuda::cuda_check(cudaDeviceSynchronize());
694 }
695
element_wise_Aij_Bjk_Ckji(int ni,int nj,int nk,float const * A,int lda,std::complex<float> const * B,int ldb,std::complex<float> * C,int ldc,int stride)696 void element_wise_Aij_Bjk_Ckji(int ni,
697 int nj,
698 int nk,
699 float const* A,
700 int lda,
701 std::complex<float> const* B,
702 int ldb,
703 std::complex<float>* C,
704 int ldc,
705 int stride)
706 {
707 // setup for nj >> ni,nk
708 size_t nthr = 32;
709 size_t nthrj = 8;
710 size_t ib = (ni + nthr - 1) / nthr;
711 size_t jb = (nj + nthr - 1) / nthr;
712 dim3 grid_dim(jb, ib, nk);
713 dim3 block_dim(nthr, nthrj, 1);
714 kernel_element_wise_Aij_Bjk_Ckji<<<grid_dim, block_dim>>>(ni, nj, nk, A, lda,
715 reinterpret_cast<thrust::complex<float> const*>(B), ldb,
716 reinterpret_cast<thrust::complex<float>*>(C), ldc, stride);
717 qmc_cuda::cuda_check(cudaGetLastError());
718 qmc_cuda::cuda_check(cudaDeviceSynchronize());
719 }
720
element_wise_Aij_Bjk_Ckji(int ni,int nj,int nk,std::complex<float> const * A,int lda,std::complex<float> const * B,int ldb,std::complex<float> * C,int ldc,int stride)721 void element_wise_Aij_Bjk_Ckji(int ni,
722 int nj,
723 int nk,
724 std::complex<float> const* A,
725 int lda,
726 std::complex<float> const* B,
727 int ldb,
728 std::complex<float>* C,
729 int ldc,
730 int stride)
731 {
732 // setup for nj >> ni,nk
733 size_t nthr = 32;
734 size_t nthrj = 8;
735 size_t ib = (ni + nthr - 1) / nthr;
736 size_t jb = (nj + nthr - 1) / nthr;
737 dim3 grid_dim(jb, ib, nk);
738 dim3 block_dim(nthr, nthrj, 1);
739 kernel_element_wise_Aij_Bjk_Ckji<<<grid_dim, block_dim>>>(ni, nj, nk,
740 reinterpret_cast<thrust::complex<float> const*>(A), lda,
741 reinterpret_cast<thrust::complex<float> const*>(B), ldb,
742 reinterpret_cast<thrust::complex<float>*>(C), ldc, stride);
743 qmc_cuda::cuda_check(cudaGetLastError());
744 qmc_cuda::cuda_check(cudaDeviceSynchronize());
745 }
746
747
748 } // namespace kernels
749