1 //////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source
3 // License.  See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
6 //
7 // File developed by:
8 //    Lawrence Livermore National Laboratory
9 //
10 // File created by:
11 // Miguel A. Morales, moralessilva2@llnl.gov
12 //    Lawrence Livermore National Laboratory
13 ////////////////////////////////////////////////////////////////////////////////
14 
15 #include <cassert>
16 #include <complex>
17 #include <cuda.h>
18 #include <thrust/complex.h>
19 #include <cuda_runtime.h>
20 #include <thrust/system/cuda/detail/core/util.h>
21 #include "AFQMC/Numerics/detail/CUDA/Kernels/cuda_settings.h"
22 #define ENABLE_CUDA 1
23 #include "AFQMC/Memory/CUDA/cuda_utilities.h"
24 
25 namespace kernels
26 {
27 // C[u][w] = alpha * sum_a A[u][w][a] * B[u][a]
28 template<typename T>
kernel_Auwn_Bun_Cuw(int nu,int nw,int na,thrust::complex<T> const alpha,thrust::complex<T> const * A,thrust::complex<T> const * B,thrust::complex<T> * C)29 __global__ void kernel_Auwn_Bun_Cuw(int nu,
30                                     int nw,
31                                     int na,
32                                     thrust::complex<T> const alpha,
33                                     thrust::complex<T> const* A,
34                                     thrust::complex<T> const* B,
35                                     thrust::complex<T>* C)
36 {
37   int nu_per_block = blockDim.x;
38   int u            = blockIdx.x * nu_per_block + threadIdx.x;
39   int w            = threadIdx.y;
40 
41   if ((u < nu) && (w < nw))
42   {
43     thrust::complex<T> Cuw = 0;
44     thrust::complex<T> const* A_(A + (u * nw + w) * na);
45     thrust::complex<T> const* B_(B + u * na);
46     for (int a = 0; a < na; ++a, ++A_, ++B_)
47       Cuw += (*A_) * (*B_);
48     C[u * nw + w] = alpha * Cuw;
49   }
50 }
51 
52 // NOT OPTIMAL: poor memory access pattern
53 template<typename T>
kernel_Awiu_Biu_Cuw(int nu,int nw,int ni,thrust::complex<T> const alpha,thrust::complex<T> const * A,T const * B,int ldb,thrust::complex<T> * C,int ldc)54 __global__ void kernel_Awiu_Biu_Cuw(int nu,
55                                     int nw,
56                                     int ni,
57                                     thrust::complex<T> const alpha,
58                                     thrust::complex<T> const* A,
59                                     T const* B,
60                                     int ldb,
61                                     thrust::complex<T>* C,
62                                     int ldc)
63 {
64   int nu_per_block = blockDim.x;
65   int u            = blockIdx.x * nu_per_block + threadIdx.x;
66   int w            = threadIdx.y;
67 
68   if ((u < nu) && (w < nw))
69   {
70     thrust::complex<T> Cuw = 0;
71     thrust::complex<T> const* A_(A + w * ni * nu + u);
72     T const* B_(B + u);
73     for (int i = 0; i < ni; ++i, A_ += nu, B_ += ldb)
74       Cuw += (*A_) * (*B_);
75     C[u * ldc + w] = alpha * Cuw;
76   }
77 }
78 
79 // C[u][w] = alpha * sum_i A[w][i][u] * B[i][u]
80 template<typename T>
kernel_Awiu_Biu_Cuw(int nu,int nw,int ni,thrust::complex<T> const alpha,thrust::complex<T> const * A,thrust::complex<T> const * B,int ldb,thrust::complex<T> * C,int ldc)81 __global__ void kernel_Awiu_Biu_Cuw(int nu,
82                                     int nw,
83                                     int ni,
84                                     thrust::complex<T> const alpha,
85                                     thrust::complex<T> const* A,
86                                     thrust::complex<T> const* B,
87                                     int ldb,
88                                     thrust::complex<T>* C,
89                                     int ldc)
90 {
91   int nu_per_block = blockDim.x;
92   int u            = blockIdx.x * nu_per_block + threadIdx.x;
93   int w            = threadIdx.y;
94 
95   if ((u < nu) && (w < nw))
96   {
97     thrust::complex<T> Cuw = 0;
98     thrust::complex<T> const* A_(A + w * ni * nu + u);
99     thrust::complex<T> const* B_(B + u);
100     for (int i = 0; i < ni; ++i, A_ += nu, B_ += ldb)
101       Cuw += (*A_) * (*B_);
102     C[u * ldc + w] = alpha * Cuw;
103   }
104 }
105 
106 // Cik = sum_j Aijk * Bkj  ( nthreads per block 32 )
107 template<typename T, typename T1>
kernel_Aijk_Bkj_Cik(int ni,int nj,int nk,thrust::complex<T> const * A,int lda,int stride,T1 const * B,int ldb,thrust::complex<T> * C,int ldc)108 __global__ void kernel_Aijk_Bkj_Cik(int ni,
109                                     int nj,
110                                     int nk,
111                                     thrust::complex<T> const* A,
112                                     int lda,
113                                     int stride,
114                                     T1 const* B,
115                                     int ldb,
116                                     thrust::complex<T>* C,
117                                     int ldc)
118 {
119   __shared__ thrust::cuda_cub::core::uninitialized_array<thrust::complex<T>, 32> cache;
120   int k = blockIdx.x;
121   int i = blockIdx.y;
122   if ((i < ni) && (k < nk))
123   {
124     cache[threadIdx.x] = thrust::complex<T>(0.0);
125     auto A_(A + i * stride + k + threadIdx.x * lda);
126     auto B_(B + k * ldb + threadIdx.x);
127     auto Bn_(B + k * ldb + nj);
128     while (B_ < Bn_)
129     {
130       cache[threadIdx.x] += (*A_) * static_cast<thrust::complex<T>>(*B_);
131       A_ += blockDim.x * lda;
132       B_ += blockDim.x;
133     }
134 
135     __syncthreads();
136     int j = 16;
137     while (j > 0)
138     {
139       if (threadIdx.x < j)
140         cache[threadIdx.x] += cache[threadIdx.x + j];
141       __syncthreads();
142       j /= 2; //not sure bitwise operations are actually faster
143     }
144     if (threadIdx.x == 0)
145       *(C + i * ldc + k) += cache[0];
146   }
147 }
148 
149 // A[w][i][j] = B[i][w][j]
150 template<typename T, typename T1>
kernel_viwj_vwij(int nw,int ni,int i0,int iN,thrust::complex<T> const * B,thrust::complex<T1> * A)151 __global__ void kernel_viwj_vwij(int nw, int ni, int i0, int iN, thrust::complex<T> const* B, thrust::complex<T1>* A)
152 {
153   int w = blockIdx.x;
154   int i = blockIdx.y + i0;
155   if ((w < nw) && (i < ni))
156   {
157     int j = threadIdx.x;
158     auto A_(A + (w * ni + i) * ni);
159     auto B_(B + (i * nw + w) * ni);
160     while (j < ni)
161     {
162       A_[j] = static_cast<thrust::complex<T1>>(B_[j]);
163       j += blockDim.x;
164     }
165   }
166 }
167 
168 // element-wise C[k][i][j] = A[i][j] * B[j][k]
169 template<typename T>
kernel_element_wise_Aij_Bjk_Ckij(char transA,int ni,int nj,int nk,T const * A,int lda,thrust::complex<T> const * B,int ldb,thrust::complex<T> * C,int ldc1,int ldc2)170 __global__ void kernel_element_wise_Aij_Bjk_Ckij(char transA,
171                                                  int ni,
172                                                  int nj,
173                                                  int nk,
174                                                  T const* A,
175                                                  int lda,
176                                                  thrust::complex<T> const* B,
177                                                  int ldb,
178                                                  thrust::complex<T>* C,
179                                                  int ldc1,
180                                                  int ldc2)
181 {
182   int i = blockIdx.x;
183   int j = blockIdx.y * blockDim.x + threadIdx.x;
184   int k = blockIdx.z;
185 
186   if ((i < ni) && (j < nj) && (k < nk))
187     C[(k * ldc1 + i) * ldc2 + j] = A[i * lda + j] * B[j * ldb + k];
188 }
189 
190 template<typename T>
kernel_element_wise_Aij_Bjk_Ckij(char transA,int ni,int nj,int nk,thrust::complex<T> const * A,int lda,thrust::complex<T> const * B,int ldb,thrust::complex<T> * C,int ldc1,int ldc2)191 __global__ void kernel_element_wise_Aij_Bjk_Ckij(char transA,
192                                                  int ni,
193                                                  int nj,
194                                                  int nk,
195                                                  thrust::complex<T> const* A,
196                                                  int lda,
197                                                  thrust::complex<T> const* B,
198                                                  int ldb,
199                                                  thrust::complex<T>* C,
200                                                  int ldc1,
201                                                  int ldc2)
202 {
203   int i = blockIdx.x;
204   int j = blockIdx.y * blockDim.x + threadIdx.x;
205   int k = blockIdx.z;
206 
207   if ((i < ni) && (j < nj) && (k < nk))
208   {
209     if (transA == 'N')
210       C[(k * ldc1 + i) * ldc2 + j] = A[i * lda + j] * B[j * ldb + k];
211     else if (transA == 'C')
212       C[(k * ldc1 + i) * ldc2 + j] = conj(A[i * lda + j]) * B[j * ldb + k];
213   }
214 }
215 
216 // Ckji = Aij * Bjk
217 template<typename T, typename T2>
kernel_element_wise_Aij_Bjk_Ckji(int ni,int nj,int nk,T2 const * A,int lda,thrust::complex<T> const * B,int ldb,thrust::complex<T> * C,int ldc,int stride)218 __global__ void kernel_element_wise_Aij_Bjk_Ckji(int ni,
219                                                  int nj,
220                                                  int nk,
221                                                  T2 const* A,
222                                                  int lda,
223                                                  thrust::complex<T> const* B,
224                                                  int ldb,
225                                                  thrust::complex<T>* C,
226                                                  int ldc,
227                                                  int stride)
228 {
229   // hard-coded to TILE_DIM=32
230   int TILE_DIM = 32;
231   __shared__ thrust::cuda_cub::core::uninitialized_array<T2, 32 * 32> Acache;
232   __shared__ thrust::cuda_cub::core::uninitialized_array<thrust::complex<T>, 32> Bcache;
233 
234   int k = blockIdx.z;
235   int j = blockIdx.x * TILE_DIM + threadIdx.x;
236   int i = blockIdx.y * TILE_DIM + threadIdx.y;
237 
238   if ((k < nk) && (j < nj))
239   {
240     int n(threadIdx.y);
241     while ((i < ni) && (n < TILE_DIM))
242     {
243       Acache[n * 32 + threadIdx.x] = A[i * lda + j];
244       n += blockDim.y;
245       i += blockDim.y;
246     }
247     if (threadIdx.y == 0)
248       Bcache[threadIdx.x] = B[j * ldb + k];
249   }
250 
251   __syncthreads();
252 
253   // subtle interchange of threadIdx.x/threadIdx.y
254   i = blockIdx.y * TILE_DIM + threadIdx.x;
255   j = blockIdx.x * TILE_DIM + threadIdx.y;
256 
257   if ((k < nk) && (i < ni))
258   {
259     int n(threadIdx.y);
260     while ((j < nj) && (n < TILE_DIM))
261     {
262       C[k * stride + j * ldc + i] = static_cast<thrust::complex<T>>(Acache[threadIdx.x * 32 + n]) * Bcache[n];
263       n += blockDim.y;
264       j += blockDim.y;
265     }
266   }
267 }
268 
269 // C[u][w] = alpha * sum_a A[u][w][a] * B[u][a]
Auwn_Bun_Cuw(int nu,int nw,int na,std::complex<double> alpha,std::complex<double> const * A,std::complex<double> const * B,std::complex<double> * C)270 void Auwn_Bun_Cuw(int nu,
271                   int nw,
272                   int na,
273                   std::complex<double> alpha,
274                   std::complex<double> const* A,
275                   std::complex<double> const* B,
276                   std::complex<double>* C)
277 {
278   if (size_t(nw) > MAX_THREADS_PER_DIM)
279     throw;
280   size_t nthr = std::max(size_t(1), MAX_THREADS_PER_DIM / size_t(nw));
281   size_t nbks = (nu + nthr - 1) / nthr;
282   dim3 grid_dim(nbks, 1, 1);
283   dim3 block_dim(nthr, nw, 1);
284   kernel_Auwn_Bun_Cuw<<<grid_dim, block_dim>>>(nu, nw, na, static_cast<thrust::complex<double> const>(alpha),
285                                                reinterpret_cast<thrust::complex<double> const*>(A),
286                                                reinterpret_cast<thrust::complex<double> const*>(B),
287                                                reinterpret_cast<thrust::complex<double>*>(C));
288   qmc_cuda::cuda_check(cudaGetLastError());
289   qmc_cuda::cuda_check(cudaDeviceSynchronize());
290 }
291 
292 
Auwn_Bun_Cuw(int nu,int nw,int na,std::complex<float> alpha,std::complex<float> const * A,std::complex<float> const * B,std::complex<float> * C)293 void Auwn_Bun_Cuw(int nu,
294                   int nw,
295                   int na,
296                   std::complex<float> alpha,
297                   std::complex<float> const* A,
298                   std::complex<float> const* B,
299                   std::complex<float>* C)
300 {
301   if (size_t(nw) > MAX_THREADS_PER_DIM)
302     throw;
303   size_t nthr = std::max(size_t(1), MAX_THREADS_PER_DIM / size_t(nw));
304   size_t nbks = (nu + nthr - 1) / nthr;
305   dim3 grid_dim(nbks, 1, 1);
306   dim3 block_dim(nthr, nw, 1);
307   kernel_Auwn_Bun_Cuw<<<grid_dim, block_dim>>>(nu, nw, na, static_cast<thrust::complex<float> const>(alpha),
308                                                reinterpret_cast<thrust::complex<float> const*>(A),
309                                                reinterpret_cast<thrust::complex<float> const*>(B),
310                                                reinterpret_cast<thrust::complex<float>*>(C));
311   qmc_cuda::cuda_check(cudaGetLastError());
312   qmc_cuda::cuda_check(cudaDeviceSynchronize());
313 }
314 
315 // C[u][w] = alpha * sum_i A[w][i][u] * B[i][u]
Awiu_Biu_Cuw(int nu,int nw,int na,std::complex<double> alpha,std::complex<double> const * A,double const * B,int ldb,std::complex<double> * C,int ldc)316 void Awiu_Biu_Cuw(int nu,
317                   int nw,
318                   int na,
319                   std::complex<double> alpha,
320                   std::complex<double> const* A,
321                   double const* B,
322                   int ldb,
323                   std::complex<double>* C,
324                   int ldc)
325 {
326   if (size_t(nw) > MAX_THREADS_PER_DIM)
327     throw;
328   size_t nthr = std::max(size_t(1), MAX_THREADS_PER_DIM / size_t(nw));
329   size_t nbks = (nu + nthr - 1) / nthr;
330   dim3 grid_dim(nbks, 1, 1);
331   dim3 block_dim(nthr, nw, 1);
332   kernel_Awiu_Biu_Cuw<<<grid_dim, block_dim>>>(nu, nw, na, static_cast<thrust::complex<double> const>(alpha),
333                                                reinterpret_cast<thrust::complex<double> const*>(A), B, ldb,
334                                                reinterpret_cast<thrust::complex<double>*>(C), ldc);
335   qmc_cuda::cuda_check(cudaGetLastError());
336   qmc_cuda::cuda_check(cudaDeviceSynchronize());
337 }
338 
339 
Awiu_Biu_Cuw(int nu,int nw,int na,std::complex<float> alpha,std::complex<float> const * A,float const * B,int ldb,std::complex<float> * C,int ldc)340 void Awiu_Biu_Cuw(int nu,
341                   int nw,
342                   int na,
343                   std::complex<float> alpha,
344                   std::complex<float> const* A,
345                   float const* B,
346                   int ldb,
347                   std::complex<float>* C,
348                   int ldc)
349 {
350   if (size_t(nw) > MAX_THREADS_PER_DIM)
351     throw;
352   size_t nthr = std::max(size_t(1), MAX_THREADS_PER_DIM / size_t(nw));
353   size_t nbks = (nu + nthr - 1) / nthr;
354   dim3 grid_dim(nbks, 1, 1);
355   dim3 block_dim(nthr, nw, 1);
356   kernel_Awiu_Biu_Cuw<<<grid_dim, block_dim>>>(nu, nw, na, static_cast<thrust::complex<float> const>(alpha),
357                                                reinterpret_cast<thrust::complex<float> const*>(A), B, ldb,
358                                                reinterpret_cast<thrust::complex<float>*>(C), ldc);
359   qmc_cuda::cuda_check(cudaGetLastError());
360   qmc_cuda::cuda_check(cudaDeviceSynchronize());
361 }
362 
Awiu_Biu_Cuw(int nu,int nw,int na,std::complex<double> alpha,std::complex<double> const * A,std::complex<double> const * B,int ldb,std::complex<double> * C,int ldc)363 void Awiu_Biu_Cuw(int nu,
364                   int nw,
365                   int na,
366                   std::complex<double> alpha,
367                   std::complex<double> const* A,
368                   std::complex<double> const* B,
369                   int ldb,
370                   std::complex<double>* C,
371                   int ldc)
372 {
373   if (size_t(nw) > MAX_THREADS_PER_DIM)
374     throw;
375   size_t nthr = std::max(size_t(1), MAX_THREADS_PER_DIM / size_t(nw));
376   size_t nbks = (nu + nthr - 1) / nthr;
377   dim3 grid_dim(nbks, 1, 1);
378   dim3 block_dim(nthr, nw, 1);
379   kernel_Awiu_Biu_Cuw<<<grid_dim, block_dim>>>(nu, nw, na, static_cast<thrust::complex<double> const>(alpha),
380                                                reinterpret_cast<thrust::complex<double> const*>(A),
381                                                reinterpret_cast<thrust::complex<double> const*>(B), ldb,
382                                                reinterpret_cast<thrust::complex<double>*>(C), ldc);
383   qmc_cuda::cuda_check(cudaGetLastError());
384   qmc_cuda::cuda_check(cudaDeviceSynchronize());
385 }
386 
387 
Awiu_Biu_Cuw(int nu,int nw,int na,std::complex<float> alpha,std::complex<float> const * A,std::complex<float> const * B,int ldb,std::complex<float> * C,int ldc)388 void Awiu_Biu_Cuw(int nu,
389                   int nw,
390                   int na,
391                   std::complex<float> alpha,
392                   std::complex<float> const* A,
393                   std::complex<float> const* B,
394                   int ldb,
395                   std::complex<float>* C,
396                   int ldc)
397 {
398   if (size_t(nw) > MAX_THREADS_PER_DIM)
399     throw;
400   size_t nthr = std::max(size_t(1), MAX_THREADS_PER_DIM / size_t(nw));
401   size_t nbks = (nu + nthr - 1) / nthr;
402   dim3 grid_dim(nbks, 1, 1);
403   dim3 block_dim(nthr, nw, 1);
404   kernel_Awiu_Biu_Cuw<<<grid_dim, block_dim>>>(nu, nw, na, static_cast<thrust::complex<float> const>(alpha),
405                                                reinterpret_cast<thrust::complex<float> const*>(A),
406                                                reinterpret_cast<thrust::complex<float> const*>(B), ldb,
407                                                reinterpret_cast<thrust::complex<float>*>(C), ldc);
408   qmc_cuda::cuda_check(cudaGetLastError());
409   qmc_cuda::cuda_check(cudaDeviceSynchronize());
410 }
411 
412 
413 // C[i][k] = sum_i A[i][j][k] * B[k][j]
Aijk_Bkj_Cik(int ni,int nj,int nk,std::complex<double> const * A,int lda,int stride,std::complex<double> const * B,int ldb,std::complex<double> * C,int ldc)414 void Aijk_Bkj_Cik(int ni,
415                   int nj,
416                   int nk,
417                   std::complex<double> const* A,
418                   int lda,
419                   int stride,
420                   std::complex<double> const* B,
421                   int ldb,
422                   std::complex<double>* C,
423                   int ldc)
424 {
425   // expect nk >> ni,nj
426   dim3 grid_dim(nk, ni, 1);
427   kernel_Aijk_Bkj_Cik<<<grid_dim, 32>>>(ni, nj, nk, reinterpret_cast<thrust::complex<double> const*>(A), lda, stride,
428                                         reinterpret_cast<thrust::complex<double> const*>(B), ldb,
429                                         reinterpret_cast<thrust::complex<double>*>(C), ldc);
430   qmc_cuda::cuda_check(cudaGetLastError());
431   qmc_cuda::cuda_check(cudaDeviceSynchronize());
432 }
433 
Aijk_Bkj_Cik(int ni,int nj,int nk,std::complex<double> const * A,int lda,int stride,double const * B,int ldb,std::complex<double> * C,int ldc)434 void Aijk_Bkj_Cik(int ni,
435                   int nj,
436                   int nk,
437                   std::complex<double> const* A,
438                   int lda,
439                   int stride,
440                   double const* B,
441                   int ldb,
442                   std::complex<double>* C,
443                   int ldc)
444 {
445   // expect nk >> ni,nj
446   dim3 grid_dim(nk, ni, 1);
447   kernel_Aijk_Bkj_Cik<<<grid_dim, 32>>>(ni, nj, nk, reinterpret_cast<thrust::complex<double> const*>(A), lda, stride, B,
448                                         ldb, reinterpret_cast<thrust::complex<double>*>(C), ldc);
449   qmc_cuda::cuda_check(cudaGetLastError());
450   qmc_cuda::cuda_check(cudaDeviceSynchronize());
451 }
452 
Aijk_Bkj_Cik(int ni,int nj,int nk,std::complex<float> const * A,int lda,int stride,std::complex<float> const * B,int ldb,std::complex<float> * C,int ldc)453 void Aijk_Bkj_Cik(int ni,
454                   int nj,
455                   int nk,
456                   std::complex<float> const* A,
457                   int lda,
458                   int stride,
459                   std::complex<float> const* B,
460                   int ldb,
461                   std::complex<float>* C,
462                   int ldc)
463 {
464   // expect nk >> ni,nj
465   dim3 grid_dim(nk, ni, 1);
466   kernel_Aijk_Bkj_Cik<<<grid_dim, 32>>>(ni, nj, nk, reinterpret_cast<thrust::complex<float> const*>(A), lda, stride,
467                                         reinterpret_cast<thrust::complex<float> const*>(B), ldb,
468                                         reinterpret_cast<thrust::complex<float>*>(C), ldc);
469   qmc_cuda::cuda_check(cudaGetLastError());
470   qmc_cuda::cuda_check(cudaDeviceSynchronize());
471 }
472 
Aijk_Bkj_Cik(int ni,int nj,int nk,std::complex<float> const * A,int lda,int stride,float const * B,int ldb,std::complex<float> * C,int ldc)473 void Aijk_Bkj_Cik(int ni,
474                   int nj,
475                   int nk,
476                   std::complex<float> const* A,
477                   int lda,
478                   int stride,
479                   float const* B,
480                   int ldb,
481                   std::complex<float>* C,
482                   int ldc)
483 {
484   // expect nk >> ni,nj
485   dim3 grid_dim(nk, ni, 1);
486   kernel_Aijk_Bkj_Cik<<<grid_dim, 32>>>(ni, nj, nk, reinterpret_cast<thrust::complex<float> const*>(A), lda, stride, B,
487                                         ldb, reinterpret_cast<thrust::complex<float>*>(C), ldc);
488   qmc_cuda::cuda_check(cudaGetLastError());
489   qmc_cuda::cuda_check(cudaDeviceSynchronize());
490 }
491 
492 // v[w][i][j] = v[i][w][j]
viwj_vwij(int nw,int ni,int i0,int iN,std::complex<double> const * B,std::complex<double> * A)493 void viwj_vwij(int nw, int ni, int i0, int iN, std::complex<double> const* B, std::complex<double>* A)
494 {
495   // expect ni > nw
496   dim3 grid_dim(nw, (iN - i0), 1);
497   kernel_viwj_vwij<<<grid_dim, MAX_THREADS_PER_DIM>>>(nw, ni, i0, iN,
498                                                       reinterpret_cast<thrust::complex<double> const*>(B),
499                                                       reinterpret_cast<thrust::complex<double>*>(A));
500   qmc_cuda::cuda_check(cudaGetLastError());
501   qmc_cuda::cuda_check(cudaDeviceSynchronize());
502 }
503 
viwj_vwij(int nw,int ni,int i0,int iN,std::complex<double> const * B,std::complex<float> * A)504 void viwj_vwij(int nw, int ni, int i0, int iN, std::complex<double> const* B, std::complex<float>* A)
505 {
506   // expect ni > nw
507   dim3 grid_dim(nw, (iN - i0), 1);
508   kernel_viwj_vwij<<<grid_dim, MAX_THREADS_PER_DIM>>>(nw, ni, i0, iN,
509                                                       reinterpret_cast<thrust::complex<double> const*>(B),
510                                                       reinterpret_cast<thrust::complex<float>*>(A));
511   qmc_cuda::cuda_check(cudaGetLastError());
512   qmc_cuda::cuda_check(cudaDeviceSynchronize());
513 }
514 
viwj_vwij(int nw,int ni,int i0,int iN,std::complex<float> const * B,std::complex<double> * A)515 void viwj_vwij(int nw, int ni, int i0, int iN, std::complex<float> const* B, std::complex<double>* A)
516 {
517   // expect ni > nw
518   dim3 grid_dim(nw, (iN - i0), 1);
519   kernel_viwj_vwij<<<grid_dim, MAX_THREADS_PER_DIM>>>(nw, ni, i0, iN,
520                                                       reinterpret_cast<thrust::complex<float> const*>(B),
521                                                       reinterpret_cast<thrust::complex<double>*>(A));
522   qmc_cuda::cuda_check(cudaGetLastError());
523   qmc_cuda::cuda_check(cudaDeviceSynchronize());
524 }
525 
viwj_vwij(int nw,int ni,int i0,int iN,std::complex<float> const * B,std::complex<float> * A)526 void viwj_vwij(int nw, int ni, int i0, int iN, std::complex<float> const* B, std::complex<float>* A)
527 {
528   // expect ni > nw
529   dim3 grid_dim(nw, (iN - i0), 1);
530   kernel_viwj_vwij<<<grid_dim, MAX_THREADS_PER_DIM>>>(nw, ni, i0, iN,
531                                                       reinterpret_cast<thrust::complex<float> const*>(B),
532                                                       reinterpret_cast<thrust::complex<float>*>(A));
533   qmc_cuda::cuda_check(cudaGetLastError());
534   qmc_cuda::cuda_check(cudaDeviceSynchronize());
535 }
536 
537 // element-wise C[k][i][j] = A[i][j] * B[j][k]
element_wise_Aij_Bjk_Ckij(char transA,int ni,int nj,int nk,double const * A,int lda,std::complex<double> const * B,int ldb,std::complex<double> * C,int ldc1,int ldc2)538 void element_wise_Aij_Bjk_Ckij(char transA,
539                                int ni,
540                                int nj,
541                                int nk,
542                                double const* A,
543                                int lda,
544                                std::complex<double> const* B,
545                                int ldb,
546                                std::complex<double>* C,
547                                int ldc1,
548                                int ldc2)
549 {
550   // setup for nj >> ni,nk
551   size_t nbks = (nj + MAX_THREADS_PER_DIM - 1) / MAX_THREADS_PER_DIM;
552   dim3 grid_dim(ni, nbks, nk);
553   kernel_element_wise_Aij_Bjk_Ckij<<<grid_dim, MAX_THREADS_PER_DIM>>>(transA, ni, nj, nk, A, lda,
554                                                                       reinterpret_cast<thrust::complex<double> const*>(
555                                                                           B),
556                                                                       ldb,
557                                                                       reinterpret_cast<thrust::complex<double>*>(C),
558                                                                       ldc1, ldc2);
559   qmc_cuda::cuda_check(cudaGetLastError());
560   qmc_cuda::cuda_check(cudaDeviceSynchronize());
561 }
562 
element_wise_Aij_Bjk_Ckij(char transA,int ni,int nj,int nk,float const * A,int lda,std::complex<float> const * B,int ldb,std::complex<float> * C,int ldc1,int ldc2)563 void element_wise_Aij_Bjk_Ckij(char transA,
564                                int ni,
565                                int nj,
566                                int nk,
567                                float const* A,
568                                int lda,
569                                std::complex<float> const* B,
570                                int ldb,
571                                std::complex<float>* C,
572                                int ldc1,
573                                int ldc2)
574 {
575   // setup for nj >> ni,nk
576   size_t nbks = (nj + MAX_THREADS_PER_DIM - 1) / MAX_THREADS_PER_DIM;
577   dim3 grid_dim(ni, nbks, nk);
578   kernel_element_wise_Aij_Bjk_Ckij<<<grid_dim, MAX_THREADS_PER_DIM>>>(transA, ni, nj, nk, A, lda,
579                                                                       reinterpret_cast<thrust::complex<float> const*>(
580                                                                           B),
581                                                                       ldb, reinterpret_cast<thrust::complex<float>*>(C),
582                                                                       ldc1, ldc2);
583   qmc_cuda::cuda_check(cudaGetLastError());
584   qmc_cuda::cuda_check(cudaDeviceSynchronize());
585 }
586 
element_wise_Aij_Bjk_Ckij(char transA,int ni,int nj,int nk,std::complex<double> const * A,int lda,std::complex<double> const * B,int ldb,std::complex<double> * C,int ldc1,int ldc2)587 void element_wise_Aij_Bjk_Ckij(char transA,
588                                int ni,
589                                int nj,
590                                int nk,
591                                std::complex<double> const* A,
592                                int lda,
593                                std::complex<double> const* B,
594                                int ldb,
595                                std::complex<double>* C,
596                                int ldc1,
597                                int ldc2)
598 {
599   // setup for nj >> ni,nk
600   size_t nbks = (nj + MAX_THREADS_PER_DIM - 1) / MAX_THREADS_PER_DIM;
601   dim3 grid_dim(ni, nbks, nk);
602   kernel_element_wise_Aij_Bjk_Ckij<<<grid_dim, MAX_THREADS_PER_DIM>>>(transA, ni, nj, nk,
603                                                                       reinterpret_cast<thrust::complex<double> const*>(
604                                                                           A),
605                                                                       lda,
606                                                                       reinterpret_cast<thrust::complex<double> const*>(
607                                                                           B),
608                                                                       ldb,
609                                                                       reinterpret_cast<thrust::complex<double>*>(C),
610                                                                       ldc1, ldc2);
611   qmc_cuda::cuda_check(cudaGetLastError());
612   qmc_cuda::cuda_check(cudaDeviceSynchronize());
613 }
614 
element_wise_Aij_Bjk_Ckij(char transA,int ni,int nj,int nk,std::complex<float> const * A,int lda,std::complex<float> const * B,int ldb,std::complex<float> * C,int ldc1,int ldc2)615 void element_wise_Aij_Bjk_Ckij(char transA,
616                                int ni,
617                                int nj,
618                                int nk,
619                                std::complex<float> const* A,
620                                int lda,
621                                std::complex<float> const* B,
622                                int ldb,
623                                std::complex<float>* C,
624                                int ldc1,
625                                int ldc2)
626 {
627   // setup for nj >> ni,nk
628   size_t nbks = (nj + MAX_THREADS_PER_DIM - 1) / MAX_THREADS_PER_DIM;
629   dim3 grid_dim(ni, nbks, nk);
630   kernel_element_wise_Aij_Bjk_Ckij<<<grid_dim, MAX_THREADS_PER_DIM>>>(transA, ni, nj, nk,
631                                                                       reinterpret_cast<thrust::complex<float> const*>(
632                                                                           A),
633                                                                       lda,
634                                                                       reinterpret_cast<thrust::complex<float> const*>(
635                                                                           B),
636                                                                       ldb, reinterpret_cast<thrust::complex<float>*>(C),
637                                                                       ldc1, ldc2);
638   qmc_cuda::cuda_check(cudaGetLastError());
639   qmc_cuda::cuda_check(cudaDeviceSynchronize());
640 }
641 
642 
643 // element-wise C[k][j][i] = A[i][j] * B[j][k]
element_wise_Aij_Bjk_Ckji(int ni,int nj,int nk,double const * A,int lda,std::complex<double> const * B,int ldb,std::complex<double> * C,int ldc,int stride)644 void element_wise_Aij_Bjk_Ckji(int ni,
645                                int nj,
646                                int nk,
647                                double const* A,
648                                int lda,
649                                std::complex<double> const* B,
650                                int ldb,
651                                std::complex<double>* C,
652                                int ldc,
653                                int stride)
654 {
655   // setup for nj >> ni,nk
656   size_t nthr  = 32;
657   size_t nthrj = 8;
658   size_t ib    = (ni + nthr - 1) / nthr;
659   size_t jb    = (nj + nthr - 1) / nthr;
660   dim3 grid_dim(jb, ib, nk);
661   dim3 block_dim(nthr, nthrj, 1);
662   kernel_element_wise_Aij_Bjk_Ckji<<<grid_dim, block_dim>>>(ni, nj, nk, A, lda,
663                                                             reinterpret_cast<thrust::complex<double> const*>(B), ldb,
664                                                             reinterpret_cast<thrust::complex<double>*>(C), ldc, stride);
665   qmc_cuda::cuda_check(cudaGetLastError());
666   qmc_cuda::cuda_check(cudaDeviceSynchronize());
667 }
668 
element_wise_Aij_Bjk_Ckji(int ni,int nj,int nk,std::complex<double> const * A,int lda,std::complex<double> const * B,int ldb,std::complex<double> * C,int ldc,int stride)669 void element_wise_Aij_Bjk_Ckji(int ni,
670                                int nj,
671                                int nk,
672                                std::complex<double> const* A,
673                                int lda,
674                                std::complex<double> const* B,
675                                int ldb,
676                                std::complex<double>* C,
677                                int ldc,
678                                int stride)
679 {
680   // setup for nj >> ni,nk
681   size_t nthr  = 32;
682   size_t nthrj = 8;
683   size_t ib    = (ni + nthr - 1) / nthr;
684   size_t jb    = (nj + nthr - 1) / nthr;
685   // jb goes along x since this is the fast index in Aij, needed for better memory access patterns
686   dim3 grid_dim(jb, ib, nk);
687   dim3 block_dim(nthr, nthrj, 1);
688   kernel_element_wise_Aij_Bjk_Ckji<<<grid_dim, block_dim>>>(ni, nj, nk,
689                                                             reinterpret_cast<thrust::complex<double> const*>(A), lda,
690                                                             reinterpret_cast<thrust::complex<double> const*>(B), ldb,
691                                                             reinterpret_cast<thrust::complex<double>*>(C), ldc, stride);
692   qmc_cuda::cuda_check(cudaGetLastError());
693   qmc_cuda::cuda_check(cudaDeviceSynchronize());
694 }
695 
element_wise_Aij_Bjk_Ckji(int ni,int nj,int nk,float const * A,int lda,std::complex<float> const * B,int ldb,std::complex<float> * C,int ldc,int stride)696 void element_wise_Aij_Bjk_Ckji(int ni,
697                                int nj,
698                                int nk,
699                                float const* A,
700                                int lda,
701                                std::complex<float> const* B,
702                                int ldb,
703                                std::complex<float>* C,
704                                int ldc,
705                                int stride)
706 {
707   // setup for nj >> ni,nk
708   size_t nthr  = 32;
709   size_t nthrj = 8;
710   size_t ib    = (ni + nthr - 1) / nthr;
711   size_t jb    = (nj + nthr - 1) / nthr;
712   dim3 grid_dim(jb, ib, nk);
713   dim3 block_dim(nthr, nthrj, 1);
714   kernel_element_wise_Aij_Bjk_Ckji<<<grid_dim, block_dim>>>(ni, nj, nk, A, lda,
715                                                             reinterpret_cast<thrust::complex<float> const*>(B), ldb,
716                                                             reinterpret_cast<thrust::complex<float>*>(C), ldc, stride);
717   qmc_cuda::cuda_check(cudaGetLastError());
718   qmc_cuda::cuda_check(cudaDeviceSynchronize());
719 }
720 
element_wise_Aij_Bjk_Ckji(int ni,int nj,int nk,std::complex<float> const * A,int lda,std::complex<float> const * B,int ldb,std::complex<float> * C,int ldc,int stride)721 void element_wise_Aij_Bjk_Ckji(int ni,
722                                int nj,
723                                int nk,
724                                std::complex<float> const* A,
725                                int lda,
726                                std::complex<float> const* B,
727                                int ldb,
728                                std::complex<float>* C,
729                                int ldc,
730                                int stride)
731 {
732   // setup for nj >> ni,nk
733   size_t nthr  = 32;
734   size_t nthrj = 8;
735   size_t ib    = (ni + nthr - 1) / nthr;
736   size_t jb    = (nj + nthr - 1) / nthr;
737   dim3 grid_dim(jb, ib, nk);
738   dim3 block_dim(nthr, nthrj, 1);
739   kernel_element_wise_Aij_Bjk_Ckji<<<grid_dim, block_dim>>>(ni, nj, nk,
740                                                             reinterpret_cast<thrust::complex<float> const*>(A), lda,
741                                                             reinterpret_cast<thrust::complex<float> const*>(B), ldb,
742                                                             reinterpret_cast<thrust::complex<float>*>(C), ldc, stride);
743   qmc_cuda::cuda_check(cudaGetLastError());
744   qmc_cuda::cuda_check(cudaDeviceSynchronize());
745 }
746 
747 
748 } // namespace kernels
749