1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2020 QMCPACK developers.
6 //
7 // File developed by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
8 //
9 // File created by: Ye Luo, yeluo@anl.gov, Argonne National Laboratory
10 //////////////////////////////////////////////////////////////////////////////////////
11
12
13 #include "cuBLAS_missing_functions.hpp"
14 #include <stdexcept>
15 #include <cuComplex.h>
16 #include <thrust/complex.h>
17 #include <thrust/system/cuda/detail/core/util.h>
18
19 namespace qmcplusplus
20 {
21 namespace cuBLAS_MFs
22 {
23 using namespace thrust::cuda_cub::core;
24
25 template<typename T, int ROWBS, int COLBS>
gemvT_batched_kernel(const int m,const int n,const T * __restrict__ alpha,const T * const A[],const int lda,const T * const x[],const int incx,const T * __restrict__ beta,T * const y[],const int incy)26 __global__ void gemvT_batched_kernel(const int m, // number of columns in row major A
27 const int n, // number of rows in row major A
28 const T* __restrict__ alpha,
29 const T* const A[],
30 const int lda,
31 const T* const x[],
32 const int incx,
33 const T* __restrict__ beta,
34 T* const y[],
35 const int incy)
36 {
37 static_assert(ROWBS <= COLBS, "Row block size must not be larger than column block size!");
38
39 constexpr int SUM_SIZE = ROWBS * COLBS;
40 __shared__ uninitialized_array<T, SUM_SIZE> sum;
41 __shared__ uninitialized_array<T, COLBS> x_part;
42
43 const int tid = threadIdx.x;
44 for (int i = 0; i < ROWBS; i++)
45 sum[i * COLBS + tid] = T(0.0);
46
47 const T* __restrict__ A_iw = A[blockIdx.x];
48 const T* __restrict__ x_iw = x[blockIdx.x];
49
50 const int row_begin = blockIdx.y * ROWBS;
51 const int row_max = (n - row_begin) < ROWBS ? (n - row_begin) : ROWBS;
52
53 const int num_col_blocks = (m + COLBS - 1) / COLBS;
54 for (int ib = 0; ib < num_col_blocks; ib++)
55 {
56 const int col_id = ib * COLBS + tid;
57 if (col_id < m)
58 x_part[tid] = x_iw[col_id * incx];
59 for (int row_id = row_begin; row_id < row_begin + row_max; row_id++)
60 if (col_id < m)
61 sum[(row_id - row_begin) * COLBS + tid] += x_part[tid] * A_iw[row_id * lda + col_id];
62 }
63
64 for (int iend = COLBS / 2; iend > 0; iend /= 2)
65 {
66 __syncthreads();
67 for (int irow = 0; irow < row_max; irow++)
68 if (tid < iend)
69 sum[irow * COLBS + tid] += sum[irow * COLBS + tid + iend];
70 }
71
72 __syncthreads();
73 T* __restrict__ y_iw = y[blockIdx.x];
74 if (tid < row_max)
75 {
76 if (beta[blockIdx.x] == T(0))
77 y_iw[(row_begin + tid) * incy] =
78 alpha[blockIdx.x] * sum[tid * COLBS]; // protecting NaN from y_iw
79 else
80 y_iw[(row_begin + tid) * incy] =
81 alpha[blockIdx.x] * sum[tid * COLBS] + beta[blockIdx.x] * y_iw[(row_begin + tid) * incy];
82 }
83 }
84
85 template<typename T, int ROWBS>
gemvN_batched_kernel(const int m,const int n,const T * __restrict__ alpha,const T * const A[],const int lda,const T * const x[],const int incx,const T * __restrict__ beta,T * const y[],const int incy)86 __global__ void gemvN_batched_kernel(const int m, // number of columns in row major A
87 const int n, // number of rows in row major A
88 const T* __restrict__ alpha,
89 const T* const A[],
90 const int lda,
91 const T* const x[],
92 const int incx,
93 const T* __restrict__ beta,
94 T* const y[],
95 const int incy)
96 {
97 const T* __restrict__ A_iw = A[blockIdx.x];
98 const T* __restrict__ x_iw = x[blockIdx.x];
99 T* __restrict__ y_iw = y[blockIdx.x];
100
101 const int tid = threadIdx.x;
102 const int row_begin = blockIdx.y * ROWBS;
103
104 if (row_begin + tid < m)
105 {
106 T sum(0);
107 for (int col_id = 0; col_id < n; col_id++)
108 sum += A_iw[col_id * lda + row_begin + tid] * x_iw[col_id * incx];
109 if (beta[blockIdx.x] == T(0))
110 y_iw[(row_begin + tid) * incy] = alpha[blockIdx.x] * sum; // protecting NaN from y_iw
111 else
112 y_iw[(row_begin + tid) * incy] = alpha[blockIdx.x] * sum + beta[blockIdx.x] * y_iw[(row_begin + tid) * incy];
113 }
114 }
115
116 template<typename T>
gemv_batched_impl(cuBLAS_MFs_handle & handle,const char trans,const int m,const int n,const T * alpha,const T * const A[],const int lda,const T * const x[],const int incx,const T * beta,T * const y[],const int incy,const int batch_count)117 cuBLAS_MFs_status gemv_batched_impl(cuBLAS_MFs_handle& handle,
118 const char trans,
119 const int m,
120 const int n,
121 const T* alpha,
122 const T* const A[],
123 const int lda,
124 const T* const x[],
125 const int incx,
126 const T* beta,
127 T* const y[],
128 const int incy,
129 const int batch_count)
130 {
131 if (batch_count == 0 || m == 0 || n == 0)
132 return cudaSuccess;
133
134 if (trans == 'T')
135 {
136 const int ROWBS = 4;
137 const int COLBS = 64;
138 const int num_row_blocks = (n + ROWBS - 1) / ROWBS;
139 dim3 dimBlock(COLBS);
140 dim3 dimGrid(batch_count, num_row_blocks);
141 gemvT_batched_kernel<T, ROWBS, COLBS>
142 <<<dimGrid, dimBlock, 0, handle>>>(m, n, alpha, A, lda, x, incx, beta, y, incy);
143 }
144 else
145 {
146 const int ROWBS = 64;
147 const int num_row_blocks = (m + ROWBS - 1) / ROWBS;
148 dim3 dimBlock(ROWBS);
149 dim3 dimGrid(batch_count, num_row_blocks);
150 gemvN_batched_kernel<T, ROWBS><<<dimGrid, dimBlock, 0, handle>>>(m, n, alpha, A, lda, x, incx, beta, y, incy);
151 }
152 return cudaPeekAtLastError();
153 }
154
gemv_batched(cuBLAS_MFs_handle & handle,const char trans,const int m,const int n,const float * alpha,const float * const A[],const int lda,const float * const x[],const int incx,const float * beta,float * const y[],const int incy,const int batch_count)155 cuBLAS_MFs_status gemv_batched(cuBLAS_MFs_handle& handle,
156 const char trans,
157 const int m,
158 const int n,
159 const float* alpha,
160 const float* const A[],
161 const int lda,
162 const float* const x[],
163 const int incx,
164 const float* beta,
165 float* const y[],
166 const int incy,
167 const int batch_count)
168 {
169 return gemv_batched_impl(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy, batch_count);
170 }
171
gemv_batched(cuBLAS_MFs_handle & handle,const char trans,const int m,const int n,const double * alpha,const double * const A[],const int lda,const double * const x[],const int incx,const double * beta,double * const y[],const int incy,const int batch_count)172 cuBLAS_MFs_status gemv_batched(cuBLAS_MFs_handle& handle,
173 const char trans,
174 const int m,
175 const int n,
176 const double* alpha,
177 const double* const A[],
178 const int lda,
179 const double* const x[],
180 const int incx,
181 const double* beta,
182 double* const y[],
183 const int incy,
184 const int batch_count)
185 {
186 return gemv_batched_impl(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy, batch_count);
187 }
188
gemv_batched(cuBLAS_MFs_handle & handle,const char trans,const int m,const int n,const std::complex<float> * alpha,const std::complex<float> * const A[],const int lda,const std::complex<float> * const x[],const int incx,const std::complex<float> * beta,std::complex<float> * const y[],const int incy,const int batch_count)189 cuBLAS_MFs_status gemv_batched(cuBLAS_MFs_handle& handle,
190 const char trans,
191 const int m,
192 const int n,
193 const std::complex<float>* alpha,
194 const std::complex<float>* const A[],
195 const int lda,
196 const std::complex<float>* const x[],
197 const int incx,
198 const std::complex<float>* beta,
199 std::complex<float>* const y[],
200 const int incy,
201 const int batch_count)
202 {
203 return gemv_batched_impl(handle, trans, m, n, (const thrust::complex<float>*)alpha, (const thrust::complex<float>**)A, lda, (const thrust::complex<float>**)x, incx, (const thrust::complex<float>*)beta, (thrust::complex<float>**)y, incy, batch_count);
204 }
205
gemv_batched(cuBLAS_MFs_handle & handle,const char trans,const int m,const int n,const std::complex<double> * alpha,const std::complex<double> * const A[],const int lda,const std::complex<double> * const x[],const int incx,const std::complex<double> * beta,std::complex<double> * const y[],const int incy,const int batch_count)206 cuBLAS_MFs_status gemv_batched(cuBLAS_MFs_handle& handle,
207 const char trans,
208 const int m,
209 const int n,
210 const std::complex<double>* alpha,
211 const std::complex<double>* const A[],
212 const int lda,
213 const std::complex<double>* const x[],
214 const int incx,
215 const std::complex<double>* beta,
216 std::complex<double>* const y[],
217 const int incy,
218 const int batch_count)
219 {
220 return gemv_batched_impl(handle, trans, m, n, (const thrust::complex<double>*)alpha, (const thrust::complex<double>**)A, lda, (const thrust::complex<double>**)x, incx, (const thrust::complex<double>*)beta, (thrust::complex<double>**)y, incy, batch_count);
221 }
222
223
224 template<typename T, int ROWBS, int COLBS>
ger_batched_kernel(const int m,const int n,const T * __restrict__ alpha,const T * const x[],const int incx,const T * const y[],const int incy,T * const A[],const int lda)225 __global__ void ger_batched_kernel(const int m, // number of columns in row major A
226 const int n, // number of rows in row major A
227 const T* __restrict__ alpha,
228 const T* const x[],
229 const int incx,
230 const T* const y[],
231 const int incy,
232 T* const A[],
233 const int lda)
234 {
235 const int iw = blockIdx.x;
236 const T* __restrict__ x_iw = x[iw];
237 const T* __restrict__ y_iw = y[iw];
238 T* __restrict__ A_iw = A[iw];
239
240 const int row_begin = blockIdx.y * ROWBS;
241 const int row_end = (row_begin + ROWBS) < n ? (row_begin + ROWBS) : n;
242 const int tid = threadIdx.x;
243 const int col_id = blockIdx.z * COLBS + tid;
244
245 __shared__ uninitialized_array<T, COLBS> x_part;
246 if (col_id < m)
247 x_part[tid] = x_iw[col_id * incx];
248
249 for (int row_id = row_begin; row_id < row_end; row_id++)
250 if (col_id < m)
251 A_iw[row_id * lda + col_id] += alpha[iw] * x_part[tid] * y_iw[row_id * incy];
252 }
253
254 template<typename T>
ger_batched_impl(cuBLAS_MFs_handle & handle,const int m,const int n,const T * alpha,const T * const x[],const int incx,const T * const y[],const int incy,T * const A[],const int lda,const int batch_count)255 cuBLAS_MFs_status ger_batched_impl(cuBLAS_MFs_handle& handle,
256 const int m,
257 const int n,
258 const T* alpha,
259 const T* const x[],
260 const int incx,
261 const T* const y[],
262 const int incy,
263 T* const A[],
264 const int lda,
265 const int batch_count)
266 {
267 if (batch_count == 0 || m == 0 || n == 0)
268 return cudaSuccess;
269
270 const int ROWBS = 16;
271 const int COLBS = 64;
272 const int num_row_blocks = (n + ROWBS - 1) / ROWBS;
273 const int num_col_blocks = (m + COLBS - 1) / COLBS;
274 dim3 dimBlock(COLBS);
275 dim3 dimGrid(batch_count, num_row_blocks, num_col_blocks);
276 ger_batched_kernel<T, ROWBS, COLBS><<<dimGrid, dimBlock, 0, handle>>>(m, n, alpha, x, incx, y, incy, A, lda);
277 return cudaPeekAtLastError();
278 }
279
ger_batched(cuBLAS_MFs_handle & handle,const int m,const int n,const float * alpha,const float * const x[],const int incx,const float * const y[],const int incy,float * const A[],const int lda,const int batch_count)280 cuBLAS_MFs_status ger_batched(cuBLAS_MFs_handle& handle,
281 const int m,
282 const int n,
283 const float* alpha,
284 const float* const x[],
285 const int incx,
286 const float* const y[],
287 const int incy,
288 float* const A[],
289 const int lda,
290 const int batch_count)
291 {
292 return ger_batched_impl(handle, m, n, alpha, x, incx, y, incy, A, lda, batch_count);
293 }
294
ger_batched(cuBLAS_MFs_handle & handle,const int m,const int n,const double * alpha,const double * const x[],const int incx,const double * const y[],const int incy,double * const A[],const int lda,const int batch_count)295 cuBLAS_MFs_status ger_batched(cuBLAS_MFs_handle& handle,
296 const int m,
297 const int n,
298 const double* alpha,
299 const double* const x[],
300 const int incx,
301 const double* const y[],
302 const int incy,
303 double* const A[],
304 const int lda,
305 const int batch_count)
306 {
307 return ger_batched_impl(handle, m, n, alpha, x, incx, y, incy, A, lda, batch_count);
308 }
309
ger_batched(cuBLAS_MFs_handle & handle,const int m,const int n,const std::complex<float> * alpha,const std::complex<float> * const x[],const int incx,const std::complex<float> * const y[],const int incy,std::complex<float> * const A[],const int lda,const int batch_count)310 cuBLAS_MFs_status ger_batched(cuBLAS_MFs_handle& handle,
311 const int m,
312 const int n,
313 const std::complex<float>* alpha,
314 const std::complex<float>* const x[],
315 const int incx,
316 const std::complex<float>* const y[],
317 const int incy,
318 std::complex<float>* const A[],
319 const int lda,
320 const int batch_count)
321 {
322 return ger_batched_impl(handle, m, n, (const thrust::complex<float>*)alpha, (const thrust::complex<float>**)x, incx, (const thrust::complex<float>**)y, incy, (thrust::complex<float>**)A, lda, batch_count);
323 }
324
ger_batched(cuBLAS_MFs_handle & handle,const int m,const int n,const std::complex<double> * alpha,const std::complex<double> * const x[],const int incx,const std::complex<double> * const y[],const int incy,std::complex<double> * const A[],const int lda,const int batch_count)325 cuBLAS_MFs_status ger_batched(cuBLAS_MFs_handle& handle,
326 const int m,
327 const int n,
328 const std::complex<double>* alpha,
329 const std::complex<double>* const x[],
330 const int incx,
331 const std::complex<double>* const y[],
332 const int incy,
333 std::complex<double>* const A[],
334 const int lda,
335 const int batch_count)
336 {
337 return ger_batched_impl(handle, m, n, (const thrust::complex<double>*)alpha, (const thrust::complex<double>**)x, incx, (const thrust::complex<double>**)y, incy, (thrust::complex<double>**)A, lda, batch_count);
338 }
339
340 template<typename T, int COLBS>
copy_batched_kernel(const int n,const T * const in[],const int incx,T * const out[],const int incy)341 __global__ void copy_batched_kernel(const int n, const T* const in[], const int incx, T* const out[], const int incy)
342 {
343 const int iw = blockIdx.x;
344 const T* __restrict__ in_iw = in[iw];
345 T* __restrict__ out_iw = out[iw];
346
347 const int col_id = blockIdx.y * COLBS + threadIdx.x;
348 if (incx == 1 && incy == 1)
349 {
350 if (col_id < n)
351 out_iw[col_id] = in_iw[col_id];
352 }
353 else
354 {
355 if (col_id < n)
356 out_iw[col_id * incx] = in_iw[col_id * incx];
357 }
358 }
359
360 template<typename T>
copy_batched_impl(cudaStream_t & hstream,const int n,const T * const in[],const int incx,T * const out[],const int incy,const int batch_count)361 cuBLAS_MFs_status copy_batched_impl(cudaStream_t& hstream,
362 const int n,
363 const T* const in[],
364 const int incx,
365 T* const out[],
366 const int incy,
367 const int batch_count)
368 {
369 if (batch_count == 0 || n == 0)
370 return cudaSuccess;
371
372 const int COLBS = 128;
373 const int num_col_blocks = (n + COLBS - 1) / COLBS;
374 dim3 dimBlock(COLBS);
375 dim3 dimGrid(batch_count, num_col_blocks);
376 copy_batched_kernel<T, COLBS><<<dimGrid, dimBlock, 0, hstream>>>(n, in, incx, out, incy);
377 return cudaPeekAtLastError();
378 }
379
copy_batched(cudaStream_t & hstream,const int n,const float * const in[],const int incx,float * const out[],const int incy,const int batch_count)380 cuBLAS_MFs_status copy_batched(cudaStream_t& hstream,
381 const int n,
382 const float* const in[],
383 const int incx,
384 float* const out[],
385 const int incy,
386 const int batch_count)
387 {
388 return copy_batched_impl(hstream, n, in, incx, out, incy, batch_count);
389 }
390
copy_batched(cudaStream_t & hstream,const int n,const double * const in[],const int incx,double * const out[],const int incy,const int batch_count)391 cuBLAS_MFs_status copy_batched(cudaStream_t& hstream,
392 const int n,
393 const double* const in[],
394 const int incx,
395 double* const out[],
396 const int incy,
397 const int batch_count)
398 {
399 return copy_batched_impl(hstream, n, in, incx, out, incy, batch_count);
400 }
401
copy_batched(cudaStream_t & hstream,const int n,const std::complex<float> * const in[],const int incx,std::complex<float> * const out[],const int incy,const int batch_count)402 cuBLAS_MFs_status copy_batched(cudaStream_t& hstream,
403 const int n,
404 const std::complex<float>* const in[],
405 const int incx,
406 std::complex<float>* const out[],
407 const int incy,
408 const int batch_count)
409 {
410 return copy_batched_impl(hstream, n, (const cuComplex**)in, incx, (cuComplex**)out, incy, batch_count);
411 }
412
copy_batched(cudaStream_t & hstream,const int n,const std::complex<double> * const in[],const int incx,std::complex<double> * const out[],const int incy,const int batch_count)413 cuBLAS_MFs_status copy_batched(cudaStream_t& hstream,
414 const int n,
415 const std::complex<double>* const in[],
416 const int incx,
417 std::complex<double>* const out[],
418 const int incy,
419 const int batch_count)
420 {
421 return copy_batched_impl(hstream, n, (const cuDoubleComplex**)in, incx, (cuDoubleComplex**)out, incy, batch_count);
422 }
423
424 } // namespace cuBLAS_MFs
425 } // namespace qmcplusplus
426