1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file c_lapack_api.h
22 * \brief Unified interface for CPU-based LAPACK calls.
23 * Purpose is to hide the platform specific differences.
24 */
25 #ifndef MXNET_OPERATOR_C_LAPACK_API_H_
26 #define MXNET_OPERATOR_C_LAPACK_API_H_
27
28 // Manually maintained list of LAPACK interfaces that can be used
29 // within MXNET. Conventions:
30 // - We should only import LAPACK-functions that are useful and
31 // ensure that we support them most efficiently on CPU/GPU. As an
32 // example take "potrs": It can be emulated by two calls to
33 // "trsm" (from BLAS3) so not really needed from functionality point
34 // of view. In addition, trsm on GPU supports batch-mode processing
35 // which is much more efficient for a bunch of smaller matrices while
36 // there is no such batch support for potrs. As a result, we may
37 // not support "potrs" internally and if we want to expose it to the user as
38 // a convenience operator at some time, then we may implement it internally
39 // as a sequence of trsm.
40 // - Interfaces must be compliant with lapacke.h in terms of signature and
41 // naming conventions so wrapping a function "foo" which has the
42 // signature
43 // lapack_int LAPACKE_foo(int, char, lapack_int, float* , lapack_int)
44 // within lapacke.h should result in a wrapper with the following signature
45 // int MXNET_LAPACK_foo(int, char, int, float* , int)
46 // Note that function signatures in lapacke.h will always have as first
47 // argument the storage order (row/col-major). All wrappers have to support
48 // that argument. The underlying fortran functions will always assume a
49 // column-major layout.
50 // - In the (usual) case that a wrapper is called specifying row-major storage
51 // order of input/output data, there are two ways to handle this:
52 // 1) The wrapper may support this without allocating any additional memory
53 // for example by exploiting the fact that a matrix is symmetric and switching
54 // certain flags (upper/lower triangular) when calling the fortran code.
55 // 2) The wrapper may cause a runtime error. In that case it should be clearly
56 // documented that these functions do only support col-major layout.
57 // Rationale: This is a low level interface that is not expected to be called
58 // directly from many upstream functions. Usually all calls should go through
59 // the tensor-based interfaces in linalg.h which simplify calls to lapack further
60 // and are better suited to handle additional transpositions that may be necessary.
61 // Also we want to push allocation of temporary storage higher up in order to
62 // allow more efficient re-use of temporal storage. And don't want to plaster
63 // these interfaces here with additional requirements of providing buffers.
64 // - It is desired to add some basic checking in the C++-wrappers in order
65 // to catch simple mistakes when calling these wrappers.
66 // - Must support compilation without lapack-package but issue runtime error in this case.
67
68 #include <dmlc/logging.h>
69 #include "mshadow/tensor.h"
70
71 using namespace mshadow;
72
73 // Will cause clash with MKL/ArmPL fortran layer headers
74 #if (!MSHADOW_USE_MKL && !MSHADOW_USE_ARMPL)
75
76 extern "C" {
77
78 // Fortran signatures
79 #ifdef __ANDROID__
80 #define MXNET_LAPACK_FSIGNATURE1(func, dtype) \
81 int func##_(char* uplo, int* n, dtype* a, int* lda, int *info);
82 #else
83 #define MXNET_LAPACK_FSIGNATURE1(func, dtype) \
84 void func##_(char* uplo, int* n, dtype* a, int* lda, int *info);
85 #endif
86
87 MXNET_LAPACK_FSIGNATURE1(spotrf, float)
88 MXNET_LAPACK_FSIGNATURE1(dpotrf, double)
89 MXNET_LAPACK_FSIGNATURE1(spotri, float)
90 MXNET_LAPACK_FSIGNATURE1(dpotri, double)
91
92 void dposv_(char *uplo, int *n, int *nrhs,
93 double *a, int *lda, double *b, int *ldb, int *info);
94
95 void sposv_(char *uplo, int *n, int *nrhs,
96 float *a, int *lda, float *b, int *ldb, int *info);
97
98 // Note: GELQF in row-major (MXNet) becomes GEQRF in column-major (LAPACK).
99 // Also, m and n are flipped, compared to the row-major version
100 #define MXNET_LAPACK_FSIG_GEQRF(func, dtype) \
101 void func##_(int *m, int *n, dtype *a, int *lda, dtype *tau, dtype *work, \
102 int *lwork, int *info);
103
104 MXNET_LAPACK_FSIG_GEQRF(sgeqrf, float)
105 MXNET_LAPACK_FSIG_GEQRF(dgeqrf, double)
106
107 // Note: ORGLQ in row-major (MXNet) becomes ORGQR in column-major (LAPACK)
108 // Also, m and n are flipped, compared to the row-major version
109 #define MXNET_LAPACK_FSIG_ORGQR(func, dtype) \
110 void func##_(int *m, int *n, int *k, dtype *a, int *lda, dtype *tau, \
111 dtype *work, int *lwork, int *info);
112
113 MXNET_LAPACK_FSIG_ORGQR(sorgqr, float)
114 MXNET_LAPACK_FSIG_ORGQR(dorgqr, double)
115
116 #define MXNET_LAPACK_FSIG_SYEVD(func, dtype) \
117 void func##_(char *jobz, char *uplo, int *n, dtype *a, int *lda, dtype *w, \
118 dtype *work, int *lwork, int *iwork, int *liwork, int *info);
119
120 MXNET_LAPACK_FSIG_SYEVD(ssyevd, float)
121 MXNET_LAPACK_FSIG_SYEVD(dsyevd, double)
122
123 #define MXNET_LAPACK_FSIG_GESVD(func, dtype) \
124 void func##_(char *jobu, char *jobvt, int *m, int *n, dtype *a, int *lda, dtype *s, \
125 dtype* u, int *ldu, dtype *vt, int *ldvt, dtype *work, int *lwork, int *info);
126
127 MXNET_LAPACK_FSIG_GESVD(sgesvd, float)
128 MXNET_LAPACK_FSIG_GESVD(dgesvd, double)
129
130 #ifdef __ANDROID__
131 #define MXNET_LAPACK_FSIG_GETRF(func, dtype) \
132 int func##_(int *m, int *n, dtype *a, int *lda, int *ipiv, int *info);
133 #else
134 #define MXNET_LAPACK_FSIG_GETRF(func, dtype) \
135 void func##_(int *m, int *n, dtype *a, int *lda, int *ipiv, int *info);
136 #endif
137
138 MXNET_LAPACK_FSIG_GETRF(sgetrf, float)
139 MXNET_LAPACK_FSIG_GETRF(dgetrf, double)
140
141 #ifdef __ANDROID__
142 #define MXNET_LAPACK_FSIG_GETRI(func, dtype) \
143 int func##_(int *n, dtype *a, int *lda, int *ipiv, dtype *work, \
144 int *lwork, int *info);
145 #else
146 #define MXNET_LAPACK_FSIG_GETRI(func, dtype) \
147 void func##_(int *n, dtype *a, int *lda, int *ipiv, dtype *work, \
148 int *lwork, int *info);
149 #endif
150
151 MXNET_LAPACK_FSIG_GETRI(sgetri, float)
152 MXNET_LAPACK_FSIG_GETRI(dgetri, double)
153
154 #ifdef __ANDROID__
155 #define MXNET_LAPACK_FSIG_GESV(func, dtype) \
156 int func##_(int *n, int *nrhs, dtype *a, int *lda, \
157 int *ipiv, dtype *b, int *ldb, int *info);
158 #else
159 #define MXNET_LAPACK_FSIG_GESV(func, dtype) \
160 void func##_(int *n, int *nrhs, dtype *a, int *lda, \
161 int *ipiv, dtype *b, int *ldb, int *info);
162 #endif
163
164 MXNET_LAPACK_FSIG_GESV(sgesv, float)
165 MXNET_LAPACK_FSIG_GESV(dgesv, double)
166
167 #ifdef __ANDROID__
168 #define MXNET_LAPACK_FSIG_GESDD(func, dtype) \
169 int func##_(char *jobz, int *m, int *n, dtype *a, int *lda, dtype *s, \
170 dtype *u, int *ldu, \
171 dtype *vt, int *ldvt, \
172 dtype *work, int *lwork, int *iwork, int *info);
173 #else
174 #define MXNET_LAPACK_FSIG_GESDD(func, dtype) \
175 void func##_(char *jobz, int *m, int *n, dtype *a, int *lda, dtype *s, \
176 dtype *u, int *ldu, \
177 dtype *vt, int *ldvt, \
178 dtype *work, int *lwork, int *iwork, int *info);
179 #endif
180
181 MXNET_LAPACK_FSIG_GESDD(sgesdd, float)
182 MXNET_LAPACK_FSIG_GESDD(dgesdd, double)
183
184 #ifdef __ANDROID__
185 #define MXNET_LAPACK_FSIG_GEEV(func, dtype) \
186 int func##_(char *jobvl, char *jobvr, int *n, dtype *a, int *lda, \
187 dtype *wr, dtype *wi, \
188 dtype *vl, int *ldvl, dtype *vr, int *ldvr, \
189 dtype *work, int *lwork, int *info);
190 #else
191 #define MXNET_LAPACK_FSIG_GEEV(func, dtype) \
192 void func##_(char *jobvl, char *jobvr, int *n, dtype *a, int *lda, \
193 dtype *wr, dtype *wi, \
194 dtype *vl, int *ldvl, dtype *vr, int *ldvr, \
195 dtype *work, int *lwork, int *info);
196 #endif
197
198 MXNET_LAPACK_FSIG_GEEV(sgeev, float)
199 MXNET_LAPACK_FSIG_GEEV(dgeev, double)
200 }
201
202 #endif // (!MSHADOW_USE_MKL && !MSHADOW_USE_ARMPL)
203
204
205 #define CHECK_LAPACK_UPLO(a) \
206 CHECK(a == 'U' || a == 'L') << "neither L nor U specified as triangle in lapack call";
207
loup(char uplo,bool invert)208 inline char loup(char uplo, bool invert) { return invert ? (uplo == 'U' ? 'L' : 'U') : uplo; }
209
210 /*!
211 * \brief Transpose matrix data in memory
212 *
213 * Equivalently we can see it as flipping the layout of the matrix
214 * between row-major and column-major.
215 *
216 * \param m number of rows of input matrix a
217 * \param n number of columns of input matrix a
218 * \param b output matrix
219 * \param ldb leading dimension of b
220 * \param a input matrix
221 * \param lda leading dimension of a
222 */
223 template <typename xpu, typename DType>
flip(int m,int n,DType * b,int ldb,DType * a,int lda)224 inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
225 for (int i = 0; i < m; ++i)
226 for (int j = 0; j < n; ++j)
227 b[j * ldb + i] = a[i * lda + j];
228 }
229
230
231 #if (MXNET_USE_LAPACK && (MSHADOW_USE_MKL || MSHADOW_USE_ARMPL))
232
233 #if (MSHADOW_USE_MKL)
234 // We interface with the C-interface of MKL
235 // as this is the preferred way.
236 #include <mkl_lapacke.h>
237 #endif
238
239 #define MXNET_LAPACK_ROW_MAJOR LAPACK_ROW_MAJOR
240 #define MXNET_LAPACK_COL_MAJOR LAPACK_COL_MAJOR
241
242 // These function have already matching signature.
243 #define MXNET_LAPACK_spotrf LAPACKE_spotrf
244 #define MXNET_LAPACK_dpotrf LAPACKE_dpotrf
245 #define MXNET_LAPACK_spotri LAPACKE_spotri
246 #define MXNET_LAPACK_dpotri LAPACKE_dpotri
247 #define mxnet_lapack_sposv LAPACKE_sposv
248 #define mxnet_lapack_dposv LAPACKE_dposv
249 #define MXNET_LAPACK_dgesv LAPACKE_dgesv
250 #define MXNET_LAPACK_sgesv LAPACKE_sgesv
251
252 // The following functions differ in signature from the
253 // MXNET_LAPACK-signature and have to be wrapped.
254 #define MXNET_LAPACK_CWRAP_GELQF(prefix, dtype) \
255 inline int MXNET_LAPACK_##prefix##gelqf(int matrix_layout, int m, int n, \
256 dtype *a, int lda, dtype *tau, \
257 dtype *work, int lwork) { \
258 if (lwork != -1) { \
259 return LAPACKE_##prefix##gelqf(matrix_layout, m, n, a, lda, tau); \
260 } \
261 *work = 0; \
262 return 0; \
263 }
264 MXNET_LAPACK_CWRAP_GELQF(s, float)
265 MXNET_LAPACK_CWRAP_GELQF(d, double)
266
267 #define MXNET_LAPACK_CWRAP_ORGLQ(prefix, dtype) \
268 inline int MXNET_LAPACK_##prefix##orglq(int matrix_layout, int m, int n, \
269 dtype *a, int lda, dtype *tau, \
270 dtype *work, int lwork) { \
271 if (lwork != -1) { \
272 return LAPACKE_##prefix##orglq(matrix_layout, m, n, m, a, lda, tau); \
273 } \
274 *work = 0; \
275 return 0; \
276 }
277 MXNET_LAPACK_CWRAP_ORGLQ(s, float)
278 MXNET_LAPACK_CWRAP_ORGLQ(d, double)
279
280 // This has to be called internally in COL_MAJOR format even when matrix_layout
281 // is row-major as otherwise the eigenvectors would be returned as cols in a
282 // row-major matrix layout (see MKL documentation).
283 // We also have to allocate at least one DType element as workspace as the
284 // calling code assumes that the workspace has at least that size.
285 #define MXNET_LAPACK_CWRAP_SYEVD(prefix, dtype) \
286 inline int MXNET_LAPACK_##prefix##syevd(int matrix_layout, char uplo, int n, dtype *a, \
287 int lda, dtype *w, dtype *work, int lwork, \
288 int *iwork, int liwork) { \
289 if (lwork != -1) { \
290 char o(loup(uplo, (matrix_layout == MXNET_LAPACK_ROW_MAJOR))); \
291 return LAPACKE_##prefix##syevd(LAPACK_COL_MAJOR, 'V', o, n, a, lda, w); \
292 } \
293 *work = 1; \
294 *iwork = 0; \
295 return 0; \
296 }
297 MXNET_LAPACK_CWRAP_SYEVD(s, float)
298 MXNET_LAPACK_CWRAP_SYEVD(d, double)
299
300 #define MXNET_LAPACK_sgetrf LAPACKE_sgetrf
301 #define MXNET_LAPACK_dgetrf LAPACKE_dgetrf
302
303 // Internally A is factorized as U * L * VT, and (according to the tech report)
304 // we want to factorize it as UT * L * V, so we pass ut as u and v as vt.
305 // We also have to allocate at least m - 1 DType elements as workspace as the internal
306 // LAPACKE function needs it to store `superb`. (see MKL documentation)
307 #define MXNET_LAPACK_CWRAP_GESVD(prefix, dtype) \
308 inline int MXNET_LAPACK_##prefix##gesvd(int matrix_layout, int m, int n, dtype* ut, \
309 int ldut, dtype* s, dtype* v, int ldv, \
310 dtype* work, int lwork) { \
311 if (lwork != -1) { \
312 return LAPACKE_##prefix##gesvd(matrix_layout, 'S', 'O', m, n, v, ldv, s, ut, ldut, \
313 v, ldv, work); \
314 } \
315 *work = m - 1; \
316 return 0; \
317 }
318 MXNET_LAPACK_CWRAP_GESVD(s, float)
319 MXNET_LAPACK_CWRAP_GESVD(d, double)
320
321 // Computes the singular value decomposition of a general rectangular matrix
322 // using a divide and conquer method.
323 #define MXNET_LAPACK_CWRAP_GESDD(prefix, dtype) \
324 inline int MXNET_LAPACK_##prefix##gesdd(int matrix_layout, int m, int n, \
325 dtype *a, int lda, dtype *s, \
326 dtype *u, int ldu, \
327 dtype *vt, int ldvt, \
328 dtype *work, int lwork, int *iwork) { \
329 if (lwork != -1) { \
330 return LAPACKE_##prefix##gesdd(matrix_layout, 'O', m, n, a, lda, \
331 s, u, ldu, vt, ldvt); \
332 } \
333 *work = 0; \
334 return 0; \
335 }
336 MXNET_LAPACK_CWRAP_GESDD(s, float)
337 MXNET_LAPACK_CWRAP_GESDD(d, double)
338
339 #define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \
340 inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, int n, dtype *a, int lda, \
341 int *ipiv, dtype *work, int lwork) { \
342 if (lwork != -1) { \
343 return LAPACKE_##prefix##getri(matrix_layout, n, a, lda, ipiv); \
344 } \
345 *work = 0; \
346 return 0; \
347 }
348 MXNET_LAPACK_CWRAP_GETRI(s, float)
349 MXNET_LAPACK_CWRAP_GETRI(d, double)
350
351 #define MXNET_LAPACK_CWRAP_GEEV(prefix, dtype) \
352 inline int MXNET_LAPACK_##prefix##geev(int matrix_layout, char jobvl, char jobvr, \
353 int n, dtype *a, int lda, \
354 dtype *wr, dtype *wi, \
355 dtype *vl, int ldvl, dtype *vr, int ldvr, \
356 dtype *work, int lwork) { \
357 if (lwork != -1) { \
358 return LAPACKE_##prefix##geev(matrix_layout, jobvl, jobvr, \
359 n, a, lda, wr, wi, vl, ldvl, vr, ldvr); \
360 } \
361 *work = 0; \
362 return 0; \
363 }
364 MXNET_LAPACK_CWRAP_GEEV(s, float)
365 MXNET_LAPACK_CWRAP_GEEV(d, double)
366
367 #elif MXNET_USE_LAPACK
368
369 #define MXNET_LAPACK_ROW_MAJOR 101
370 #define MXNET_LAPACK_COL_MAJOR 102
371
372 // These functions can be called with either row- or col-major format.
373 #define MXNET_LAPACK_CWRAPPER1(func, dtype) \
374 inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, int lda) { \
375 CHECK_LAPACK_UPLO(uplo); \
376 char o(loup(uplo, (matrix_layout == MXNET_LAPACK_ROW_MAJOR))); \
377 int ret(0); \
378 func##_(&o, &n, a, &lda, &ret); \
379 return ret; \
380 }
381 MXNET_LAPACK_CWRAPPER1(spotrf, float)
382 MXNET_LAPACK_CWRAPPER1(dpotrf, double)
383 MXNET_LAPACK_CWRAPPER1(spotri, float)
384 MXNET_LAPACK_CWRAPPER1(dpotri, double)
385
386 inline int mxnet_lapack_sposv(int matrix_layout, char uplo, int n, int nrhs,
387 float *a, int lda, float *b, int ldb) {
388 int info;
389 if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) {
390 // Transpose b to b_t of shape (nrhs, n)
391 float *b_t = new float[nrhs * n];
392 flip<cpu, float>(n, nrhs, b_t, n, b, ldb);
393 sposv_(&uplo, &n, &nrhs, a, &lda, b_t, &n, &info);
394 flip<cpu, float>(nrhs, n, b, ldb, b_t, n);
395 delete [] b_t;
396 return info;
397 }
398 sposv_(&uplo, &n, &nrhs, a, &lda, b, &ldb, &info);
399 return info;
400 }
401
402 inline int mxnet_lapack_dposv(int matrix_layout, char uplo, int n, int nrhs,
403 double *a, int lda, double *b, int ldb) {
404 int info;
405 if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) {
406 // Transpose b to b_t of shape (nrhs, n)
407 double *b_t = new double[nrhs * n];
408 flip<cpu, double>(n, nrhs, b_t, n, b, ldb);
409 dposv_(&uplo, &n, &nrhs, a, &lda, b_t, &n, &info);
410 flip<cpu, double>(nrhs, n, b, ldb, b_t, n);
411 delete [] b_t;
412 return info;
413 }
414 dposv_(&uplo, &n, &nrhs, a, &lda, b, &ldb, &info);
415 return info;
416 }
417
418 // Note: Both MXNET_LAPACK_*gelqf, MXNET_LAPACK_*orglq can only be called with
419 // row-major format (MXNet). Internally, the QR variants are done in column-major.
420 // In particular, the matrix dimensions m and n are flipped.
421 #define MXNET_LAPACK_CWRAP_GELQF(prefix, dtype) \
422 inline int MXNET_LAPACK_##prefix##gelqf(int matrix_layout, int m, int n, \
423 dtype *a, int lda, dtype* tau, \
424 dtype* work, int lwork) { \
425 if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
426 int info(0); \
427 prefix##geqrf_(&n, &m, a, &lda, tau, work, &lwork, &info); \
428 return info; \
429 } else { \
430 CHECK(false) << "MXNET_LAPACK_" << #prefix << "gelqf implemented for row-major layout only"; \
431 return 1; \
432 } \
433 }
434 MXNET_LAPACK_CWRAP_GELQF(s, float)
435 MXNET_LAPACK_CWRAP_GELQF(d, double)
436
437 // Note: The k argument (rank) is equal to m as well
438 #define MXNET_LAPACK_CWRAP_ORGLQ(prefix, dtype) \
439 inline int MXNET_LAPACK_##prefix##orglq(int matrix_layout, int m, int n, \
440 dtype *a, int lda, dtype* tau, \
441 dtype* work, int lwork) { \
442 if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
443 int info(0); \
444 prefix##orgqr_(&n, &m, &m, a, &lda, tau, work, &lwork, &info); \
445 return info; \
446 } else { \
447 CHECK(false) << "MXNET_LAPACK_" << #prefix << "orglq implemented for row-major layout only"; \
448 return 1; \
449 } \
450 }
451 MXNET_LAPACK_CWRAP_ORGLQ(s, float)
452 MXNET_LAPACK_CWRAP_ORGLQ(d, double)
453
454 // Note: Supports row-major format only. Internally, column-major is used, so all
455 // inputs/outputs are flipped (in particular, uplo is flipped).
456 #define MXNET_LAPACK_CWRAP_SYEVD(func, dtype) \
457 inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \
458 int lda, dtype *w, dtype *work, int lwork, \
459 int *iwork, int liwork) { \
460 if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
461 int info(0); \
462 char jobz('V'); \
463 char uplo_(loup(uplo, true)); \
464 func##_(&jobz, &uplo_, &n, a, &lda, w, work, &lwork, iwork, &liwork, &info); \
465 return info; \
466 } else { \
467 CHECK(false) << "MXNET_LAPACK_" << #func << " implemented for row-major layout only"; \
468 return 1; \
469 } \
470 }
471 MXNET_LAPACK_CWRAP_SYEVD(ssyevd, float)
472 MXNET_LAPACK_CWRAP_SYEVD(dsyevd, double)
473
474 // Note: Supports row-major format only. Internally, column-major is used, so all
475 // inputs/outputs are flipped and transposed. m and n are flipped as well.
476 #define MXNET_LAPACK_CWRAP_GESVD(func, dtype) \
477 inline int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* ut, \
478 int ldut, dtype* s, dtype* v, int ldv, \
479 dtype* work, int lwork) { \
480 if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
481 int info(0); \
482 char jobu('O'); \
483 char jobvt('S'); \
484 func##_(&jobu, &jobvt, &n, &m, v, &ldv, s, v, &ldv, ut, &ldut, work, &lwork, &info); \
485 return info; \
486 } else { \
487 CHECK(false) << "MXNET_LAPACK_" << #func << " implemented for row-major layout only"; \
488 return 1; \
489 } \
490 }
491 MXNET_LAPACK_CWRAP_GESVD(sgesvd, float)
492 MXNET_LAPACK_CWRAP_GESVD(dgesvd, double)
493
494 #define MXNET_LAPACK_CWRAP_GESDD(func, dtype) \
495 inline int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \
496 dtype *a, int lda, dtype *s, \
497 dtype *u, int ldu, \
498 dtype *vt, int ldvt, \
499 dtype *work, int lwork, int *iwork) { \
500 if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
501 CHECK(false) << "MXNET_LAPACK_" << #func << " implemented for row-major layout only"; \
502 return 1; \
503 } else { \
504 int info(0); \
505 char jobz('O'); \
506 func##_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, &info); \
507 return info; \
508 } \
509 }
510 MXNET_LAPACK_CWRAP_GESDD(sgesdd, float)
511 MXNET_LAPACK_CWRAP_GESDD(dgesdd, double)
512
513 #define MXNET_LAPACK_CWRAP_GEEV(prefix, dtype) \
514 inline int MXNET_LAPACK_##prefix##geev(int matrix_layout, char jobvl, char jobvr, \
515 int n, dtype *a, int lda, \
516 dtype *wr, dtype *wi, \
517 dtype *vl, int ldvl, dtype *vr, int ldvr, \
518 dtype *work, int lwork) { \
519 if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
520 CHECK(false) << "MXNET_LAPACK_" << #prefix << "geev implemented for col-major layout only"; \
521 return 1; \
522 } else { \
523 int info(0); \
524 prefix##geev_(&jobvl, &jobvr, \
525 &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, &info); \
526 return info; \
527 } \
528 }
529 MXNET_LAPACK_CWRAP_GEEV(s, float)
530 MXNET_LAPACK_CWRAP_GEEV(d, double)
531
532 #define MXNET_LAPACK
533
534 // Note: Both MXNET_LAPACK_*getrf, MXNET_LAPACK_*getri can only be called with col-major format
535 // (MXNet) for performance.
536 #define MXNET_LAPACK_CWRAP_GETRF(prefix, dtype) \
537 inline int MXNET_LAPACK_##prefix##getrf(int matrix_layout, int m, int n, \
538 dtype *a, int lda, int *ipiv) { \
539 if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
540 CHECK(false) << "MXNET_LAPACK_" << #prefix << "getri implemented for col-major layout only"; \
541 return 1; \
542 } else { \
543 int info(0); \
544 prefix##getrf_(&m, &n, a, &lda, ipiv, &info); \
545 return info; \
546 } \
547 }
548 MXNET_LAPACK_CWRAP_GETRF(s, float)
549 MXNET_LAPACK_CWRAP_GETRF(d, double)
550
551 #define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \
552 inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, int n, dtype *a, int lda, \
553 int *ipiv, dtype *work, int lwork) { \
554 if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
555 CHECK(false) << "MXNET_LAPACK_" << #prefix << "getri implemented for col-major layout only"; \
556 return 1; \
557 } else { \
558 int info(0); \
559 prefix##getri_(&n, a, &lda, ipiv, work, &lwork, &info); \
560 return info; \
561 } \
562 }
563 MXNET_LAPACK_CWRAP_GETRI(s, float)
564 MXNET_LAPACK_CWRAP_GETRI(d, double)
565
566 #define MXNET_LAPACK_CWRAP_GESV(prefix, dtype) \
567 inline int MXNET_LAPACK_##prefix##gesv(int matrix_layout, \
568 int n, int nrhs, dtype *a, int lda, \
569 int *ipiv, dtype *b, int ldb) { \
570 if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
571 CHECK(false) << "MXNET_LAPACK_" << #prefix << "gesv implemented for col-major layout only"; \
572 return 1; \
573 } else { \
574 int info(0); \
575 prefix##gesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, &info); \
576 return info; \
577 } \
578 }
579 MXNET_LAPACK_CWRAP_GESV(s, float)
580 MXNET_LAPACK_CWRAP_GESV(d, double)
581
582 #else
583
584 #define MXNET_LAPACK_ROW_MAJOR 101
585 #define MXNET_LAPACK_COL_MAJOR 102
586
587 // Define compilable stubs.
588 #define MXNET_LAPACK_CWRAPPER1(func, dtype) \
589 int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* a, int lda);
590
591 #define MXNET_LAPACK_CWRAPPER2(func, dtype) \
592 int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \
593 int lda, dtype* tau, dtype* work, int lwork);
594
595 #define MXNET_LAPACK_CWRAPPER3(func, dtype) \
596 int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \
597 int lda, dtype *w, dtype *work, int lwork, \
598 int *iwork, int liwork);
599
600 #define MXNET_LAPACK_CWRAPPER4(func, dtype) \
601 int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \
602 dtype *a, int lda, int *ipiv);
603
604 #define MXNET_LAPACK_CWRAPPER5(func, dtype) \
605 int MXNET_LAPACK_##func(int matrix_layout, int n, dtype *a, int lda, \
606 int *ipiv, dtype *work, int lwork);
607
608 #define MXNET_LAPACK_CWRAPPER6(func, dtype) \
609 int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* ut, \
610 int ldut, dtype* s, dtype* v, int ldv, \
611 dtype* work, int lwork);
612
613 #define MXNET_LAPACK_CWRAPPER7(func, dtype) \
614 int MXNET_LAPACK_##func(int matrix_order, int n, int nrhs, dtype *a, \
615 int lda, int *ipiv, dtype *b, int ldb); \
616
617 #define MXNET_LAPACK_CWRAPPER8(func, dtype) \
618 int MXNET_LAPACK_##func(int matrix_layout, char jobvl, char jobvr, \
619 int n, dtype *a, int lda, \
620 dtype *wr, dtype *wi, \
621 dtype *vl, int ldvl, dtype *vr, int ldvr, \
622 dtype *work, int lwork); \
623
624 #define MXNET_LAPACK_CWRAPPER9(func, dtype) \
625 int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \
626 dtype *a, int lda, dtype *s, \
627 dtype *u, int ldu, \
628 dtype *vt, int ldvt, \
629 dtype *work, int lwork, int *iwork);
630
631 #define MXNET_LAPACK_UNAVAILABLE(func) \
632 int mxnet_lapack_##func(...);
633 MXNET_LAPACK_CWRAPPER1(spotrf, float)
634 MXNET_LAPACK_CWRAPPER1(dpotrf, double)
635 MXNET_LAPACK_CWRAPPER1(spotri, float)
636 MXNET_LAPACK_CWRAPPER1(dpotri, double)
637
638 MXNET_LAPACK_UNAVAILABLE(sposv)
639 MXNET_LAPACK_UNAVAILABLE(dposv)
640
641 MXNET_LAPACK_CWRAPPER2(sgelqf, float)
642 MXNET_LAPACK_CWRAPPER2(dgelqf, double)
643 MXNET_LAPACK_CWRAPPER2(sorglq, float)
644 MXNET_LAPACK_CWRAPPER2(dorglq, double)
645
646 MXNET_LAPACK_CWRAPPER3(ssyevd, float)
647 MXNET_LAPACK_CWRAPPER3(dsyevd, double)
648
649 MXNET_LAPACK_CWRAPPER4(sgetrf, float)
650 MXNET_LAPACK_CWRAPPER4(dgetrf, double)
651
652 MXNET_LAPACK_CWRAPPER5(sgetri, float)
653 MXNET_LAPACK_CWRAPPER5(dgetri, double)
654
655 MXNET_LAPACK_CWRAPPER6(sgesvd, float)
656 MXNET_LAPACK_CWRAPPER6(dgesvd, double)
657
658 MXNET_LAPACK_CWRAPPER7(sgesv, float)
659 MXNET_LAPACK_CWRAPPER7(dgesv, double)
660
661 MXNET_LAPACK_CWRAPPER8(sgeev, float)
662 MXNET_LAPACK_CWRAPPER8(dgeev, double)
663
664 MXNET_LAPACK_CWRAPPER9(sgesdd, float)
665 MXNET_LAPACK_CWRAPPER9(dgesdd, double)
666
667 #undef MXNET_LAPACK_CWRAPPER1
668 #undef MXNET_LAPACK_CWRAPPER2
669 #undef MXNET_LAPACK_CWRAPPER3
670 #undef MXNET_LAPACK_CWRAPPER4
671 #undef MXNET_LAPACK_UNAVAILABLE
672 #endif
673
674 template <typename DType>
675 inline int MXNET_LAPACK_posv(int matrix_layout, char uplo, int n, int nrhs,
676 DType *a, int lda, DType *b, int ldb);
677
678 template <>
679 inline int MXNET_LAPACK_posv<float>(int matrix_layout, char uplo, int n,
680 int nrhs, float *a, int lda, float *b, int ldb) {
681 return mxnet_lapack_sposv(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
682 }
683
684 template <>
685 inline int MXNET_LAPACK_posv<double>(int matrix_layout, char uplo, int n,
686 int nrhs, double *a, int lda, double *b, int ldb) {
687 return mxnet_lapack_dposv(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
688 }
689
690 #endif // MXNET_OPERATOR_C_LAPACK_API_H_
691