1 ///////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source
3 // License. See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2020 QMCPACK developers.
6 //
7 // File developed by: Fionn Malone, malone14@llnl.gov, LLNL
8 //
9 // File created by: Fionn Malone, malone14@llnl.gov, LLNL
10 ////////////////////////////////////////////////////////////////////////////////
11
12
13 #include <cassert>
14 #include <complex>
15 #include <hip/hip_runtime.h>
16 #include <thrust/complex.h>
17 #include <hip/hip_runtime.h>
18 #include "AFQMC/Numerics/detail/HIP/hip_kernel_utils.h"
19
20 namespace kernels
21 {
22 template<typename T>
kernel_axpy_batched(int n,thrust::complex<T> * x,thrust::complex<T> ** a,int inca,thrust::complex<T> ** b,int incb,int batchSize)23 __global__ void kernel_axpy_batched(int n,
24 thrust::complex<T>* x,
25 thrust::complex<T>** a,
26 int inca,
27 thrust::complex<T>** b,
28 int incb,
29 int batchSize)
30 {
31 int batch = blockIdx.x;
32 if (batch >= batchSize)
33 return;
34
35 thrust::complex<T>* a_(a[batch]);
36 thrust::complex<T>* b_(b[batch]);
37 thrust::complex<T> x_(x[batch]);
38
39 int i = threadIdx.x;
40 while (i < n)
41 {
42 b_[i * incb] = b_[i * incb] + x_ * a_[i * inca];
43 i += blockDim.x;
44 }
45 }
46
47 template<typename T>
kernel_sumGw_batched(int n,thrust::complex<T> * x,thrust::complex<T> ** a,int inca,thrust::complex<T> ** b,int incb,int b0,int nw,int batchSize)48 __global__ void kernel_sumGw_batched(int n,
49 thrust::complex<T>* x,
50 thrust::complex<T>** a,
51 int inca,
52 thrust::complex<T>** b,
53 int incb,
54 int b0,
55 int nw,
56 int batchSize)
57 {
58 if (blockIdx.x >= batchSize)
59 return;
60
61 int my_iw = (b0 + blockIdx.x) % nw;
62
63 for (int m = 0; m < batchSize; ++m)
64 {
65 if ((b0 + m) % nw != my_iw)
66 continue;
67
68 thrust::complex<T>* a_(a[m]);
69 thrust::complex<T>* b_(b[m]);
70 thrust::complex<T> x_(x[m]);
71
72 int i = threadIdx.x;
73 while (i < n)
74 {
75 b_[i * incb] = b_[i * incb] + x_ * a_[i * inca];
76 i += blockDim.x;
77 }
78 }
79 }
80
axpy_batched_gpu(int n,std::complex<double> * x,const std::complex<double> ** a,int inca,std::complex<double> ** b,int incb,int batchSize)81 void axpy_batched_gpu(int n,
82 std::complex<double>* x,
83 const std::complex<double>** a,
84 int inca,
85 std::complex<double>** b,
86 int incb,
87 int batchSize)
88 {
89 thrust::complex<double>*x_, **a_, **b_;
90 hipMalloc((void**)&a_, batchSize * sizeof(*a_));
91 hipMalloc((void**)&b_, batchSize * sizeof(*b_));
92 hipMalloc((void**)&x_, batchSize * sizeof(*x_));
93 hipMemcpy(a_, a, batchSize * sizeof(*a), hipMemcpyHostToDevice);
94 hipMemcpy(b_, b, batchSize * sizeof(*b), hipMemcpyHostToDevice);
95 hipMemcpy(x_, x, batchSize * sizeof(*x), hipMemcpyHostToDevice);
96 hipLaunchKernelGGL(kernel_axpy_batched, dim3(batchSize), dim3(128), 0, 0, n, x_, a_, inca, b_, incb, batchSize);
97 qmc_hip::hip_kernel_check(hipGetLastError());
98 qmc_hip::hip_kernel_check(hipDeviceSynchronize());
99 hipFree(a_);
100 hipFree(b_);
101 hipFree(x_);
102 }
103
sumGw_batched_gpu(int n,std::complex<double> * x,const std::complex<double> ** a,int inca,std::complex<double> ** b,int incb,int b0,int nw,int batchSize)104 void sumGw_batched_gpu(int n,
105 std::complex<double>* x,
106 const std::complex<double>** a,
107 int inca,
108 std::complex<double>** b,
109 int incb,
110 int b0,
111 int nw,
112 int batchSize)
113 {
114 thrust::complex<double>*x_, **a_, **b_;
115 hipMalloc((void**)&a_, batchSize * sizeof(*a_));
116 hipMalloc((void**)&b_, batchSize * sizeof(*b_));
117 hipMalloc((void**)&x_, batchSize * sizeof(*x_));
118 hipMemcpy(a_, a, batchSize * sizeof(*a), hipMemcpyHostToDevice);
119 hipMemcpy(b_, b, batchSize * sizeof(*b), hipMemcpyHostToDevice);
120 hipMemcpy(x_, x, batchSize * sizeof(*x), hipMemcpyHostToDevice);
121 int nb_(nw > batchSize ? batchSize : nw);
122 hipLaunchKernelGGL(kernel_sumGw_batched, dim3(nb_), dim3(256), 0, 0, n, x_, a_, inca, b_, incb, b0, nw, batchSize);
123 qmc_hip::hip_kernel_check(hipGetLastError());
124 qmc_hip::hip_kernel_check(hipDeviceSynchronize());
125 hipFree(a_);
126 hipFree(b_);
127 hipFree(x_);
128 }
129
130 } // namespace kernels
131