1 ///////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source
3 // License.  See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2020 QMCPACK developers.
6 //
7 // File developed by: Fionn Malone, malone14@llnl.gov, LLNL
8 //
9 // File created by: Fionn Malone, malone14@llnl.gov, LLNL
10 ////////////////////////////////////////////////////////////////////////////////
11 
12 #include <cassert>
13 #include <complex>
14 #include <hip/hip_runtime.h>
15 #include <thrust/complex.h>
16 //#include "hip_settings.h"
17 //#include "hip_utilities.h"
18 #include "AFQMC/Numerics/detail/HIP/Kernels/hip_settings.h"
19 #include "AFQMC/Numerics/detail/HIP/hip_kernel_utils.h"
20 
21 namespace kernels
22 {
23 // simple
24 // A[k][i] = B[k][i][i]
25 template<typename T>
kernel_get_diagonal_strided(int nk,int ni,thrust::complex<T> const * B,int ldb,int stride,thrust::complex<T> * A,int lda)26 __global__ void kernel_get_diagonal_strided(int nk,
27                                             int ni,
28                                             thrust::complex<T> const* B,
29                                             int ldb,
30                                             int stride,
31                                             thrust::complex<T>* A,
32                                             int lda)
33 {
34   int k = blockIdx.y;
35   int i = blockIdx.x * blockDim.x + threadIdx.x;
36   if ((i < ni) && (k < nk))
37     A[k * lda + i] = B[k * stride + i * ldb + i];
38 }
39 
40 // A[k][i] = B[k][i][i]
get_diagonal_strided(int nk,int ni,std::complex<double> const * B,int ldb,int stride,std::complex<double> * A,int lda)41 void get_diagonal_strided(int nk,
42                           int ni,
43                           std::complex<double> const* B,
44                           int ldb,
45                           int stride,
46                           std::complex<double>* A,
47                           int lda)
48 {
49   size_t nthr = 32;
50   size_t nbks = (ni + nthr - 1) / nthr;
51   dim3 grid_dim(nbks, nk, 1);
52   hipLaunchKernelGGL(kernel_get_diagonal_strided, dim3(grid_dim), dim3(nthr), 0, 0, nk, ni,
53                      reinterpret_cast<thrust::complex<double> const*>(B), ldb, stride,
54                      reinterpret_cast<thrust::complex<double>*>(A), lda);
55   qmc_hip::hip_kernel_check(hipGetLastError());
56   qmc_hip::hip_kernel_check(hipDeviceSynchronize());
57 }
58 
get_diagonal_strided(int nk,int ni,std::complex<float> const * B,int ldb,int stride,std::complex<float> * A,int lda)59 void get_diagonal_strided(int nk,
60                           int ni,
61                           std::complex<float> const* B,
62                           int ldb,
63                           int stride,
64                           std::complex<float>* A,
65                           int lda)
66 {
67   size_t nthr = 32;
68   size_t nbks = (ni + nthr - 1) / nthr;
69   dim3 grid_dim(nbks, nk, 1);
70   hipLaunchKernelGGL(kernel_get_diagonal_strided, dim3(grid_dim), dim3(nthr), 0, 0, nk, ni,
71                      reinterpret_cast<thrust::complex<float> const*>(B), ldb, stride,
72                      reinterpret_cast<thrust::complex<float>*>(A), lda);
73   qmc_hip::hip_kernel_check(hipGetLastError());
74   qmc_hip::hip_kernel_check(hipDeviceSynchronize());
75 }
76 
77 } // namespace kernels
78