1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2020 QMCPACK developers.
6 //
7 // File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
8 //
9 // File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
10 //////////////////////////////////////////////////////////////////////////////////////
11 
12 
13 #include "cuBLAS_missing_functions.hpp"
14 #include <stdexcept>
15 #include <cuComplex.h>
16 #include <thrust/complex.h>
17 #include <thrust/system/cuda/detail/core/util.h>
18 
19 namespace qmcplusplus
20 {
21 namespace cuBLAS_MFs
22 {
23 using namespace thrust::cuda_cub::core;
24 
25 template<typename T, int ROWBS, int COLBS>
gemvT_batched_kernel(const int m,const int n,const T * __restrict__ alpha,const T * const A[],const int lda,const T * const x[],const int incx,const T * __restrict__ beta,T * const y[],const int incy)26 __global__ void gemvT_batched_kernel(const int m, // number of columns in row major A
27                                      const int n, // number of rows in row major A
28                                      const T* __restrict__ alpha,
29                                      const T* const A[],
30                                      const int lda,
31                                      const T* const x[],
32                                      const int incx,
33                                      const T* __restrict__ beta,
34                                      T* const y[],
35                                      const int incy)
36 {
37   static_assert(ROWBS <= COLBS, "Row block size must not be larger than column block size!");
38 
39   constexpr int SUM_SIZE = ROWBS * COLBS;
40   __shared__ uninitialized_array<T, SUM_SIZE> sum;
41   __shared__ uninitialized_array<T, COLBS> x_part;
42 
43   const int tid = threadIdx.x;
44   for (int i = 0; i < ROWBS; i++)
45     sum[i * COLBS + tid] = T(0.0);
46 
47   const T* __restrict__ A_iw = A[blockIdx.x];
48   const T* __restrict__ x_iw = x[blockIdx.x];
49 
50   const int row_begin = blockIdx.y * ROWBS;
51   const int row_max   = (n - row_begin) < ROWBS ? (n - row_begin) : ROWBS;
52 
53   const int num_col_blocks = (m + COLBS - 1) / COLBS;
54   for (int ib = 0; ib < num_col_blocks; ib++)
55   {
56     const int col_id = ib * COLBS + tid;
57     if (col_id < m)
58       x_part[tid] = x_iw[col_id * incx];
59     for (int row_id = row_begin; row_id < row_begin + row_max; row_id++)
60       if (col_id < m)
61         sum[(row_id - row_begin) * COLBS + tid] += x_part[tid] * A_iw[row_id * lda + col_id];
62   }
63 
64   for (int iend = COLBS / 2; iend > 0; iend /= 2)
65   {
66     __syncthreads();
67     for (int irow = 0; irow < row_max; irow++)
68       if (tid < iend)
69         sum[irow * COLBS + tid] += sum[irow * COLBS + tid + iend];
70   }
71 
72   __syncthreads();
73   T* __restrict__ y_iw = y[blockIdx.x];
74   if (tid < row_max)
75   {
76     if (beta[blockIdx.x] == T(0))
77       y_iw[(row_begin + tid) * incy] =
78           alpha[blockIdx.x] * sum[tid * COLBS]; // protecting NaN from y_iw
79     else
80       y_iw[(row_begin + tid) * incy] =
81           alpha[blockIdx.x] * sum[tid * COLBS] + beta[blockIdx.x] * y_iw[(row_begin + tid) * incy];
82   }
83 }
84 
85 template<typename T, int ROWBS>
gemvN_batched_kernel(const int m,const int n,const T * __restrict__ alpha,const T * const A[],const int lda,const T * const x[],const int incx,const T * __restrict__ beta,T * const y[],const int incy)86 __global__ void gemvN_batched_kernel(const int m, // number of columns in row major A
87                                      const int n, // number of rows in row major A
88                                      const T* __restrict__ alpha,
89                                      const T* const A[],
90                                      const int lda,
91                                      const T* const x[],
92                                      const int incx,
93                                      const T* __restrict__ beta,
94                                      T* const y[],
95                                      const int incy)
96 {
97   const T* __restrict__ A_iw = A[blockIdx.x];
98   const T* __restrict__ x_iw = x[blockIdx.x];
99   T* __restrict__ y_iw       = y[blockIdx.x];
100 
101   const int tid       = threadIdx.x;
102   const int row_begin = blockIdx.y * ROWBS;
103 
104   if (row_begin + tid < m)
105   {
106     T sum(0);
107     for (int col_id = 0; col_id < n; col_id++)
108       sum += A_iw[col_id * lda + row_begin + tid] * x_iw[col_id * incx];
109     if (beta[blockIdx.x] == T(0))
110       y_iw[(row_begin + tid) * incy] = alpha[blockIdx.x] * sum; // protecting NaN from y_iw
111     else
112       y_iw[(row_begin + tid) * incy] = alpha[blockIdx.x] * sum + beta[blockIdx.x] * y_iw[(row_begin + tid) * incy];
113   }
114 }
115 
116 template<typename T>
gemv_batched_impl(cuBLAS_MFs_handle & handle,const char trans,const int m,const int n,const T * alpha,const T * const A[],const int lda,const T * const x[],const int incx,const T * beta,T * const y[],const int incy,const int batch_count)117 cuBLAS_MFs_status gemv_batched_impl(cuBLAS_MFs_handle& handle,
118                                     const char trans,
119                                     const int m,
120                                     const int n,
121                                     const T* alpha,
122                                     const T* const A[],
123                                     const int lda,
124                                     const T* const x[],
125                                     const int incx,
126                                     const T* beta,
127                                     T* const y[],
128                                     const int incy,
129                                     const int batch_count)
130 {
131   if (batch_count == 0 || m == 0 || n == 0)
132     return cudaSuccess;
133 
134   if (trans == 'T')
135   {
136     const int ROWBS          = 4;
137     const int COLBS          = 64;
138     const int num_row_blocks = (n + ROWBS - 1) / ROWBS;
139     dim3 dimBlock(COLBS);
140     dim3 dimGrid(batch_count, num_row_blocks);
141     gemvT_batched_kernel<T, ROWBS, COLBS>
142         <<<dimGrid, dimBlock, 0, handle>>>(m, n, alpha, A, lda, x, incx, beta, y, incy);
143   }
144   else
145   {
146     const int ROWBS          = 64;
147     const int num_row_blocks = (m + ROWBS - 1) / ROWBS;
148     dim3 dimBlock(ROWBS);
149     dim3 dimGrid(batch_count, num_row_blocks);
150     gemvN_batched_kernel<T, ROWBS><<<dimGrid, dimBlock, 0, handle>>>(m, n, alpha, A, lda, x, incx, beta, y, incy);
151   }
152   return cudaPeekAtLastError();
153 }
154 
gemv_batched(cuBLAS_MFs_handle & handle,const char trans,const int m,const int n,const float * alpha,const float * const A[],const int lda,const float * const x[],const int incx,const float * beta,float * const y[],const int incy,const int batch_count)155 cuBLAS_MFs_status gemv_batched(cuBLAS_MFs_handle& handle,
156                                const char trans,
157                                const int m,
158                                const int n,
159                                const float* alpha,
160                                const float* const A[],
161                                const int lda,
162                                const float* const x[],
163                                const int incx,
164                                const float* beta,
165                                float* const y[],
166                                const int incy,
167                                const int batch_count)
168 {
169   return gemv_batched_impl(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy, batch_count);
170 }
171 
gemv_batched(cuBLAS_MFs_handle & handle,const char trans,const int m,const int n,const double * alpha,const double * const A[],const int lda,const double * const x[],const int incx,const double * beta,double * const y[],const int incy,const int batch_count)172 cuBLAS_MFs_status gemv_batched(cuBLAS_MFs_handle& handle,
173                                const char trans,
174                                const int m,
175                                const int n,
176                                const double* alpha,
177                                const double* const A[],
178                                const int lda,
179                                const double* const x[],
180                                const int incx,
181                                const double* beta,
182                                double* const y[],
183                                const int incy,
184                                const int batch_count)
185 {
186   return gemv_batched_impl(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy, batch_count);
187 }
188 
gemv_batched(cuBLAS_MFs_handle & handle,const char trans,const int m,const int n,const std::complex<float> * alpha,const std::complex<float> * const A[],const int lda,const std::complex<float> * const x[],const int incx,const std::complex<float> * beta,std::complex<float> * const y[],const int incy,const int batch_count)189 cuBLAS_MFs_status gemv_batched(cuBLAS_MFs_handle& handle,
190                                const char trans,
191                                const int m,
192                                const int n,
193                                const std::complex<float>* alpha,
194                                const std::complex<float>* const A[],
195                                const int lda,
196                                const std::complex<float>* const x[],
197                                const int incx,
198                                const std::complex<float>* beta,
199                                std::complex<float>* const y[],
200                                const int incy,
201                                const int batch_count)
202 {
203   return gemv_batched_impl(handle, trans, m, n, (const thrust::complex<float>*)alpha, (const thrust::complex<float>**)A, lda, (const thrust::complex<float>**)x, incx, (const thrust::complex<float>*)beta, (thrust::complex<float>**)y, incy, batch_count);
204 }
205 
gemv_batched(cuBLAS_MFs_handle & handle,const char trans,const int m,const int n,const std::complex<double> * alpha,const std::complex<double> * const A[],const int lda,const std::complex<double> * const x[],const int incx,const std::complex<double> * beta,std::complex<double> * const y[],const int incy,const int batch_count)206 cuBLAS_MFs_status gemv_batched(cuBLAS_MFs_handle& handle,
207                                const char trans,
208                                const int m,
209                                const int n,
210                                const std::complex<double>* alpha,
211                                const std::complex<double>* const A[],
212                                const int lda,
213                                const std::complex<double>* const x[],
214                                const int incx,
215                                const std::complex<double>* beta,
216                                std::complex<double>* const y[],
217                                const int incy,
218                                const int batch_count)
219 {
220   return gemv_batched_impl(handle, trans, m, n, (const thrust::complex<double>*)alpha, (const thrust::complex<double>**)A, lda, (const thrust::complex<double>**)x, incx, (const thrust::complex<double>*)beta, (thrust::complex<double>**)y, incy, batch_count);
221 }
222 
223 
224 template<typename T, int ROWBS, int COLBS>
ger_batched_kernel(const int m,const int n,const T * __restrict__ alpha,const T * const x[],const int incx,const T * const y[],const int incy,T * const A[],const int lda)225 __global__ void ger_batched_kernel(const int m, // number of columns in row major A
226                                    const int n, // number of rows in row major A
227                                    const T* __restrict__ alpha,
228                                    const T* const x[],
229                                    const int incx,
230                                    const T* const y[],
231                                    const int incy,
232                                    T* const A[],
233                                    const int lda)
234 {
235   const int iw               = blockIdx.x;
236   const T* __restrict__ x_iw = x[iw];
237   const T* __restrict__ y_iw = y[iw];
238   T* __restrict__ A_iw       = A[iw];
239 
240   const int row_begin = blockIdx.y * ROWBS;
241   const int row_end   = (row_begin + ROWBS) < n ? (row_begin + ROWBS) : n;
242   const int tid       = threadIdx.x;
243   const int col_id    = blockIdx.z * COLBS + tid;
244 
245   __shared__ uninitialized_array<T, COLBS> x_part;
246   if (col_id < m)
247     x_part[tid] = x_iw[col_id * incx];
248 
249   for (int row_id = row_begin; row_id < row_end; row_id++)
250     if (col_id < m)
251       A_iw[row_id * lda + col_id] += alpha[iw] * x_part[tid] * y_iw[row_id * incy];
252 }
253 
254 template<typename T>
ger_batched_impl(cuBLAS_MFs_handle & handle,const int m,const int n,const T * alpha,const T * const x[],const int incx,const T * const y[],const int incy,T * const A[],const int lda,const int batch_count)255 cuBLAS_MFs_status ger_batched_impl(cuBLAS_MFs_handle& handle,
256                                    const int m,
257                                    const int n,
258                                    const T* alpha,
259                                    const T* const x[],
260                                    const int incx,
261                                    const T* const y[],
262                                    const int incy,
263                                    T* const A[],
264                                    const int lda,
265                                    const int batch_count)
266 {
267   if (batch_count == 0 || m == 0 || n == 0)
268     return cudaSuccess;
269 
270   const int ROWBS          = 16;
271   const int COLBS          = 64;
272   const int num_row_blocks = (n + ROWBS - 1) / ROWBS;
273   const int num_col_blocks = (m + COLBS - 1) / COLBS;
274   dim3 dimBlock(COLBS);
275   dim3 dimGrid(batch_count, num_row_blocks, num_col_blocks);
276   ger_batched_kernel<T, ROWBS, COLBS><<<dimGrid, dimBlock, 0, handle>>>(m, n, alpha, x, incx, y, incy, A, lda);
277   return cudaPeekAtLastError();
278 }
279 
ger_batched(cuBLAS_MFs_handle & handle,const int m,const int n,const float * alpha,const float * const x[],const int incx,const float * const y[],const int incy,float * const A[],const int lda,const int batch_count)280 cuBLAS_MFs_status ger_batched(cuBLAS_MFs_handle& handle,
281                               const int m,
282                               const int n,
283                               const float* alpha,
284                               const float* const x[],
285                               const int incx,
286                               const float* const y[],
287                               const int incy,
288                               float* const A[],
289                               const int lda,
290                               const int batch_count)
291 {
292   return ger_batched_impl(handle, m, n, alpha, x, incx, y, incy, A, lda, batch_count);
293 }
294 
ger_batched(cuBLAS_MFs_handle & handle,const int m,const int n,const double * alpha,const double * const x[],const int incx,const double * const y[],const int incy,double * const A[],const int lda,const int batch_count)295 cuBLAS_MFs_status ger_batched(cuBLAS_MFs_handle& handle,
296                               const int m,
297                               const int n,
298                               const double* alpha,
299                               const double* const x[],
300                               const int incx,
301                               const double* const y[],
302                               const int incy,
303                               double* const A[],
304                               const int lda,
305                               const int batch_count)
306 {
307   return ger_batched_impl(handle, m, n, alpha, x, incx, y, incy, A, lda, batch_count);
308 }
309 
ger_batched(cuBLAS_MFs_handle & handle,const int m,const int n,const std::complex<float> * alpha,const std::complex<float> * const x[],const int incx,const std::complex<float> * const y[],const int incy,std::complex<float> * const A[],const int lda,const int batch_count)310 cuBLAS_MFs_status ger_batched(cuBLAS_MFs_handle& handle,
311                               const int m,
312                               const int n,
313                               const std::complex<float>* alpha,
314                               const std::complex<float>* const x[],
315                               const int incx,
316                               const std::complex<float>* const y[],
317                               const int incy,
318                               std::complex<float>* const A[],
319                               const int lda,
320                               const int batch_count)
321 {
322   return ger_batched_impl(handle, m, n, (const thrust::complex<float>*)alpha, (const thrust::complex<float>**)x, incx, (const thrust::complex<float>**)y, incy, (thrust::complex<float>**)A, lda, batch_count);
323 }
324 
ger_batched(cuBLAS_MFs_handle & handle,const int m,const int n,const std::complex<double> * alpha,const std::complex<double> * const x[],const int incx,const std::complex<double> * const y[],const int incy,std::complex<double> * const A[],const int lda,const int batch_count)325 cuBLAS_MFs_status ger_batched(cuBLAS_MFs_handle& handle,
326                               const int m,
327                               const int n,
328                               const std::complex<double>* alpha,
329                               const std::complex<double>* const x[],
330                               const int incx,
331                               const std::complex<double>* const y[],
332                               const int incy,
333                               std::complex<double>* const A[],
334                               const int lda,
335                               const int batch_count)
336 {
337   return ger_batched_impl(handle, m, n, (const thrust::complex<double>*)alpha, (const thrust::complex<double>**)x, incx, (const thrust::complex<double>**)y, incy, (thrust::complex<double>**)A, lda, batch_count);
338 }
339 
340 template<typename T, int COLBS>
copy_batched_kernel(const int n,const T * const in[],const int incx,T * const out[],const int incy)341 __global__ void copy_batched_kernel(const int n, const T* const in[], const int incx, T* const out[], const int incy)
342 {
343   const int iw                = blockIdx.x;
344   const T* __restrict__ in_iw = in[iw];
345   T* __restrict__ out_iw      = out[iw];
346 
347   const int col_id = blockIdx.y * COLBS + threadIdx.x;
348   if (incx == 1 && incy == 1)
349   {
350     if (col_id < n)
351       out_iw[col_id] = in_iw[col_id];
352   }
353   else
354   {
355     if (col_id < n)
356       out_iw[col_id * incx] = in_iw[col_id * incx];
357   }
358 }
359 
360 template<typename T>
copy_batched_impl(cudaStream_t & hstream,const int n,const T * const in[],const int incx,T * const out[],const int incy,const int batch_count)361 cuBLAS_MFs_status copy_batched_impl(cudaStream_t& hstream,
362                                     const int n,
363                                     const T* const in[],
364                                     const int incx,
365                                     T* const out[],
366                                     const int incy,
367                                     const int batch_count)
368 {
369   if (batch_count == 0 || n == 0)
370     return cudaSuccess;
371 
372   const int COLBS          = 128;
373   const int num_col_blocks = (n + COLBS - 1) / COLBS;
374   dim3 dimBlock(COLBS);
375   dim3 dimGrid(batch_count, num_col_blocks);
376   copy_batched_kernel<T, COLBS><<<dimGrid, dimBlock, 0, hstream>>>(n, in, incx, out, incy);
377   return cudaPeekAtLastError();
378 }
379 
copy_batched(cudaStream_t & hstream,const int n,const float * const in[],const int incx,float * const out[],const int incy,const int batch_count)380 cuBLAS_MFs_status copy_batched(cudaStream_t& hstream,
381                                const int n,
382                                const float* const in[],
383                                const int incx,
384                                float* const out[],
385                                const int incy,
386                                const int batch_count)
387 {
388   return copy_batched_impl(hstream, n, in, incx, out, incy, batch_count);
389 }
390 
copy_batched(cudaStream_t & hstream,const int n,const double * const in[],const int incx,double * const out[],const int incy,const int batch_count)391 cuBLAS_MFs_status copy_batched(cudaStream_t& hstream,
392                                const int n,
393                                const double* const in[],
394                                const int incx,
395                                double* const out[],
396                                const int incy,
397                                const int batch_count)
398 {
399   return copy_batched_impl(hstream, n, in, incx, out, incy, batch_count);
400 }
401 
copy_batched(cudaStream_t & hstream,const int n,const std::complex<float> * const in[],const int incx,std::complex<float> * const out[],const int incy,const int batch_count)402 cuBLAS_MFs_status copy_batched(cudaStream_t& hstream,
403                                const int n,
404                                const std::complex<float>* const in[],
405                                const int incx,
406                                std::complex<float>* const out[],
407                                const int incy,
408                                const int batch_count)
409 {
410   return copy_batched_impl(hstream, n, (const cuComplex**)in, incx, (cuComplex**)out, incy, batch_count);
411 }
412 
copy_batched(cudaStream_t & hstream,const int n,const std::complex<double> * const in[],const int incx,std::complex<double> * const out[],const int incy,const int batch_count)413 cuBLAS_MFs_status copy_batched(cudaStream_t& hstream,
414                                const int n,
415                                const std::complex<double>* const in[],
416                                const int incx,
417                                std::complex<double>* const out[],
418                                const int incy,
419                                const int batch_count)
420 {
421   return copy_batched_impl(hstream, n, (const cuDoubleComplex**)in, incx, (cuDoubleComplex**)out, incy, batch_count);
422 }
423 
424 } // namespace cuBLAS_MFs
425 } // namespace qmcplusplus
426