1 ///////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source
3 // License.  See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2020 QMCPACK developers.
6 //
7 // File developed by: Fionn Malone, malone14@llnl.gov, LLNL
8 //
9 // File created by: Fionn Malone, malone14@llnl.gov, LLNL
10 ////////////////////////////////////////////////////////////////////////////////
11 
12 
13 #include <cassert>
14 #include <complex>
15 #include <hip/hip_runtime.h>
16 #include <thrust/complex.h>
17 #include <hip/hip_runtime.h>
18 #include "AFQMC/Numerics/detail/HIP/hip_kernel_utils.h"
19 
20 namespace kernels
21 {
22 template<typename T>
kernel_axpy_batched(int n,thrust::complex<T> * x,thrust::complex<T> ** a,int inca,thrust::complex<T> ** b,int incb,int batchSize)23 __global__ void kernel_axpy_batched(int n,
24                                     thrust::complex<T>* x,
25                                     thrust::complex<T>** a,
26                                     int inca,
27                                     thrust::complex<T>** b,
28                                     int incb,
29                                     int batchSize)
30 {
31   int batch = blockIdx.x;
32   if (batch >= batchSize)
33     return;
34 
35   thrust::complex<T>* a_(a[batch]);
36   thrust::complex<T>* b_(b[batch]);
37   thrust::complex<T> x_(x[batch]);
38 
39   int i = threadIdx.x;
40   while (i < n)
41   {
42     b_[i * incb] = b_[i * incb] + x_ * a_[i * inca];
43     i += blockDim.x;
44   }
45 }
46 
47 template<typename T>
kernel_sumGw_batched(int n,thrust::complex<T> * x,thrust::complex<T> ** a,int inca,thrust::complex<T> ** b,int incb,int b0,int nw,int batchSize)48 __global__ void kernel_sumGw_batched(int n,
49                                      thrust::complex<T>* x,
50                                      thrust::complex<T>** a,
51                                      int inca,
52                                      thrust::complex<T>** b,
53                                      int incb,
54                                      int b0,
55                                      int nw,
56                                      int batchSize)
57 {
58   if (blockIdx.x >= batchSize)
59     return;
60 
61   int my_iw = (b0 + blockIdx.x) % nw;
62 
63   for (int m = 0; m < batchSize; ++m)
64   {
65     if ((b0 + m) % nw != my_iw)
66       continue;
67 
68     thrust::complex<T>* a_(a[m]);
69     thrust::complex<T>* b_(b[m]);
70     thrust::complex<T> x_(x[m]);
71 
72     int i = threadIdx.x;
73     while (i < n)
74     {
75       b_[i * incb] = b_[i * incb] + x_ * a_[i * inca];
76       i += blockDim.x;
77     }
78   }
79 }
80 
axpy_batched_gpu(int n,std::complex<double> * x,const std::complex<double> ** a,int inca,std::complex<double> ** b,int incb,int batchSize)81 void axpy_batched_gpu(int n,
82                       std::complex<double>* x,
83                       const std::complex<double>** a,
84                       int inca,
85                       std::complex<double>** b,
86                       int incb,
87                       int batchSize)
88 {
89   thrust::complex<double>*x_, **a_, **b_;
90   hipMalloc((void**)&a_, batchSize * sizeof(*a_));
91   hipMalloc((void**)&b_, batchSize * sizeof(*b_));
92   hipMalloc((void**)&x_, batchSize * sizeof(*x_));
93   hipMemcpy(a_, a, batchSize * sizeof(*a), hipMemcpyHostToDevice);
94   hipMemcpy(b_, b, batchSize * sizeof(*b), hipMemcpyHostToDevice);
95   hipMemcpy(x_, x, batchSize * sizeof(*x), hipMemcpyHostToDevice);
96   hipLaunchKernelGGL(kernel_axpy_batched, dim3(batchSize), dim3(128), 0, 0, n, x_, a_, inca, b_, incb, batchSize);
97   qmc_hip::hip_kernel_check(hipGetLastError());
98   qmc_hip::hip_kernel_check(hipDeviceSynchronize());
99   hipFree(a_);
100   hipFree(b_);
101   hipFree(x_);
102 }
103 
sumGw_batched_gpu(int n,std::complex<double> * x,const std::complex<double> ** a,int inca,std::complex<double> ** b,int incb,int b0,int nw,int batchSize)104 void sumGw_batched_gpu(int n,
105                        std::complex<double>* x,
106                        const std::complex<double>** a,
107                        int inca,
108                        std::complex<double>** b,
109                        int incb,
110                        int b0,
111                        int nw,
112                        int batchSize)
113 {
114   thrust::complex<double>*x_, **a_, **b_;
115   hipMalloc((void**)&a_, batchSize * sizeof(*a_));
116   hipMalloc((void**)&b_, batchSize * sizeof(*b_));
117   hipMalloc((void**)&x_, batchSize * sizeof(*x_));
118   hipMemcpy(a_, a, batchSize * sizeof(*a), hipMemcpyHostToDevice);
119   hipMemcpy(b_, b, batchSize * sizeof(*b), hipMemcpyHostToDevice);
120   hipMemcpy(x_, x, batchSize * sizeof(*x), hipMemcpyHostToDevice);
121   int nb_(nw > batchSize ? batchSize : nw);
122   hipLaunchKernelGGL(kernel_sumGw_batched, dim3(nb_), dim3(256), 0, 0, n, x_, a_, inca, b_, incb, b0, nw, batchSize);
123   qmc_hip::hip_kernel_check(hipGetLastError());
124   qmc_hip::hip_kernel_check(hipDeviceSynchronize());
125   hipFree(a_);
126   hipFree(b_);
127   hipFree(x_);
128 }
129 
130 } // namespace kernels
131