1 #include "chainerx/cuda/cuda_device.h"
2
3 #include <cstdint>
4 #include <mutex>
5 #include <type_traits>
6
7 #include <cublas_v2.h>
8 #include <cuda_runtime.h>
9 #include <cusolverDn.h>
10 #include <cuda_fp16.hpp>
11
12 #include "chainerx/array.h"
13 #include "chainerx/axes.h"
14 #include "chainerx/backend.h"
15 #include "chainerx/backend_util.h"
16 #include "chainerx/cuda/cublas.h"
17 #include "chainerx/cuda/cuda_runtime.h"
18 #include "chainerx/cuda/cuda_set_device_scope.h"
19 #include "chainerx/cuda/cusolver.h"
20 #include "chainerx/cuda/data_type.cuh"
21 #include "chainerx/cuda/float16.cuh"
22 #include "chainerx/cuda/kernel_regist.h"
23 #include "chainerx/device.h"
24 #include "chainerx/dtype.h"
25 #include "chainerx/error.h"
26 #include "chainerx/float16.h"
27 #include "chainerx/kernels/creation.h"
28 #include "chainerx/kernels/linalg.h"
29 #include "chainerx/kernels/misc.h"
30 #include "chainerx/macro.h"
31 #include "chainerx/native/native_device.h"
32 #include "chainerx/routines/arithmetic.h"
33 #include "chainerx/routines/creation.h"
34 #include "chainerx/routines/indexing.h"
35 #include "chainerx/routines/linalg.h"
36
37 namespace chainerx {
38 namespace cuda {
39 namespace {
40
41 template <typename T>
GetrfBuffersize(cusolverDnHandle_t,int,int,T *,int,int *)42 cusolverStatus_t GetrfBuffersize(cusolverDnHandle_t /*handle*/, int /*m*/, int /*n*/, T* /*a*/, int /*lda*/, int* /*lwork*/) {
43 throw DtypeError{"Only Arrays of float or double type are supported by getrf (LU)"};
44 }
45
46 template <typename T>
Getrf(cusolverDnHandle_t,int,int,T *,int,T *,int *,int *)47 cusolverStatus_t Getrf(
48 cusolverDnHandle_t /*handle*/, int /*m*/, int /*n*/, T* /*a*/, int /*lda*/, T* /*workspace*/, int* /*devipiv*/, int* /*devinfo*/) {
49 throw DtypeError{"Only Arrays of float or double type are supported by getrf (LU)"};
50 }
51
52 template <typename T>
Getrs(cusolverDnHandle_t,cublasOperation_t,int,int,T *,int,int *,T *,int,int *)53 cusolverStatus_t Getrs(
54 cusolverDnHandle_t /*handle*/,
55 cublasOperation_t /*trans*/,
56 int /*n*/,
57 int /*nrhs*/,
58 T* /*a*/,
59 int /*lda*/,
60 int* /*devipiv*/,
61 T* /*b*/,
62 int /*ldb*/,
63 int* /*devinfo*/) {
64 throw DtypeError{"Only Arrays of float or double type are supported by getrs (Solve)"};
65 }
66
67 template <typename T>
GesvdBuffersize(cusolverDnHandle_t,int,int,int *)68 cusolverStatus_t GesvdBuffersize(cusolverDnHandle_t /*handle*/, int /*m*/, int /*n*/, int* /*lwork*/) {
69 throw DtypeError{"Only Arrays of float or double type are supported by gesvd (SVD)"};
70 }
71
72 template <typename T>
Gesvd(cusolverDnHandle_t,signed char,signed char,int,int,T *,int,T *,T *,int,T *,int,T *,int,T *,int *)73 cusolverStatus_t Gesvd(
74 cusolverDnHandle_t /*handle*/,
75 signed char /*jobu*/,
76 signed char /*jobvt*/,
77 int /*m*/,
78 int /*n*/,
79 T* /*a*/,
80 int /*lda*/,
81 T* /*s*/,
82 T* /*u*/,
83 int /*ldu*/,
84 T* /*vt*/,
85 int /*ldvt*/,
86 T* /*work*/,
87 int /*lwork*/,
88 T* /*rwork*/,
89 int* /*devinfo*/) {
90 throw DtypeError{"Only Arrays of float or double type are supported by gesvd (SVD)"};
91 }
92
93 template <typename T>
GeqrfBufferSize(cusolverDnHandle_t,int,int,T *,int,int *)94 cusolverStatus_t GeqrfBufferSize(cusolverDnHandle_t /*handle*/, int /*m*/, int /*n*/, T* /*a*/, int /*lda*/, int* /*lwork*/) {
95 throw DtypeError{"Only Arrays of float or double type are supported by geqrf (QR)"};
96 }
97
98 template <typename T>
Geqrf(cusolverDnHandle_t,int,int,T *,int,T *,T *,int,int *)99 cusolverStatus_t Geqrf(
100 cusolverDnHandle_t /*handle*/,
101 int /*m*/,
102 int /*n*/,
103 T* /*a*/,
104 int /*lda*/,
105 T* /*tau*/,
106 T* /*workspace*/,
107 int /*lwork*/,
108 int* /*devinfo*/) {
109 throw DtypeError{"Only Arrays of float or double type are supported by geqrf (QR)"};
110 }
111
112 template <typename T>
OrgqrBufferSize(cusolverDnHandle_t,int,int,int,T *,int,T *,int *)113 cusolverStatus_t OrgqrBufferSize(
114 cusolverDnHandle_t /*handle*/, int /*m*/, int /*n*/, int /*k*/, T* /*a*/, int /*lda*/, T* /*tau*/, int* /*lwork*/) {
115 throw DtypeError{"Only Arrays of float or double type are supported by orgqr (QR)"};
116 }
117
118 template <typename T>
Orgqr(cusolverDnHandle_t,int,int,int,T *,int,T *,T *,int,int *)119 cusolverStatus_t Orgqr(
120 cusolverDnHandle_t /*handle*/,
121 int /*m*/,
122 int /*n*/,
123 int /*k*/,
124 T* /*a*/,
125 int /*lda*/,
126 T* /*tau*/,
127 T* /*work*/,
128 int /*lwork*/,
129 int* /*devinfo*/) {
130 throw DtypeError{"Only Arrays of float or double type are supported by orgqr (QR)"};
131 }
132
133 template <typename T>
PotrfBuffersize(cusolverDnHandle_t,cublasFillMode_t,int,T *,int,int *)134 cusolverStatus_t PotrfBuffersize(
135 cusolverDnHandle_t /*handle*/, cublasFillMode_t /*uplo*/, int /*n*/, T* /*a*/, int /*lda*/, int* /*lwork*/) {
136 throw DtypeError{"Only Arrays of float or double type are supported by potrf (Cholesky)"};
137 }
138
139 template <typename T>
Potrf(cusolverDnHandle_t,cublasFillMode_t,int,T *,int,T *,int,int *)140 cusolverStatus_t Potrf(
141 cusolverDnHandle_t /*handle*/,
142 cublasFillMode_t /*uplo*/,
143 int /*n*/,
144 T* /*a*/,
145 int /*lda*/,
146 T* /*workspace*/,
147 int /*lwork*/,
148 int* /*devinfo*/) {
149 throw DtypeError{"Only Arrays of float or double type are supported by potrf (Cholesky)"};
150 }
151
152 template <typename T>
SyevdBuffersize(cusolverDnHandle_t,cusolverEigMode_t,cublasFillMode_t,int,T *,int,T *,int *)153 cusolverStatus_t SyevdBuffersize(
154 cusolverDnHandle_t /*handle*/,
155 cusolverEigMode_t /*jobz*/,
156 cublasFillMode_t /*uplo*/,
157 int /*n*/,
158 T* /*a*/,
159 int /*lda*/,
160 T* /*w*/,
161 int* /*lwork*/) {
162 throw DtypeError{"Only Arrays of float or double type are supported by syevd (Eigen)"};
163 }
164
165 template <typename T>
Syevd(cusolverDnHandle_t,cusolverEigMode_t,cublasFillMode_t,int,T *,int,T *,T *,int,int *)166 cusolverStatus_t Syevd(
167 cusolverDnHandle_t /*handle*/,
168 cusolverEigMode_t /*jobz*/,
169 cublasFillMode_t /*uplo*/,
170 int /*n*/,
171 T* /*a*/,
172 int /*lda*/,
173 T* /*w*/,
174 T* /*work*/,
175 int /*lwork*/,
176 int* /*devinfo*/) {
177 throw DtypeError{"Only Arrays of float or double type are supported by syevd (Eigen)"};
178 }
179
180 template <>
GetrfBuffersize(cusolverDnHandle_t handle,int m,int n,double * a,int lda,int * lwork)181 cusolverStatus_t GetrfBuffersize<double>(cusolverDnHandle_t handle, int m, int n, double* a, int lda, int* lwork) {
182 return cusolverDnDgetrf_bufferSize(handle, m, n, a, lda, lwork);
183 }
184
185 template <>
GetrfBuffersize(cusolverDnHandle_t handle,int m,int n,float * a,int lda,int * lwork)186 cusolverStatus_t GetrfBuffersize<float>(cusolverDnHandle_t handle, int m, int n, float* a, int lda, int* lwork) {
187 return cusolverDnSgetrf_bufferSize(handle, m, n, a, lda, lwork);
188 }
189
190 template <>
Getrf(cusolverDnHandle_t handle,int m,int n,double * a,int lda,double * workspace,int * devipiv,int * devinfo)191 cusolverStatus_t Getrf<double>(cusolverDnHandle_t handle, int m, int n, double* a, int lda, double* workspace, int* devipiv, int* devinfo) {
192 return cusolverDnDgetrf(handle, m, n, a, lda, workspace, devipiv, devinfo);
193 }
194
195 template <>
Getrf(cusolverDnHandle_t handle,int m,int n,float * a,int lda,float * workspace,int * devipiv,int * devinfo)196 cusolverStatus_t Getrf<float>(cusolverDnHandle_t handle, int m, int n, float* a, int lda, float* workspace, int* devipiv, int* devinfo) {
197 return cusolverDnSgetrf(handle, m, n, a, lda, workspace, devipiv, devinfo);
198 }
199
200 template <>
Getrs(cusolverDnHandle_t handle,cublasOperation_t trans,int n,int nrhs,double * a,int lda,int * devipiv,double * b,int ldb,int * devinfo)201 cusolverStatus_t Getrs<double>(
202 cusolverDnHandle_t handle,
203 cublasOperation_t trans,
204 int n,
205 int nrhs,
206 double* a,
207 int lda,
208 int* devipiv,
209 double* b,
210 int ldb,
211 int* devinfo) {
212 return cusolverDnDgetrs(handle, trans, n, nrhs, a, lda, devipiv, b, ldb, devinfo);
213 }
214
215 template <>
Getrs(cusolverDnHandle_t handle,cublasOperation_t trans,int n,int nrhs,float * a,int lda,int * devipiv,float * b,int ldb,int * devinfo)216 cusolverStatus_t Getrs<float>(
217 cusolverDnHandle_t handle,
218 cublasOperation_t trans,
219 int n,
220 int nrhs,
221 float* a,
222 int lda,
223 int* devipiv,
224 float* b,
225 int ldb,
226 int* devinfo) {
227 return cusolverDnSgetrs(handle, trans, n, nrhs, a, lda, devipiv, b, ldb, devinfo);
228 }
229
230 template <>
GesvdBuffersize(cusolverDnHandle_t handle,int m,int n,int * lwork)231 cusolverStatus_t GesvdBuffersize<double>(cusolverDnHandle_t handle, int m, int n, int* lwork) {
232 return cusolverDnDgesvd_bufferSize(handle, m, n, lwork);
233 }
234
235 template <>
GesvdBuffersize(cusolverDnHandle_t handle,int m,int n,int * lwork)236 cusolverStatus_t GesvdBuffersize<float>(cusolverDnHandle_t handle, int m, int n, int* lwork) {
237 return cusolverDnSgesvd_bufferSize(handle, m, n, lwork);
238 }
239
240 template <>
Gesvd(cusolverDnHandle_t handle,signed char jobu,signed char jobvt,int m,int n,double * a,int lda,double * s,double * u,int ldu,double * vt,int ldvt,double * work,int lwork,double * rwork,int * devinfo)241 cusolverStatus_t Gesvd<double>(
242 cusolverDnHandle_t handle,
243 signed char jobu,
244 signed char jobvt,
245 int m,
246 int n,
247 double* a,
248 int lda,
249 double* s,
250 double* u,
251 int ldu,
252 double* vt,
253 int ldvt,
254 double* work,
255 int lwork,
256 double* rwork,
257 int* devinfo) {
258 return cusolverDnDgesvd(handle, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, rwork, devinfo);
259 }
260
261 template <>
Gesvd(cusolverDnHandle_t handle,signed char jobu,signed char jobvt,int m,int n,float * a,int lda,float * s,float * u,int ldu,float * vt,int ldvt,float * work,int lwork,float * rwork,int * devinfo)262 cusolverStatus_t Gesvd<float>(
263 cusolverDnHandle_t handle,
264 signed char jobu,
265 signed char jobvt,
266 int m,
267 int n,
268 float* a,
269 int lda,
270 float* s,
271 float* u,
272 int ldu,
273 float* vt,
274 int ldvt,
275 float* work,
276 int lwork,
277 float* rwork,
278 int* devinfo) {
279 return cusolverDnSgesvd(handle, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, rwork, devinfo);
280 }
281
282 template <>
GeqrfBufferSize(cusolverDnHandle_t handle,int m,int n,double * a,int lda,int * lwork)283 cusolverStatus_t GeqrfBufferSize<double>(cusolverDnHandle_t handle, int m, int n, double* a, int lda, int* lwork) {
284 return cusolverDnDgeqrf_bufferSize(handle, m, n, a, lda, lwork);
285 }
286
287 template <>
GeqrfBufferSize(cusolverDnHandle_t handle,int m,int n,float * a,int lda,int * lwork)288 cusolverStatus_t GeqrfBufferSize<float>(cusolverDnHandle_t handle, int m, int n, float* a, int lda, int* lwork) {
289 return cusolverDnSgeqrf_bufferSize(handle, m, n, a, lda, lwork);
290 }
291
292 template <>
Geqrf(cusolverDnHandle_t handle,int m,int n,double * a,int lda,double * tau,double * workspace,int lwork,int * devinfo)293 cusolverStatus_t Geqrf<double>(
294 cusolverDnHandle_t handle, int m, int n, double* a, int lda, double* tau, double* workspace, int lwork, int* devinfo) {
295 return cusolverDnDgeqrf(handle, m, n, a, lda, tau, workspace, lwork, devinfo);
296 }
297
298 template <>
Geqrf(cusolverDnHandle_t handle,int m,int n,float * a,int lda,float * tau,float * workspace,int lwork,int * devinfo)299 cusolverStatus_t Geqrf<float>(
300 cusolverDnHandle_t handle, int m, int n, float* a, int lda, float* tau, float* workspace, int lwork, int* devinfo) {
301 return cusolverDnSgeqrf(handle, m, n, a, lda, tau, workspace, lwork, devinfo);
302 }
303
304 template <>
OrgqrBufferSize(cusolverDnHandle_t handle,int m,int n,int k,double * a,int lda,double * tau,int * lwork)305 cusolverStatus_t OrgqrBufferSize<double>(cusolverDnHandle_t handle, int m, int n, int k, double* a, int lda, double* tau, int* lwork) {
306 return cusolverDnDorgqr_bufferSize(handle, m, n, k, a, lda, tau, lwork);
307 }
308
309 template <>
OrgqrBufferSize(cusolverDnHandle_t handle,int m,int n,int k,float * a,int lda,float * tau,int * lwork)310 cusolverStatus_t OrgqrBufferSize<float>(cusolverDnHandle_t handle, int m, int n, int k, float* a, int lda, float* tau, int* lwork) {
311 return cusolverDnSorgqr_bufferSize(handle, m, n, k, a, lda, tau, lwork);
312 }
313
314 template <>
Orgqr(cusolverDnHandle_t handle,int m,int n,int k,double * a,int lda,double * tau,double * work,int lwork,int * devinfo)315 cusolverStatus_t Orgqr<double>(
316 cusolverDnHandle_t handle, int m, int n, int k, double* a, int lda, double* tau, double* work, int lwork, int* devinfo) {
317 return cusolverDnDorgqr(handle, m, n, k, a, lda, tau, work, lwork, devinfo);
318 }
319
320 template <>
Orgqr(cusolverDnHandle_t handle,int m,int n,int k,float * a,int lda,float * tau,float * work,int lwork,int * devinfo)321 cusolverStatus_t Orgqr<float>(
322 cusolverDnHandle_t handle, int m, int n, int k, float* a, int lda, float* tau, float* work, int lwork, int* devinfo) {
323 return cusolverDnSorgqr(handle, m, n, k, a, lda, tau, work, lwork, devinfo);
324 }
325
326 template <>
PotrfBuffersize(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,double * a,int lda,int * lwork)327 cusolverStatus_t PotrfBuffersize<double>(cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double* a, int lda, int* lwork) {
328 return cusolverDnDpotrf_bufferSize(handle, uplo, n, a, lda, lwork);
329 }
330
331 template <>
PotrfBuffersize(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,float * a,int lda,int * lwork)332 cusolverStatus_t PotrfBuffersize<float>(cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, float* a, int lda, int* lwork) {
333 return cusolverDnSpotrf_bufferSize(handle, uplo, n, a, lda, lwork);
334 }
335
336 template <>
Potrf(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,double * a,int lda,double * workspace,int lwork,int * devinfo)337 cusolverStatus_t Potrf<double>(
338 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double* a, int lda, double* workspace, int lwork, int* devinfo) {
339 return cusolverDnDpotrf(handle, uplo, n, a, lda, workspace, lwork, devinfo);
340 }
341
342 template <>
Potrf(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,float * a,int lda,float * workspace,int lwork,int * devinfo)343 cusolverStatus_t Potrf<float>(
344 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, float* a, int lda, float* workspace, int lwork, int* devinfo) {
345 return cusolverDnSpotrf(handle, uplo, n, a, lda, workspace, lwork, devinfo);
346 }
347
348 template <>
SyevdBuffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,double * a,int lda,double * w,int * lwork)349 cusolverStatus_t SyevdBuffersize<double>(
350 cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, double* a, int lda, double* w, int* lwork) {
351 return cusolverDnDsyevd_bufferSize(handle, jobz, uplo, n, a, lda, w, lwork);
352 }
353
354 template <>
SyevdBuffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,float * a,int lda,float * w,int * lwork)355 cusolverStatus_t SyevdBuffersize<float>(
356 cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, float* a, int lda, float* w, int* lwork) {
357 return cusolverDnSsyevd_bufferSize(handle, jobz, uplo, n, a, lda, w, lwork);
358 }
359
360 template <>
Syevd(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,double * a,int lda,double * w,double * work,int lwork,int * devinfo)361 cusolverStatus_t Syevd<double>(
362 cusolverDnHandle_t handle,
363 cusolverEigMode_t jobz,
364 cublasFillMode_t uplo,
365 int n,
366 double* a,
367 int lda,
368 double* w,
369 double* work,
370 int lwork,
371 int* devinfo) {
372 return cusolverDnDsyevd(handle, jobz, uplo, n, a, lda, w, work, lwork, devinfo);
373 }
374
375 template <>
Syevd(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,float * a,int lda,float * w,float * work,int lwork,int * devinfo)376 cusolverStatus_t Syevd<float>(
377 cusolverDnHandle_t handle,
378 cusolverEigMode_t jobz,
379 cublasFillMode_t uplo,
380 int n,
381 float* a,
382 int lda,
383 float* w,
384 float* work,
385 int lwork,
386 int* devinfo) {
387 return cusolverDnSsyevd(handle, jobz, uplo, n, a, lda, w, work, lwork, devinfo);
388 }
389
390 template <typename T>
SolveImpl(const Array & a,const Array & b,const Array & out)391 void SolveImpl(const Array& a, const Array& b, const Array& out) {
392 Device& device = a.device();
393 Dtype dtype = a.dtype();
394
395 cuda_internal::DeviceInternals& device_internals = cuda_internal::GetDeviceInternals(static_cast<CudaDevice&>(device));
396
397 Array lu_matrix = Empty(a.shape(), dtype, device);
398 device.backend().CallKernel<CopyKernel>(a.Transpose(), lu_matrix);
399 auto lu_ptr = static_cast<T*>(internal::GetRawOffsetData(lu_matrix));
400
401 int64_t m = a.shape()[0];
402 int64_t lda = std::max(int64_t{1}, m);
403 int64_t nrhs = 1;
404 if (b.ndim() == 2) {
405 nrhs = b.shape()[1];
406 }
407
408 Array ipiv = Empty(Shape{m}, Dtype::kInt32, device);
409 auto ipiv_ptr = static_cast<int*>(internal::GetRawOffsetData(ipiv));
410
411 int buffersize = 0;
412 device_internals.cusolverdn_handle().Call(GetrfBuffersize<T>, m, m, lu_ptr, lda, &buffersize);
413
414 Array work = Empty(Shape{buffersize}, dtype, device);
415 auto work_ptr = static_cast<T*>(internal::GetRawOffsetData(work));
416
417 std::shared_ptr<void> devinfo = device.Allocate(sizeof(int));
418
419 device_internals.cusolverdn_handle().Call(Getrf<T>, m, m, lu_ptr, lda, work_ptr, ipiv_ptr, static_cast<int*>(devinfo.get()));
420
421 int devinfo_h = 0;
422 Device& native_device = GetDefaultContext().GetDevice({"native", 0});
423 device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
424 if (devinfo_h != 0) {
425 throw ChainerxError{"Unsuccessful getrf (LU) execution. Info = ", devinfo_h};
426 }
427
428 Array out_transposed = b.Transpose().Copy();
429 auto out_ptr = static_cast<T*>(internal::GetRawOffsetData(out_transposed));
430
431 device_internals.cusolverdn_handle().Call(
432 Getrs<T>, CUBLAS_OP_N, m, nrhs, lu_ptr, lda, ipiv_ptr, out_ptr, lda, static_cast<int*>(devinfo.get()));
433
434 device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
435 if (devinfo_h != 0) {
436 throw ChainerxError{"Unsuccessful getrs (Solve) execution. Info = ", devinfo_h};
437 }
438
439 device.backend().CallKernel<CopyKernel>(out_transposed.Transpose(), out);
440 }
441
442 template <typename T>
QrImpl(const Array & a,const Array & q,const Array & r,const Array & tau,QrMode mode)443 void QrImpl(const Array& a, const Array& q, const Array& r, const Array& tau, QrMode mode) {
444 Device& device = a.device();
445 Dtype dtype = a.dtype();
446
447 int64_t m = a.shape()[0];
448 int64_t n = a.shape()[1];
449 int64_t k = std::min(m, n);
450 int64_t lda = std::max(int64_t{1}, m);
451
452 // cuSOLVER does not return correct result in this case and older versions of cuSOLVER (<10.1)
453 // might not work well with zero-sized arrays therefore it's better to return earlier
454 if (a.shape().GetTotalSize() == 0) {
455 if (mode == QrMode::kComplete) {
456 device.backend().CallKernel<IdentityKernel>(q);
457 }
458 return;
459 }
460
461 Array r_temp = a.Transpose().Copy(); // QR decomposition is done in-place
462
463 cuda_internal::DeviceInternals& device_internals = cuda_internal::GetDeviceInternals(static_cast<CudaDevice&>(device));
464
465 auto r_ptr = static_cast<T*>(internal::GetRawOffsetData(r_temp));
466 auto tau_ptr = static_cast<T*>(internal::GetRawOffsetData(tau));
467
468 std::shared_ptr<void> devinfo = device.Allocate(sizeof(int));
469
470 int buffersize_geqrf = 0;
471 device_internals.cusolverdn_handle().Call(GeqrfBufferSize<T>, m, n, r_ptr, lda, &buffersize_geqrf);
472
473 Array work = Empty(Shape{buffersize_geqrf}, dtype, device);
474 auto work_ptr = static_cast<T*>(internal::GetRawOffsetData(work));
475
476 device_internals.cusolverdn_handle().Call(
477 Geqrf<T>, m, n, r_ptr, lda, tau_ptr, work_ptr, buffersize_geqrf, static_cast<int*>(devinfo.get()));
478
479 int devinfo_h = 0;
480 Device& native_device = GetDefaultContext().GetDevice({"native", 0});
481 device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
482 if (devinfo_h != 0) {
483 throw ChainerxError{"Unsuccessful geqrf (QR) execution. Info = ", devinfo_h};
484 }
485
486 if (mode == QrMode::kR) {
487 r_temp = r_temp.At(std::vector<ArrayIndex>{Slice{}, Slice{0, k}}).Transpose(); // R = R[:, 0:k].T
488 r_temp = Triu(r_temp, 0);
489 device.backend().CallKernel<CopyKernel>(r_temp, r);
490 return;
491 }
492
493 if (mode == QrMode::kRaw) {
494 device.backend().CallKernel<CopyKernel>(r_temp, r);
495 return;
496 }
497
498 int64_t mc;
499 Shape q_shape{0};
500 if (mode == QrMode::kComplete && m > n) {
501 mc = m;
502 q_shape = Shape{m, m};
503 } else {
504 mc = k;
505 q_shape = Shape{n, m};
506 }
507 Array q_temp = Empty(q_shape, dtype, device);
508
509 device.backend().CallKernel<CopyKernel>(r_temp, q_temp.At(std::vector<ArrayIndex>{Slice{0, n}, Slice{}})); // Q[0:n, :] = R
510 auto q_ptr = static_cast<T*>(internal::GetRawOffsetData(q_temp));
511
512 int buffersize_orgqr = 0;
513 device_internals.cusolverdn_handle().Call(OrgqrBufferSize<T>, m, mc, k, q_ptr, lda, tau_ptr, &buffersize_orgqr);
514
515 Array work_orgqr = Empty(Shape{buffersize_orgqr}, dtype, device);
516 auto work_orgqr_ptr = static_cast<T*>(internal::GetRawOffsetData(work_orgqr));
517
518 device_internals.cusolverdn_handle().Call(
519 Orgqr<T>, m, mc, k, q_ptr, lda, tau_ptr, work_orgqr_ptr, buffersize_orgqr, static_cast<int*>(devinfo.get()));
520
521 device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
522 if (devinfo_h != 0) {
523 throw ChainerxError{"Unsuccessful orgqr (QR) execution. Info = ", devinfo_h};
524 }
525
526 q_temp = q_temp.At(std::vector<ArrayIndex>{Slice{0, mc}, Slice{}}).Transpose(); // Q = Q[0:mc, :].T
527 r_temp = r_temp.At(std::vector<ArrayIndex>{Slice{}, Slice{0, mc}}).Transpose(); // R = R[:, 0:mc].T
528 r_temp = Triu(r_temp, 0);
529
530 device.backend().CallKernel<CopyKernel>(q_temp, q);
531 device.backend().CallKernel<CopyKernel>(r_temp, r);
532 }
533
534 } // namespace
535
536 class CudaSolveKernel : public SolveKernel {
537 public:
Call(const Array & a,const Array & b,const Array & out)538 void Call(const Array& a, const Array& b, const Array& out) override {
539 Device& device = a.device();
540 CudaSetDeviceScope scope{device.index()};
541
542 CHAINERX_ASSERT(a.ndim() == 2);
543 CHAINERX_ASSERT(a.shape()[0] == a.shape()[1]);
544
545 VisitFloatingPointDtype(out.dtype(), [&](auto pt) {
546 using T = typename decltype(pt)::type;
547 SolveImpl<T>(a.dtype() == out.dtype() ? a : a.AsType(out.dtype()), b.dtype() == out.dtype() ? b : b.AsType(out.dtype()), out);
548 });
549 }
550 };
551
552 CHAINERX_CUDA_REGISTER_KERNEL(SolveKernel, CudaSolveKernel);
553
554 class CudaInverseKernel : public InverseKernel {
555 public:
Call(const Array & a,const Array & out)556 void Call(const Array& a, const Array& out) override {
557 Device& device = a.device();
558 Dtype dtype = a.dtype();
559 CudaSetDeviceScope scope{device.index()};
560
561 CHAINERX_ASSERT(a.ndim() == 2);
562 CHAINERX_ASSERT(a.shape()[0] == a.shape()[1]);
563
564 // There is LAPACK routine ``getri`` for computing the inverse of an LU-factored matrix,
565 // but cuSOLVER does not have it implemented, therefore inverse is obtained with ``getrs``
566 // inv(A) == solve(A, Identity)
567 Array b = Identity(a.shape()[0], dtype, device);
568 device.backend().CallKernel<SolveKernel>(a, b, out);
569 }
570 };
571
572 CHAINERX_CUDA_REGISTER_KERNEL(InverseKernel, CudaInverseKernel);
573
574 class CudaSvdKernel : public SvdKernel {
575 public:
Call(const Array & a,const Array & u,const Array & s,const Array & vt,bool full_matrices,bool compute_uv)576 void Call(const Array& a, const Array& u, const Array& s, const Array& vt, bool full_matrices, bool compute_uv) override {
577 Device& device = a.device();
578 Dtype dtype = a.dtype();
579 CudaSetDeviceScope scope{device.index()};
580
581 CHAINERX_ASSERT(a.ndim() == 2);
582
583 if (a.shape().GetTotalSize() == 0) {
584 if (full_matrices && compute_uv) {
585 device.backend().CallKernel<IdentityKernel>(u);
586 device.backend().CallKernel<IdentityKernel>(vt);
587 }
588 // This kernel works correctly for zero-sized input also without early return
589 return;
590 }
591
592 // cuSOLVER assumes arrays are in column-major order.
593 // In order to avoid transposing the input matrix, matrix dimensions are swapped.
594 // Since the input is assumed to be transposed, it is necessary to
595 // swap the pointers to u and vt matrices when calling Gesvd.
596 int64_t n = a.shape()[0];
597 int64_t m = a.shape()[1];
598 int64_t k = std::min(m, n);
599
600 Array x = EmptyLike(a, device);
601 Array u_temp{};
602 Array vt_temp{};
603 bool trans_flag;
604
605 // Remark: gesvd only supports m>=n.
606 // See: https://docs.nvidia.com/cuda/cusolver/index.html#cuds-lt-t-gt-gesvd
607 // Therefore for the case m<n we calculuate svd of transposed matrix,
608 // instead of calculating svd(A) = U S V^T, we compute svd(A^T) = V S U^T
609 if (m >= n) {
610 device.backend().CallKernel<CopyKernel>(a, x);
611 trans_flag = false;
612 } else {
613 m = a.shape()[0];
614 n = a.shape()[1];
615 x = x.Reshape(Shape{n, m});
616 device.backend().CallKernel<CopyKernel>(a.Transpose(), x);
617 trans_flag = true;
618
619 // Temporary arrays for u, vt are needed to store transposed results
620 Shape u_shape;
621 Shape vt_shape;
622 if (compute_uv) {
623 if (full_matrices) {
624 u_shape = Shape{m, m};
625 vt_shape = Shape{n, n};
626 } else {
627 u_shape = Shape{k, m};
628 vt_shape = Shape{n, k};
629 }
630 } else {
631 u_shape = Shape{0};
632 vt_shape = Shape{0};
633 }
634 u_temp = Empty(u_shape, dtype, device);
635 vt_temp = Empty(vt_shape, dtype, device);
636 }
637
638 int64_t lda = std::max(int64_t{1}, m);
639 int64_t ldu = std::max(int64_t{1}, m);
640 int64_t ldvt = full_matrices ? std::max(int64_t{1}, n) : std::max(int64_t{1}, k);
641
642 auto svd_impl = [&](auto pt) {
643 using T = typename decltype(pt)::type;
644 cuda_internal::DeviceInternals& device_internals = cuda_internal::GetDeviceInternals(static_cast<CudaDevice&>(device));
645
646 auto x_ptr = static_cast<T*>(internal::GetRawOffsetData(x));
647 auto s_ptr = static_cast<T*>(internal::GetRawOffsetData(s));
648 auto u_ptr = static_cast<T*>(internal::GetRawOffsetData(u));
649 auto vt_ptr = static_cast<T*>(internal::GetRawOffsetData(vt));
650 if (trans_flag) {
651 u_ptr = static_cast<T*>(internal::GetRawOffsetData(vt_temp));
652 vt_ptr = static_cast<T*>(internal::GetRawOffsetData(u_temp));
653 }
654
655 std::shared_ptr<void> devinfo = device.Allocate(sizeof(int));
656
657 int buffersize = 0;
658 device_internals.cusolverdn_handle().Call(GesvdBuffersize<T>, m, n, &buffersize);
659
660 Array work = Empty(Shape{buffersize}, dtype, device);
661 auto work_ptr = static_cast<T*>(internal::GetRawOffsetData(work));
662
663 signed char job;
664 if (compute_uv) {
665 job = full_matrices ? 'A' : 'S';
666 } else {
667 job = 'N';
668 }
669
670 // When calling Gesvd pointers to u and vt are swapped instead of transposing the input matrix.
671 device_internals.cusolverdn_handle().Call(
672 Gesvd<T>,
673 job,
674 job,
675 m,
676 n,
677 x_ptr,
678 lda,
679 s_ptr,
680 vt_ptr,
681 ldu,
682 u_ptr,
683 ldvt,
684 work_ptr,
685 buffersize,
686 nullptr,
687 static_cast<int*>(devinfo.get()));
688
689 int devinfo_h = 0;
690 Device& native_device = GetDefaultContext().GetDevice({"native", 0});
691 device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
692 if (devinfo_h != 0) {
693 throw ChainerxError{"Unsuccessful gesvd (SVD) execution. Info = ", devinfo_h};
694 }
695
696 if (trans_flag) {
697 device.backend().CallKernel<CopyKernel>(u_temp.Transpose(), u);
698 device.backend().CallKernel<CopyKernel>(vt_temp.Transpose(), vt);
699 }
700 };
701
702 VisitFloatingPointDtype(dtype, svd_impl);
703 }
704 };
705
706 CHAINERX_CUDA_REGISTER_KERNEL(SvdKernel, CudaSvdKernel);
707
708 class CudaQrKernel : public QrKernel {
709 public:
Call(const Array & a,const Array & q,const Array & r,const Array & tau,QrMode mode)710 void Call(const Array& a, const Array& q, const Array& r, const Array& tau, QrMode mode) override {
711 Device& device = a.device();
712 Dtype dtype = a.dtype();
713 CudaSetDeviceScope scope{device.index()};
714
715 CHAINERX_ASSERT(a.ndim() == 2);
716
717 VisitFloatingPointDtype(dtype, [&](auto pt) {
718 using T = typename decltype(pt)::type;
719 QrImpl<T>(a, q, r, tau, mode);
720 });
721 }
722 };
723
724 CHAINERX_CUDA_REGISTER_KERNEL(QrKernel, CudaQrKernel);
725
726 class CudaCholeskyKernel : public CholeskyKernel {
727 public:
Call(const Array & a,const Array & out)728 void Call(const Array& a, const Array& out) override {
729 Device& device = a.device();
730 device.CheckDevicesCompatible(a, out);
731 Dtype dtype = a.dtype();
732 CudaSetDeviceScope scope{device.index()};
733
734 CHAINERX_ASSERT(a.ndim() == 2);
735 CHAINERX_ASSERT(out.ndim() == 2);
736 CHAINERX_ASSERT(a.shape()[0] == a.shape()[1]);
737 CHAINERX_ASSERT(out.IsContiguous());
738 CHAINERX_ASSERT(a.dtype() == out.dtype());
739
740 // cuSOLVER might not work well with zero-sized arrays for older versions of cuSOLVER (<10.1)
741 // therefore it's better to return earlier
742 if (a.shape().GetTotalSize() == 0) {
743 return;
744 }
745
746 // potrf (cholesky) stores result in-place, therefore copy ``a`` to ``out`` and then pass ``out`` to the routine
747 device.backend().CallKernel<CopyKernel>(Tril(a, 0), out);
748
749 auto cholesky_impl = [&](auto pt) {
750 using T = typename decltype(pt)::type;
751
752 // Note that cuSOLVER uses Fortran order.
753 // To compute a lower triangular matrix L = cholesky(A), we use cuSOLVER to compute an upper triangular matrix U = cholesky(A).
754 cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER;
755
756 cuda_internal::DeviceInternals& device_internals = cuda_internal::GetDeviceInternals(static_cast<CudaDevice&>(device));
757
758 // compute workspace size and prepare workspace
759 auto out_ptr = static_cast<T*>(internal::GetRawOffsetData(out));
760 int work_size = 0;
761 int64_t n = a.shape()[0];
762 device_internals.cusolverdn_handle().Call(PotrfBuffersize<T>, uplo, n, out_ptr, std::max(int64_t{1}, n), &work_size);
763
764 // POTRF execution
765 Array work = Empty(Shape{work_size}, dtype, device);
766 auto work_ptr = static_cast<T*>(internal::GetRawOffsetData(work));
767
768 std::shared_ptr<void> devinfo = device.Allocate(sizeof(int));
769 device_internals.cusolverdn_handle().Call(
770 Potrf<T>, uplo, n, out_ptr, std::max(int64_t{1}, n), work_ptr, work_size, static_cast<int*>(devinfo.get()));
771
772 int devinfo_h = 0;
773 Device& native_device = GetDefaultContext().GetDevice({"native", 0});
774 device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
775 if (devinfo_h != 0) {
776 throw ChainerxError{"Unsuccessful potrf (Cholesky) execution. Info = ", devinfo_h};
777 }
778 };
779
780 VisitFloatingPointDtype(dtype, cholesky_impl);
781 }
782 };
783
784 CHAINERX_CUDA_REGISTER_KERNEL(CholeskyKernel, CudaCholeskyKernel);
785
786 class CudaSyevdKernel : public SyevdKernel {
787 public:
Call(const Array & a,const Array & w,const Array & v,char uplo,bool compute_v)788 void Call(const Array& a, const Array& w, const Array& v, char uplo, bool compute_v) override {
789 Device& device = a.device();
790 Dtype dtype = a.dtype();
791 CudaSetDeviceScope scope{device.index()};
792
793 CHAINERX_ASSERT(a.ndim() == 2);
794
795 device.backend().CallKernel<CopyKernel>(a, v);
796
797 int64_t m = a.shape()[0];
798 int64_t n = a.shape()[1];
799
800 auto syevd_impl = [&](auto pt) {
801 using T = typename decltype(pt)::type;
802 cuda_internal::DeviceInternals& device_internals = cuda_internal::GetDeviceInternals(static_cast<CudaDevice&>(device));
803
804 auto v_ptr = static_cast<T*>(internal::GetRawOffsetData(v));
805 auto w_ptr = static_cast<T*>(internal::GetRawOffsetData(w));
806
807 cusolverEigMode_t jobz = compute_v ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
808
809 // cuSOLVER assumes that arrays are stored in column-major order
810 // The uplo argument is swapped instead of transposing the input matrix
811 cublasFillMode_t uplo_cublas = toupper(uplo) == 'U' ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
812
813 int buffersize = 0;
814 // When calling Syevd matrix dimensions are swapped instead of transposing the input matrix
815 device_internals.cusolverdn_handle().Call(
816 SyevdBuffersize<T>, jobz, uplo_cublas, n, v_ptr, std::max(int64_t{1}, m), w_ptr, &buffersize);
817
818 Array work = Empty(Shape{buffersize}, dtype, device);
819 auto work_ptr = static_cast<T*>(internal::GetRawOffsetData(work));
820
821 std::shared_ptr<void> devinfo = device.Allocate(sizeof(int));
822
823 device_internals.cusolverdn_handle().Call(
824 Syevd<T>,
825 jobz,
826 uplo_cublas,
827 n,
828 v_ptr,
829 std::max(int64_t{1}, m),
830 w_ptr,
831 work_ptr,
832 buffersize,
833 static_cast<int*>(devinfo.get()));
834
835 int devinfo_h = 0;
836 Device& native_device = GetDefaultContext().GetDevice({"native", 0});
837 device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
838 if (devinfo_h != 0) {
839 throw ChainerxError{"Unsuccessful syevd (Eigen Decomposition) execution. Info = ", devinfo_h};
840 }
841
842 // v is stored now in column-major order, need to transform it to row-major
843 device.backend().CallKernel<CopyKernel>(v.Transpose(), v);
844 };
845
846 VisitFloatingPointDtype(dtype, syevd_impl);
847 }
848 };
849
850 CHAINERX_CUDA_REGISTER_KERNEL(SyevdKernel, CudaSyevdKernel);
851
852 } // namespace cuda
853 } // namespace chainerx
854