1 #include "chainerx/cuda/cuda_device.h"
2 
3 #include <cstdint>
4 #include <mutex>
5 #include <type_traits>
6 
7 #include <cublas_v2.h>
8 #include <cuda_runtime.h>
9 #include <cusolverDn.h>
10 #include <cuda_fp16.hpp>
11 
12 #include "chainerx/array.h"
13 #include "chainerx/axes.h"
14 #include "chainerx/backend.h"
15 #include "chainerx/backend_util.h"
16 #include "chainerx/cuda/cublas.h"
17 #include "chainerx/cuda/cuda_runtime.h"
18 #include "chainerx/cuda/cuda_set_device_scope.h"
19 #include "chainerx/cuda/cusolver.h"
20 #include "chainerx/cuda/data_type.cuh"
21 #include "chainerx/cuda/float16.cuh"
22 #include "chainerx/cuda/kernel_regist.h"
23 #include "chainerx/device.h"
24 #include "chainerx/dtype.h"
25 #include "chainerx/error.h"
26 #include "chainerx/float16.h"
27 #include "chainerx/kernels/creation.h"
28 #include "chainerx/kernels/linalg.h"
29 #include "chainerx/kernels/misc.h"
30 #include "chainerx/macro.h"
31 #include "chainerx/native/native_device.h"
32 #include "chainerx/routines/arithmetic.h"
33 #include "chainerx/routines/creation.h"
34 #include "chainerx/routines/indexing.h"
35 #include "chainerx/routines/linalg.h"
36 
37 namespace chainerx {
38 namespace cuda {
39 namespace {
40 
41 template <typename T>
GetrfBuffersize(cusolverDnHandle_t,int,int,T *,int,int *)42 cusolverStatus_t GetrfBuffersize(cusolverDnHandle_t /*handle*/, int /*m*/, int /*n*/, T* /*a*/, int /*lda*/, int* /*lwork*/) {
43     throw DtypeError{"Only Arrays of float or double type are supported by getrf (LU)"};
44 }
45 
46 template <typename T>
Getrf(cusolverDnHandle_t,int,int,T *,int,T *,int *,int *)47 cusolverStatus_t Getrf(
48         cusolverDnHandle_t /*handle*/, int /*m*/, int /*n*/, T* /*a*/, int /*lda*/, T* /*workspace*/, int* /*devipiv*/, int* /*devinfo*/) {
49     throw DtypeError{"Only Arrays of float or double type are supported by getrf (LU)"};
50 }
51 
52 template <typename T>
Getrs(cusolverDnHandle_t,cublasOperation_t,int,int,T *,int,int *,T *,int,int *)53 cusolverStatus_t Getrs(
54         cusolverDnHandle_t /*handle*/,
55         cublasOperation_t /*trans*/,
56         int /*n*/,
57         int /*nrhs*/,
58         T* /*a*/,
59         int /*lda*/,
60         int* /*devipiv*/,
61         T* /*b*/,
62         int /*ldb*/,
63         int* /*devinfo*/) {
64     throw DtypeError{"Only Arrays of float or double type are supported by getrs (Solve)"};
65 }
66 
67 template <typename T>
GesvdBuffersize(cusolverDnHandle_t,int,int,int *)68 cusolverStatus_t GesvdBuffersize(cusolverDnHandle_t /*handle*/, int /*m*/, int /*n*/, int* /*lwork*/) {
69     throw DtypeError{"Only Arrays of float or double type are supported by gesvd (SVD)"};
70 }
71 
72 template <typename T>
Gesvd(cusolverDnHandle_t,signed char,signed char,int,int,T *,int,T *,T *,int,T *,int,T *,int,T *,int *)73 cusolverStatus_t Gesvd(
74         cusolverDnHandle_t /*handle*/,
75         signed char /*jobu*/,
76         signed char /*jobvt*/,
77         int /*m*/,
78         int /*n*/,
79         T* /*a*/,
80         int /*lda*/,
81         T* /*s*/,
82         T* /*u*/,
83         int /*ldu*/,
84         T* /*vt*/,
85         int /*ldvt*/,
86         T* /*work*/,
87         int /*lwork*/,
88         T* /*rwork*/,
89         int* /*devinfo*/) {
90     throw DtypeError{"Only Arrays of float or double type are supported by gesvd (SVD)"};
91 }
92 
93 template <typename T>
GeqrfBufferSize(cusolverDnHandle_t,int,int,T *,int,int *)94 cusolverStatus_t GeqrfBufferSize(cusolverDnHandle_t /*handle*/, int /*m*/, int /*n*/, T* /*a*/, int /*lda*/, int* /*lwork*/) {
95     throw DtypeError{"Only Arrays of float or double type are supported by geqrf (QR)"};
96 }
97 
98 template <typename T>
Geqrf(cusolverDnHandle_t,int,int,T *,int,T *,T *,int,int *)99 cusolverStatus_t Geqrf(
100         cusolverDnHandle_t /*handle*/,
101         int /*m*/,
102         int /*n*/,
103         T* /*a*/,
104         int /*lda*/,
105         T* /*tau*/,
106         T* /*workspace*/,
107         int /*lwork*/,
108         int* /*devinfo*/) {
109     throw DtypeError{"Only Arrays of float or double type are supported by geqrf (QR)"};
110 }
111 
112 template <typename T>
OrgqrBufferSize(cusolverDnHandle_t,int,int,int,T *,int,T *,int *)113 cusolverStatus_t OrgqrBufferSize(
114         cusolverDnHandle_t /*handle*/, int /*m*/, int /*n*/, int /*k*/, T* /*a*/, int /*lda*/, T* /*tau*/, int* /*lwork*/) {
115     throw DtypeError{"Only Arrays of float or double type are supported by orgqr (QR)"};
116 }
117 
118 template <typename T>
Orgqr(cusolverDnHandle_t,int,int,int,T *,int,T *,T *,int,int *)119 cusolverStatus_t Orgqr(
120         cusolverDnHandle_t /*handle*/,
121         int /*m*/,
122         int /*n*/,
123         int /*k*/,
124         T* /*a*/,
125         int /*lda*/,
126         T* /*tau*/,
127         T* /*work*/,
128         int /*lwork*/,
129         int* /*devinfo*/) {
130     throw DtypeError{"Only Arrays of float or double type are supported by orgqr (QR)"};
131 }
132 
133 template <typename T>
PotrfBuffersize(cusolverDnHandle_t,cublasFillMode_t,int,T *,int,int *)134 cusolverStatus_t PotrfBuffersize(
135         cusolverDnHandle_t /*handle*/, cublasFillMode_t /*uplo*/, int /*n*/, T* /*a*/, int /*lda*/, int* /*lwork*/) {
136     throw DtypeError{"Only Arrays of float or double type are supported by potrf (Cholesky)"};
137 }
138 
139 template <typename T>
Potrf(cusolverDnHandle_t,cublasFillMode_t,int,T *,int,T *,int,int *)140 cusolverStatus_t Potrf(
141         cusolverDnHandle_t /*handle*/,
142         cublasFillMode_t /*uplo*/,
143         int /*n*/,
144         T* /*a*/,
145         int /*lda*/,
146         T* /*workspace*/,
147         int /*lwork*/,
148         int* /*devinfo*/) {
149     throw DtypeError{"Only Arrays of float or double type are supported by potrf (Cholesky)"};
150 }
151 
152 template <typename T>
SyevdBuffersize(cusolverDnHandle_t,cusolverEigMode_t,cublasFillMode_t,int,T *,int,T *,int *)153 cusolverStatus_t SyevdBuffersize(
154         cusolverDnHandle_t /*handle*/,
155         cusolverEigMode_t /*jobz*/,
156         cublasFillMode_t /*uplo*/,
157         int /*n*/,
158         T* /*a*/,
159         int /*lda*/,
160         T* /*w*/,
161         int* /*lwork*/) {
162     throw DtypeError{"Only Arrays of float or double type are supported by syevd (Eigen)"};
163 }
164 
165 template <typename T>
Syevd(cusolverDnHandle_t,cusolverEigMode_t,cublasFillMode_t,int,T *,int,T *,T *,int,int *)166 cusolverStatus_t Syevd(
167         cusolverDnHandle_t /*handle*/,
168         cusolverEigMode_t /*jobz*/,
169         cublasFillMode_t /*uplo*/,
170         int /*n*/,
171         T* /*a*/,
172         int /*lda*/,
173         T* /*w*/,
174         T* /*work*/,
175         int /*lwork*/,
176         int* /*devinfo*/) {
177     throw DtypeError{"Only Arrays of float or double type are supported by syevd (Eigen)"};
178 }
179 
180 template <>
GetrfBuffersize(cusolverDnHandle_t handle,int m,int n,double * a,int lda,int * lwork)181 cusolverStatus_t GetrfBuffersize<double>(cusolverDnHandle_t handle, int m, int n, double* a, int lda, int* lwork) {
182     return cusolverDnDgetrf_bufferSize(handle, m, n, a, lda, lwork);
183 }
184 
185 template <>
GetrfBuffersize(cusolverDnHandle_t handle,int m,int n,float * a,int lda,int * lwork)186 cusolverStatus_t GetrfBuffersize<float>(cusolverDnHandle_t handle, int m, int n, float* a, int lda, int* lwork) {
187     return cusolverDnSgetrf_bufferSize(handle, m, n, a, lda, lwork);
188 }
189 
190 template <>
Getrf(cusolverDnHandle_t handle,int m,int n,double * a,int lda,double * workspace,int * devipiv,int * devinfo)191 cusolverStatus_t Getrf<double>(cusolverDnHandle_t handle, int m, int n, double* a, int lda, double* workspace, int* devipiv, int* devinfo) {
192     return cusolverDnDgetrf(handle, m, n, a, lda, workspace, devipiv, devinfo);
193 }
194 
195 template <>
Getrf(cusolverDnHandle_t handle,int m,int n,float * a,int lda,float * workspace,int * devipiv,int * devinfo)196 cusolverStatus_t Getrf<float>(cusolverDnHandle_t handle, int m, int n, float* a, int lda, float* workspace, int* devipiv, int* devinfo) {
197     return cusolverDnSgetrf(handle, m, n, a, lda, workspace, devipiv, devinfo);
198 }
199 
200 template <>
Getrs(cusolverDnHandle_t handle,cublasOperation_t trans,int n,int nrhs,double * a,int lda,int * devipiv,double * b,int ldb,int * devinfo)201 cusolverStatus_t Getrs<double>(
202         cusolverDnHandle_t handle,
203         cublasOperation_t trans,
204         int n,
205         int nrhs,
206         double* a,
207         int lda,
208         int* devipiv,
209         double* b,
210         int ldb,
211         int* devinfo) {
212     return cusolverDnDgetrs(handle, trans, n, nrhs, a, lda, devipiv, b, ldb, devinfo);
213 }
214 
215 template <>
Getrs(cusolverDnHandle_t handle,cublasOperation_t trans,int n,int nrhs,float * a,int lda,int * devipiv,float * b,int ldb,int * devinfo)216 cusolverStatus_t Getrs<float>(
217         cusolverDnHandle_t handle,
218         cublasOperation_t trans,
219         int n,
220         int nrhs,
221         float* a,
222         int lda,
223         int* devipiv,
224         float* b,
225         int ldb,
226         int* devinfo) {
227     return cusolverDnSgetrs(handle, trans, n, nrhs, a, lda, devipiv, b, ldb, devinfo);
228 }
229 
230 template <>
GesvdBuffersize(cusolverDnHandle_t handle,int m,int n,int * lwork)231 cusolverStatus_t GesvdBuffersize<double>(cusolverDnHandle_t handle, int m, int n, int* lwork) {
232     return cusolverDnDgesvd_bufferSize(handle, m, n, lwork);
233 }
234 
235 template <>
GesvdBuffersize(cusolverDnHandle_t handle,int m,int n,int * lwork)236 cusolverStatus_t GesvdBuffersize<float>(cusolverDnHandle_t handle, int m, int n, int* lwork) {
237     return cusolverDnSgesvd_bufferSize(handle, m, n, lwork);
238 }
239 
240 template <>
Gesvd(cusolverDnHandle_t handle,signed char jobu,signed char jobvt,int m,int n,double * a,int lda,double * s,double * u,int ldu,double * vt,int ldvt,double * work,int lwork,double * rwork,int * devinfo)241 cusolverStatus_t Gesvd<double>(
242         cusolverDnHandle_t handle,
243         signed char jobu,
244         signed char jobvt,
245         int m,
246         int n,
247         double* a,
248         int lda,
249         double* s,
250         double* u,
251         int ldu,
252         double* vt,
253         int ldvt,
254         double* work,
255         int lwork,
256         double* rwork,
257         int* devinfo) {
258     return cusolverDnDgesvd(handle, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, rwork, devinfo);
259 }
260 
261 template <>
Gesvd(cusolverDnHandle_t handle,signed char jobu,signed char jobvt,int m,int n,float * a,int lda,float * s,float * u,int ldu,float * vt,int ldvt,float * work,int lwork,float * rwork,int * devinfo)262 cusolverStatus_t Gesvd<float>(
263         cusolverDnHandle_t handle,
264         signed char jobu,
265         signed char jobvt,
266         int m,
267         int n,
268         float* a,
269         int lda,
270         float* s,
271         float* u,
272         int ldu,
273         float* vt,
274         int ldvt,
275         float* work,
276         int lwork,
277         float* rwork,
278         int* devinfo) {
279     return cusolverDnSgesvd(handle, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, rwork, devinfo);
280 }
281 
282 template <>
GeqrfBufferSize(cusolverDnHandle_t handle,int m,int n,double * a,int lda,int * lwork)283 cusolverStatus_t GeqrfBufferSize<double>(cusolverDnHandle_t handle, int m, int n, double* a, int lda, int* lwork) {
284     return cusolverDnDgeqrf_bufferSize(handle, m, n, a, lda, lwork);
285 }
286 
287 template <>
GeqrfBufferSize(cusolverDnHandle_t handle,int m,int n,float * a,int lda,int * lwork)288 cusolverStatus_t GeqrfBufferSize<float>(cusolverDnHandle_t handle, int m, int n, float* a, int lda, int* lwork) {
289     return cusolverDnSgeqrf_bufferSize(handle, m, n, a, lda, lwork);
290 }
291 
292 template <>
Geqrf(cusolverDnHandle_t handle,int m,int n,double * a,int lda,double * tau,double * workspace,int lwork,int * devinfo)293 cusolverStatus_t Geqrf<double>(
294         cusolverDnHandle_t handle, int m, int n, double* a, int lda, double* tau, double* workspace, int lwork, int* devinfo) {
295     return cusolverDnDgeqrf(handle, m, n, a, lda, tau, workspace, lwork, devinfo);
296 }
297 
298 template <>
Geqrf(cusolverDnHandle_t handle,int m,int n,float * a,int lda,float * tau,float * workspace,int lwork,int * devinfo)299 cusolverStatus_t Geqrf<float>(
300         cusolverDnHandle_t handle, int m, int n, float* a, int lda, float* tau, float* workspace, int lwork, int* devinfo) {
301     return cusolverDnSgeqrf(handle, m, n, a, lda, tau, workspace, lwork, devinfo);
302 }
303 
304 template <>
OrgqrBufferSize(cusolverDnHandle_t handle,int m,int n,int k,double * a,int lda,double * tau,int * lwork)305 cusolverStatus_t OrgqrBufferSize<double>(cusolverDnHandle_t handle, int m, int n, int k, double* a, int lda, double* tau, int* lwork) {
306     return cusolverDnDorgqr_bufferSize(handle, m, n, k, a, lda, tau, lwork);
307 }
308 
309 template <>
OrgqrBufferSize(cusolverDnHandle_t handle,int m,int n,int k,float * a,int lda,float * tau,int * lwork)310 cusolverStatus_t OrgqrBufferSize<float>(cusolverDnHandle_t handle, int m, int n, int k, float* a, int lda, float* tau, int* lwork) {
311     return cusolverDnSorgqr_bufferSize(handle, m, n, k, a, lda, tau, lwork);
312 }
313 
314 template <>
Orgqr(cusolverDnHandle_t handle,int m,int n,int k,double * a,int lda,double * tau,double * work,int lwork,int * devinfo)315 cusolverStatus_t Orgqr<double>(
316         cusolverDnHandle_t handle, int m, int n, int k, double* a, int lda, double* tau, double* work, int lwork, int* devinfo) {
317     return cusolverDnDorgqr(handle, m, n, k, a, lda, tau, work, lwork, devinfo);
318 }
319 
320 template <>
Orgqr(cusolverDnHandle_t handle,int m,int n,int k,float * a,int lda,float * tau,float * work,int lwork,int * devinfo)321 cusolverStatus_t Orgqr<float>(
322         cusolverDnHandle_t handle, int m, int n, int k, float* a, int lda, float* tau, float* work, int lwork, int* devinfo) {
323     return cusolverDnSorgqr(handle, m, n, k, a, lda, tau, work, lwork, devinfo);
324 }
325 
326 template <>
PotrfBuffersize(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,double * a,int lda,int * lwork)327 cusolverStatus_t PotrfBuffersize<double>(cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double* a, int lda, int* lwork) {
328     return cusolverDnDpotrf_bufferSize(handle, uplo, n, a, lda, lwork);
329 }
330 
331 template <>
PotrfBuffersize(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,float * a,int lda,int * lwork)332 cusolverStatus_t PotrfBuffersize<float>(cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, float* a, int lda, int* lwork) {
333     return cusolverDnSpotrf_bufferSize(handle, uplo, n, a, lda, lwork);
334 }
335 
336 template <>
Potrf(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,double * a,int lda,double * workspace,int lwork,int * devinfo)337 cusolverStatus_t Potrf<double>(
338         cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double* a, int lda, double* workspace, int lwork, int* devinfo) {
339     return cusolverDnDpotrf(handle, uplo, n, a, lda, workspace, lwork, devinfo);
340 }
341 
342 template <>
Potrf(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,float * a,int lda,float * workspace,int lwork,int * devinfo)343 cusolverStatus_t Potrf<float>(
344         cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, float* a, int lda, float* workspace, int lwork, int* devinfo) {
345     return cusolverDnSpotrf(handle, uplo, n, a, lda, workspace, lwork, devinfo);
346 }
347 
348 template <>
SyevdBuffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,double * a,int lda,double * w,int * lwork)349 cusolverStatus_t SyevdBuffersize<double>(
350         cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, double* a, int lda, double* w, int* lwork) {
351     return cusolverDnDsyevd_bufferSize(handle, jobz, uplo, n, a, lda, w, lwork);
352 }
353 
354 template <>
SyevdBuffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,float * a,int lda,float * w,int * lwork)355 cusolverStatus_t SyevdBuffersize<float>(
356         cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, float* a, int lda, float* w, int* lwork) {
357     return cusolverDnSsyevd_bufferSize(handle, jobz, uplo, n, a, lda, w, lwork);
358 }
359 
360 template <>
Syevd(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,double * a,int lda,double * w,double * work,int lwork,int * devinfo)361 cusolverStatus_t Syevd<double>(
362         cusolverDnHandle_t handle,
363         cusolverEigMode_t jobz,
364         cublasFillMode_t uplo,
365         int n,
366         double* a,
367         int lda,
368         double* w,
369         double* work,
370         int lwork,
371         int* devinfo) {
372     return cusolverDnDsyevd(handle, jobz, uplo, n, a, lda, w, work, lwork, devinfo);
373 }
374 
375 template <>
Syevd(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,float * a,int lda,float * w,float * work,int lwork,int * devinfo)376 cusolverStatus_t Syevd<float>(
377         cusolverDnHandle_t handle,
378         cusolverEigMode_t jobz,
379         cublasFillMode_t uplo,
380         int n,
381         float* a,
382         int lda,
383         float* w,
384         float* work,
385         int lwork,
386         int* devinfo) {
387     return cusolverDnSsyevd(handle, jobz, uplo, n, a, lda, w, work, lwork, devinfo);
388 }
389 
390 template <typename T>
SolveImpl(const Array & a,const Array & b,const Array & out)391 void SolveImpl(const Array& a, const Array& b, const Array& out) {
392     Device& device = a.device();
393     Dtype dtype = a.dtype();
394 
395     cuda_internal::DeviceInternals& device_internals = cuda_internal::GetDeviceInternals(static_cast<CudaDevice&>(device));
396 
397     Array lu_matrix = Empty(a.shape(), dtype, device);
398     device.backend().CallKernel<CopyKernel>(a.Transpose(), lu_matrix);
399     auto lu_ptr = static_cast<T*>(internal::GetRawOffsetData(lu_matrix));
400 
401     int64_t m = a.shape()[0];
402     int64_t lda = std::max(int64_t{1}, m);
403     int64_t nrhs = 1;
404     if (b.ndim() == 2) {
405         nrhs = b.shape()[1];
406     }
407 
408     Array ipiv = Empty(Shape{m}, Dtype::kInt32, device);
409     auto ipiv_ptr = static_cast<int*>(internal::GetRawOffsetData(ipiv));
410 
411     int buffersize = 0;
412     device_internals.cusolverdn_handle().Call(GetrfBuffersize<T>, m, m, lu_ptr, lda, &buffersize);
413 
414     Array work = Empty(Shape{buffersize}, dtype, device);
415     auto work_ptr = static_cast<T*>(internal::GetRawOffsetData(work));
416 
417     std::shared_ptr<void> devinfo = device.Allocate(sizeof(int));
418 
419     device_internals.cusolverdn_handle().Call(Getrf<T>, m, m, lu_ptr, lda, work_ptr, ipiv_ptr, static_cast<int*>(devinfo.get()));
420 
421     int devinfo_h = 0;
422     Device& native_device = GetDefaultContext().GetDevice({"native", 0});
423     device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
424     if (devinfo_h != 0) {
425         throw ChainerxError{"Unsuccessful getrf (LU) execution. Info = ", devinfo_h};
426     }
427 
428     Array out_transposed = b.Transpose().Copy();
429     auto out_ptr = static_cast<T*>(internal::GetRawOffsetData(out_transposed));
430 
431     device_internals.cusolverdn_handle().Call(
432             Getrs<T>, CUBLAS_OP_N, m, nrhs, lu_ptr, lda, ipiv_ptr, out_ptr, lda, static_cast<int*>(devinfo.get()));
433 
434     device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
435     if (devinfo_h != 0) {
436         throw ChainerxError{"Unsuccessful getrs (Solve) execution. Info = ", devinfo_h};
437     }
438 
439     device.backend().CallKernel<CopyKernel>(out_transposed.Transpose(), out);
440 }
441 
442 template <typename T>
QrImpl(const Array & a,const Array & q,const Array & r,const Array & tau,QrMode mode)443 void QrImpl(const Array& a, const Array& q, const Array& r, const Array& tau, QrMode mode) {
444     Device& device = a.device();
445     Dtype dtype = a.dtype();
446 
447     int64_t m = a.shape()[0];
448     int64_t n = a.shape()[1];
449     int64_t k = std::min(m, n);
450     int64_t lda = std::max(int64_t{1}, m);
451 
452     // cuSOLVER does not return correct result in this case and older versions of cuSOLVER (<10.1)
453     // might not work well with zero-sized arrays therefore it's better to return earlier
454     if (a.shape().GetTotalSize() == 0) {
455         if (mode == QrMode::kComplete) {
456             device.backend().CallKernel<IdentityKernel>(q);
457         }
458         return;
459     }
460 
461     Array r_temp = a.Transpose().Copy();  // QR decomposition is done in-place
462 
463     cuda_internal::DeviceInternals& device_internals = cuda_internal::GetDeviceInternals(static_cast<CudaDevice&>(device));
464 
465     auto r_ptr = static_cast<T*>(internal::GetRawOffsetData(r_temp));
466     auto tau_ptr = static_cast<T*>(internal::GetRawOffsetData(tau));
467 
468     std::shared_ptr<void> devinfo = device.Allocate(sizeof(int));
469 
470     int buffersize_geqrf = 0;
471     device_internals.cusolverdn_handle().Call(GeqrfBufferSize<T>, m, n, r_ptr, lda, &buffersize_geqrf);
472 
473     Array work = Empty(Shape{buffersize_geqrf}, dtype, device);
474     auto work_ptr = static_cast<T*>(internal::GetRawOffsetData(work));
475 
476     device_internals.cusolverdn_handle().Call(
477             Geqrf<T>, m, n, r_ptr, lda, tau_ptr, work_ptr, buffersize_geqrf, static_cast<int*>(devinfo.get()));
478 
479     int devinfo_h = 0;
480     Device& native_device = GetDefaultContext().GetDevice({"native", 0});
481     device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
482     if (devinfo_h != 0) {
483         throw ChainerxError{"Unsuccessful geqrf (QR) execution. Info = ", devinfo_h};
484     }
485 
486     if (mode == QrMode::kR) {
487         r_temp = r_temp.At(std::vector<ArrayIndex>{Slice{}, Slice{0, k}}).Transpose();  // R = R[:, 0:k].T
488         r_temp = Triu(r_temp, 0);
489         device.backend().CallKernel<CopyKernel>(r_temp, r);
490         return;
491     }
492 
493     if (mode == QrMode::kRaw) {
494         device.backend().CallKernel<CopyKernel>(r_temp, r);
495         return;
496     }
497 
498     int64_t mc;
499     Shape q_shape{0};
500     if (mode == QrMode::kComplete && m > n) {
501         mc = m;
502         q_shape = Shape{m, m};
503     } else {
504         mc = k;
505         q_shape = Shape{n, m};
506     }
507     Array q_temp = Empty(q_shape, dtype, device);
508 
509     device.backend().CallKernel<CopyKernel>(r_temp, q_temp.At(std::vector<ArrayIndex>{Slice{0, n}, Slice{}}));  // Q[0:n, :] = R
510     auto q_ptr = static_cast<T*>(internal::GetRawOffsetData(q_temp));
511 
512     int buffersize_orgqr = 0;
513     device_internals.cusolverdn_handle().Call(OrgqrBufferSize<T>, m, mc, k, q_ptr, lda, tau_ptr, &buffersize_orgqr);
514 
515     Array work_orgqr = Empty(Shape{buffersize_orgqr}, dtype, device);
516     auto work_orgqr_ptr = static_cast<T*>(internal::GetRawOffsetData(work_orgqr));
517 
518     device_internals.cusolverdn_handle().Call(
519             Orgqr<T>, m, mc, k, q_ptr, lda, tau_ptr, work_orgqr_ptr, buffersize_orgqr, static_cast<int*>(devinfo.get()));
520 
521     device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
522     if (devinfo_h != 0) {
523         throw ChainerxError{"Unsuccessful orgqr (QR) execution. Info = ", devinfo_h};
524     }
525 
526     q_temp = q_temp.At(std::vector<ArrayIndex>{Slice{0, mc}, Slice{}}).Transpose();  // Q = Q[0:mc, :].T
527     r_temp = r_temp.At(std::vector<ArrayIndex>{Slice{}, Slice{0, mc}}).Transpose();  // R = R[:, 0:mc].T
528     r_temp = Triu(r_temp, 0);
529 
530     device.backend().CallKernel<CopyKernel>(q_temp, q);
531     device.backend().CallKernel<CopyKernel>(r_temp, r);
532 }
533 
534 }  // namespace
535 
536 class CudaSolveKernel : public SolveKernel {
537 public:
Call(const Array & a,const Array & b,const Array & out)538     void Call(const Array& a, const Array& b, const Array& out) override {
539         Device& device = a.device();
540         CudaSetDeviceScope scope{device.index()};
541 
542         CHAINERX_ASSERT(a.ndim() == 2);
543         CHAINERX_ASSERT(a.shape()[0] == a.shape()[1]);
544 
545         VisitFloatingPointDtype(out.dtype(), [&](auto pt) {
546             using T = typename decltype(pt)::type;
547             SolveImpl<T>(a.dtype() == out.dtype() ? a : a.AsType(out.dtype()), b.dtype() == out.dtype() ? b : b.AsType(out.dtype()), out);
548         });
549     }
550 };
551 
552 CHAINERX_CUDA_REGISTER_KERNEL(SolveKernel, CudaSolveKernel);
553 
554 class CudaInverseKernel : public InverseKernel {
555 public:
Call(const Array & a,const Array & out)556     void Call(const Array& a, const Array& out) override {
557         Device& device = a.device();
558         Dtype dtype = a.dtype();
559         CudaSetDeviceScope scope{device.index()};
560 
561         CHAINERX_ASSERT(a.ndim() == 2);
562         CHAINERX_ASSERT(a.shape()[0] == a.shape()[1]);
563 
564         // There is LAPACK routine ``getri`` for computing the inverse of an LU-factored matrix,
565         // but cuSOLVER does not have it implemented, therefore inverse is obtained with ``getrs``
566         // inv(A) == solve(A, Identity)
567         Array b = Identity(a.shape()[0], dtype, device);
568         device.backend().CallKernel<SolveKernel>(a, b, out);
569     }
570 };
571 
572 CHAINERX_CUDA_REGISTER_KERNEL(InverseKernel, CudaInverseKernel);
573 
574 class CudaSvdKernel : public SvdKernel {
575 public:
Call(const Array & a,const Array & u,const Array & s,const Array & vt,bool full_matrices,bool compute_uv)576     void Call(const Array& a, const Array& u, const Array& s, const Array& vt, bool full_matrices, bool compute_uv) override {
577         Device& device = a.device();
578         Dtype dtype = a.dtype();
579         CudaSetDeviceScope scope{device.index()};
580 
581         CHAINERX_ASSERT(a.ndim() == 2);
582 
583         if (a.shape().GetTotalSize() == 0) {
584             if (full_matrices && compute_uv) {
585                 device.backend().CallKernel<IdentityKernel>(u);
586                 device.backend().CallKernel<IdentityKernel>(vt);
587             }
588             // This kernel works correctly for zero-sized input also without early return
589             return;
590         }
591 
592         // cuSOLVER assumes arrays are in column-major order.
593         // In order to avoid transposing the input matrix, matrix dimensions are swapped.
594         // Since the input is assumed to be transposed, it is necessary to
595         // swap the pointers to u and vt matrices when calling Gesvd.
596         int64_t n = a.shape()[0];
597         int64_t m = a.shape()[1];
598         int64_t k = std::min(m, n);
599 
600         Array x = EmptyLike(a, device);
601         Array u_temp{};
602         Array vt_temp{};
603         bool trans_flag;
604 
605         // Remark: gesvd only supports m>=n.
606         // See: https://docs.nvidia.com/cuda/cusolver/index.html#cuds-lt-t-gt-gesvd
607         // Therefore for the case m<n we calculuate svd of transposed matrix,
608         // instead of calculating svd(A) = U S V^T, we compute svd(A^T) = V S U^T
609         if (m >= n) {
610             device.backend().CallKernel<CopyKernel>(a, x);
611             trans_flag = false;
612         } else {
613             m = a.shape()[0];
614             n = a.shape()[1];
615             x = x.Reshape(Shape{n, m});
616             device.backend().CallKernel<CopyKernel>(a.Transpose(), x);
617             trans_flag = true;
618 
619             // Temporary arrays for u, vt are needed to store transposed results
620             Shape u_shape;
621             Shape vt_shape;
622             if (compute_uv) {
623                 if (full_matrices) {
624                     u_shape = Shape{m, m};
625                     vt_shape = Shape{n, n};
626                 } else {
627                     u_shape = Shape{k, m};
628                     vt_shape = Shape{n, k};
629                 }
630             } else {
631                 u_shape = Shape{0};
632                 vt_shape = Shape{0};
633             }
634             u_temp = Empty(u_shape, dtype, device);
635             vt_temp = Empty(vt_shape, dtype, device);
636         }
637 
638         int64_t lda = std::max(int64_t{1}, m);
639         int64_t ldu = std::max(int64_t{1}, m);
640         int64_t ldvt = full_matrices ? std::max(int64_t{1}, n) : std::max(int64_t{1}, k);
641 
642         auto svd_impl = [&](auto pt) {
643             using T = typename decltype(pt)::type;
644             cuda_internal::DeviceInternals& device_internals = cuda_internal::GetDeviceInternals(static_cast<CudaDevice&>(device));
645 
646             auto x_ptr = static_cast<T*>(internal::GetRawOffsetData(x));
647             auto s_ptr = static_cast<T*>(internal::GetRawOffsetData(s));
648             auto u_ptr = static_cast<T*>(internal::GetRawOffsetData(u));
649             auto vt_ptr = static_cast<T*>(internal::GetRawOffsetData(vt));
650             if (trans_flag) {
651                 u_ptr = static_cast<T*>(internal::GetRawOffsetData(vt_temp));
652                 vt_ptr = static_cast<T*>(internal::GetRawOffsetData(u_temp));
653             }
654 
655             std::shared_ptr<void> devinfo = device.Allocate(sizeof(int));
656 
657             int buffersize = 0;
658             device_internals.cusolverdn_handle().Call(GesvdBuffersize<T>, m, n, &buffersize);
659 
660             Array work = Empty(Shape{buffersize}, dtype, device);
661             auto work_ptr = static_cast<T*>(internal::GetRawOffsetData(work));
662 
663             signed char job;
664             if (compute_uv) {
665                 job = full_matrices ? 'A' : 'S';
666             } else {
667                 job = 'N';
668             }
669 
670             // When calling Gesvd pointers to u and vt are swapped instead of transposing the input matrix.
671             device_internals.cusolverdn_handle().Call(
672                     Gesvd<T>,
673                     job,
674                     job,
675                     m,
676                     n,
677                     x_ptr,
678                     lda,
679                     s_ptr,
680                     vt_ptr,
681                     ldu,
682                     u_ptr,
683                     ldvt,
684                     work_ptr,
685                     buffersize,
686                     nullptr,
687                     static_cast<int*>(devinfo.get()));
688 
689             int devinfo_h = 0;
690             Device& native_device = GetDefaultContext().GetDevice({"native", 0});
691             device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
692             if (devinfo_h != 0) {
693                 throw ChainerxError{"Unsuccessful gesvd (SVD) execution. Info = ", devinfo_h};
694             }
695 
696             if (trans_flag) {
697                 device.backend().CallKernel<CopyKernel>(u_temp.Transpose(), u);
698                 device.backend().CallKernel<CopyKernel>(vt_temp.Transpose(), vt);
699             }
700         };
701 
702         VisitFloatingPointDtype(dtype, svd_impl);
703     }
704 };
705 
706 CHAINERX_CUDA_REGISTER_KERNEL(SvdKernel, CudaSvdKernel);
707 
708 class CudaQrKernel : public QrKernel {
709 public:
Call(const Array & a,const Array & q,const Array & r,const Array & tau,QrMode mode)710     void Call(const Array& a, const Array& q, const Array& r, const Array& tau, QrMode mode) override {
711         Device& device = a.device();
712         Dtype dtype = a.dtype();
713         CudaSetDeviceScope scope{device.index()};
714 
715         CHAINERX_ASSERT(a.ndim() == 2);
716 
717         VisitFloatingPointDtype(dtype, [&](auto pt) {
718             using T = typename decltype(pt)::type;
719             QrImpl<T>(a, q, r, tau, mode);
720         });
721     }
722 };
723 
724 CHAINERX_CUDA_REGISTER_KERNEL(QrKernel, CudaQrKernel);
725 
726 class CudaCholeskyKernel : public CholeskyKernel {
727 public:
Call(const Array & a,const Array & out)728     void Call(const Array& a, const Array& out) override {
729         Device& device = a.device();
730         device.CheckDevicesCompatible(a, out);
731         Dtype dtype = a.dtype();
732         CudaSetDeviceScope scope{device.index()};
733 
734         CHAINERX_ASSERT(a.ndim() == 2);
735         CHAINERX_ASSERT(out.ndim() == 2);
736         CHAINERX_ASSERT(a.shape()[0] == a.shape()[1]);
737         CHAINERX_ASSERT(out.IsContiguous());
738         CHAINERX_ASSERT(a.dtype() == out.dtype());
739 
740         // cuSOLVER might not work well with zero-sized arrays for older versions of cuSOLVER (<10.1)
741         // therefore it's better to return earlier
742         if (a.shape().GetTotalSize() == 0) {
743             return;
744         }
745 
746         // potrf (cholesky) stores result in-place, therefore copy ``a`` to ``out`` and then pass ``out`` to the routine
747         device.backend().CallKernel<CopyKernel>(Tril(a, 0), out);
748 
749         auto cholesky_impl = [&](auto pt) {
750             using T = typename decltype(pt)::type;
751 
752             // Note that cuSOLVER uses Fortran order.
753             // To compute a lower triangular matrix L = cholesky(A), we use cuSOLVER to compute an upper triangular matrix U = cholesky(A).
754             cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER;
755 
756             cuda_internal::DeviceInternals& device_internals = cuda_internal::GetDeviceInternals(static_cast<CudaDevice&>(device));
757 
758             // compute workspace size and prepare workspace
759             auto out_ptr = static_cast<T*>(internal::GetRawOffsetData(out));
760             int work_size = 0;
761             int64_t n = a.shape()[0];
762             device_internals.cusolverdn_handle().Call(PotrfBuffersize<T>, uplo, n, out_ptr, std::max(int64_t{1}, n), &work_size);
763 
764             // POTRF execution
765             Array work = Empty(Shape{work_size}, dtype, device);
766             auto work_ptr = static_cast<T*>(internal::GetRawOffsetData(work));
767 
768             std::shared_ptr<void> devinfo = device.Allocate(sizeof(int));
769             device_internals.cusolverdn_handle().Call(
770                     Potrf<T>, uplo, n, out_ptr, std::max(int64_t{1}, n), work_ptr, work_size, static_cast<int*>(devinfo.get()));
771 
772             int devinfo_h = 0;
773             Device& native_device = GetDefaultContext().GetDevice({"native", 0});
774             device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
775             if (devinfo_h != 0) {
776                 throw ChainerxError{"Unsuccessful potrf (Cholesky) execution. Info = ", devinfo_h};
777             }
778         };
779 
780         VisitFloatingPointDtype(dtype, cholesky_impl);
781     }
782 };
783 
784 CHAINERX_CUDA_REGISTER_KERNEL(CholeskyKernel, CudaCholeskyKernel);
785 
786 class CudaSyevdKernel : public SyevdKernel {
787 public:
Call(const Array & a,const Array & w,const Array & v,char uplo,bool compute_v)788     void Call(const Array& a, const Array& w, const Array& v, char uplo, bool compute_v) override {
789         Device& device = a.device();
790         Dtype dtype = a.dtype();
791         CudaSetDeviceScope scope{device.index()};
792 
793         CHAINERX_ASSERT(a.ndim() == 2);
794 
795         device.backend().CallKernel<CopyKernel>(a, v);
796 
797         int64_t m = a.shape()[0];
798         int64_t n = a.shape()[1];
799 
800         auto syevd_impl = [&](auto pt) {
801             using T = typename decltype(pt)::type;
802             cuda_internal::DeviceInternals& device_internals = cuda_internal::GetDeviceInternals(static_cast<CudaDevice&>(device));
803 
804             auto v_ptr = static_cast<T*>(internal::GetRawOffsetData(v));
805             auto w_ptr = static_cast<T*>(internal::GetRawOffsetData(w));
806 
807             cusolverEigMode_t jobz = compute_v ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
808 
809             // cuSOLVER assumes that arrays are stored in column-major order
810             // The uplo argument is swapped instead of transposing the input matrix
811             cublasFillMode_t uplo_cublas = toupper(uplo) == 'U' ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
812 
813             int buffersize = 0;
814             // When calling Syevd matrix dimensions are swapped instead of transposing the input matrix
815             device_internals.cusolverdn_handle().Call(
816                     SyevdBuffersize<T>, jobz, uplo_cublas, n, v_ptr, std::max(int64_t{1}, m), w_ptr, &buffersize);
817 
818             Array work = Empty(Shape{buffersize}, dtype, device);
819             auto work_ptr = static_cast<T*>(internal::GetRawOffsetData(work));
820 
821             std::shared_ptr<void> devinfo = device.Allocate(sizeof(int));
822 
823             device_internals.cusolverdn_handle().Call(
824                     Syevd<T>,
825                     jobz,
826                     uplo_cublas,
827                     n,
828                     v_ptr,
829                     std::max(int64_t{1}, m),
830                     w_ptr,
831                     work_ptr,
832                     buffersize,
833                     static_cast<int*>(devinfo.get()));
834 
835             int devinfo_h = 0;
836             Device& native_device = GetDefaultContext().GetDevice({"native", 0});
837             device.MemoryCopyTo(&devinfo_h, devinfo.get(), sizeof(int), native_device);
838             if (devinfo_h != 0) {
839                 throw ChainerxError{"Unsuccessful syevd (Eigen Decomposition) execution. Info = ", devinfo_h};
840             }
841 
842             // v is stored now in column-major order, need to transform it to row-major
843             device.backend().CallKernel<CopyKernel>(v.Transpose(), v);
844         };
845 
846         VisitFloatingPointDtype(dtype, syevd_impl);
847     }
848 };
849 
850 CHAINERX_CUDA_REGISTER_KERNEL(SyevdKernel, CudaSyevdKernel);
851 
852 }  // namespace cuda
853 }  // namespace chainerx
854