1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file c_lapack_api.h
22  * \brief Unified interface for CPU-based LAPACK calls.
23  *  Purpose is to hide the platform specific differences.
24  */
25 #ifndef MXNET_OPERATOR_C_LAPACK_API_H_
26 #define MXNET_OPERATOR_C_LAPACK_API_H_
27 
28 // Manually maintained list of LAPACK interfaces that can be used
29 // within MXNET. Conventions:
30 //    - We should only import LAPACK-functions that are useful and
31 //      ensure that we support them most efficiently on CPU/GPU. As an
32 //      example take "potrs": It can be emulated by two calls to
33 //      "trsm" (from BLAS3) so not really needed from functionality point
34 //      of view. In addition, trsm on GPU supports batch-mode processing
35 //      which is much more efficient for a bunch of smaller matrices while
36 //      there is no such batch support for potrs. As a result, we may
37 //      not support "potrs" internally and if we want to expose it to the user as
38 //      a convenience operator at some time, then we may implement it internally
39 //      as a sequence of trsm.
40 //    - Interfaces must be compliant with lapacke.h in terms of signature and
41 //      naming conventions so wrapping a function "foo" which has the
42 //      signature
43 //         lapack_int LAPACKE_foo(int, char, lapack_int, float* , lapack_int)
44 //      within lapacke.h should result in a wrapper with the following signature
45 //         int MXNET_LAPACK_foo(int, char, int, float* , int)
46 //      Note that function signatures in lapacke.h will always have as first
47 //      argument the storage order (row/col-major). All wrappers have to support
48 //      that argument. The underlying fortran functions will always assume a
49 //      column-major layout.
50 //    - In the (usual) case that a wrapper is called specifying row-major storage
51 //      order of input/output data, there are two ways to handle this:
52 //        1) The wrapper may support this without allocating any additional memory
53 //           for example by exploiting the fact that a matrix is symmetric and switching
54 //           certain flags (upper/lower triangular) when calling the fortran code.
55 //        2) The wrapper may cause a runtime error. In that case it should be clearly
56 //           documented that these functions do only support col-major layout.
57 //      Rationale: This is a low level interface that is not expected to be called
58 //      directly from many upstream functions. Usually all calls should go through
59 //      the tensor-based interfaces in linalg.h which simplify calls to lapack further
60 //      and are better suited to handle additional transpositions that may be necessary.
61 //      Also we want to push allocation of temporary storage higher up in order to
62 //      allow more efficient re-use of temporal storage. And don't want to plaster
63 //      these interfaces here with additional requirements of providing buffers.
64 //    - It is desired to add some basic checking in the C++-wrappers in order
65 //      to catch simple mistakes when calling these wrappers.
66 //    - Must support compilation without lapack-package but issue runtime error in this case.
67 
68 #include <dmlc/logging.h>
69 #include "mshadow/tensor.h"
70 
71 using namespace mshadow;
72 
73 // Will cause clash with MKL/ArmPL fortran layer headers
74 #if (!MSHADOW_USE_MKL && !MSHADOW_USE_ARMPL)
75 
76 extern "C" {
77 
78   // Fortran signatures
79   #ifdef __ANDROID__
80     #define MXNET_LAPACK_FSIGNATURE1(func, dtype) \
81       int func##_(char* uplo, int* n, dtype* a, int* lda, int *info);
82   #else
83     #define MXNET_LAPACK_FSIGNATURE1(func, dtype) \
84       void func##_(char* uplo, int* n, dtype* a, int* lda, int *info);
85   #endif
86 
87   MXNET_LAPACK_FSIGNATURE1(spotrf, float)
88   MXNET_LAPACK_FSIGNATURE1(dpotrf, double)
89   MXNET_LAPACK_FSIGNATURE1(spotri, float)
90   MXNET_LAPACK_FSIGNATURE1(dpotri, double)
91 
92   void dposv_(char *uplo, int *n, int *nrhs,
93     double *a, int *lda, double *b, int *ldb, int *info);
94 
95   void sposv_(char *uplo, int *n, int *nrhs,
96     float *a, int *lda, float *b, int *ldb, int *info);
97 
98   // Note: GELQF in row-major (MXNet) becomes GEQRF in column-major (LAPACK).
99   // Also, m and n are flipped, compared to the row-major version
100   #define MXNET_LAPACK_FSIG_GEQRF(func, dtype) \
101     void func##_(int *m, int *n, dtype *a, int *lda, dtype *tau, dtype *work, \
102                  int *lwork, int *info);
103 
104   MXNET_LAPACK_FSIG_GEQRF(sgeqrf, float)
105   MXNET_LAPACK_FSIG_GEQRF(dgeqrf, double)
106 
107   // Note: ORGLQ in row-major (MXNet) becomes ORGQR in column-major (LAPACK)
108   // Also, m and n are flipped, compared to the row-major version
109   #define MXNET_LAPACK_FSIG_ORGQR(func, dtype) \
110     void func##_(int *m, int *n, int *k, dtype *a, int *lda, dtype *tau, \
111                  dtype *work, int *lwork, int *info);
112 
113   MXNET_LAPACK_FSIG_ORGQR(sorgqr, float)
114   MXNET_LAPACK_FSIG_ORGQR(dorgqr, double)
115 
116   #define MXNET_LAPACK_FSIG_SYEVD(func, dtype) \
117     void func##_(char *jobz, char *uplo, int *n, dtype *a, int *lda, dtype *w, \
118                  dtype *work, int *lwork, int *iwork, int *liwork, int *info);
119 
120   MXNET_LAPACK_FSIG_SYEVD(ssyevd, float)
121   MXNET_LAPACK_FSIG_SYEVD(dsyevd, double)
122 
123   #define MXNET_LAPACK_FSIG_GESVD(func, dtype) \
124     void func##_(char *jobu, char *jobvt, int *m, int *n, dtype *a, int *lda, dtype *s, \
125                  dtype* u, int *ldu, dtype *vt, int *ldvt, dtype *work, int *lwork, int *info);
126 
127   MXNET_LAPACK_FSIG_GESVD(sgesvd, float)
128   MXNET_LAPACK_FSIG_GESVD(dgesvd, double)
129 
130   #ifdef __ANDROID__
131     #define MXNET_LAPACK_FSIG_GETRF(func, dtype) \
132       int func##_(int *m, int *n, dtype *a, int *lda, int *ipiv, int *info);
133   #else
134     #define MXNET_LAPACK_FSIG_GETRF(func, dtype) \
135       void func##_(int *m, int *n, dtype *a, int *lda, int *ipiv, int *info);
136   #endif
137 
138   MXNET_LAPACK_FSIG_GETRF(sgetrf, float)
139   MXNET_LAPACK_FSIG_GETRF(dgetrf, double)
140 
141   #ifdef __ANDROID__
142     #define MXNET_LAPACK_FSIG_GETRI(func, dtype) \
143       int func##_(int *n, dtype *a, int *lda, int *ipiv, dtype *work, \
144                   int *lwork, int *info);
145   #else
146     #define MXNET_LAPACK_FSIG_GETRI(func, dtype) \
147       void func##_(int *n, dtype *a, int *lda, int *ipiv, dtype *work, \
148                    int *lwork, int *info);
149   #endif
150 
151   MXNET_LAPACK_FSIG_GETRI(sgetri, float)
152   MXNET_LAPACK_FSIG_GETRI(dgetri, double)
153 
154   #ifdef __ANDROID__
155     #define MXNET_LAPACK_FSIG_GESV(func, dtype) \
156       int func##_(int *n, int *nrhs, dtype *a, int *lda, \
157                    int *ipiv, dtype *b, int *ldb, int *info);
158   #else
159     #define MXNET_LAPACK_FSIG_GESV(func, dtype) \
160       void func##_(int *n, int *nrhs, dtype *a, int *lda, \
161                    int *ipiv, dtype *b, int *ldb, int *info);
162   #endif
163 
164   MXNET_LAPACK_FSIG_GESV(sgesv, float)
165   MXNET_LAPACK_FSIG_GESV(dgesv, double)
166 
167   #ifdef __ANDROID__
168     #define MXNET_LAPACK_FSIG_GESDD(func, dtype) \
169     int func##_(char *jobz, int *m, int *n, dtype *a, int *lda, dtype *s, \
170                 dtype *u, int *ldu, \
171                 dtype *vt, int *ldvt, \
172                 dtype *work, int *lwork, int *iwork, int *info);
173   #else
174     #define MXNET_LAPACK_FSIG_GESDD(func, dtype) \
175     void func##_(char *jobz, int *m, int *n, dtype *a, int *lda, dtype *s, \
176                  dtype *u, int *ldu, \
177                  dtype *vt, int *ldvt, \
178                  dtype *work, int *lwork, int *iwork, int *info);
179   #endif
180 
181   MXNET_LAPACK_FSIG_GESDD(sgesdd, float)
182   MXNET_LAPACK_FSIG_GESDD(dgesdd, double)
183 
184   #ifdef __ANDROID__
185     #define MXNET_LAPACK_FSIG_GEEV(func, dtype) \
186     int func##_(char *jobvl, char *jobvr, int *n, dtype *a, int *lda, \
187                 dtype *wr, dtype *wi, \
188                 dtype *vl, int *ldvl, dtype *vr, int *ldvr, \
189                 dtype *work, int *lwork, int *info);
190   #else
191     #define MXNET_LAPACK_FSIG_GEEV(func, dtype) \
192     void func##_(char *jobvl, char *jobvr, int *n, dtype *a, int *lda, \
193                 dtype *wr, dtype *wi, \
194                 dtype *vl, int *ldvl, dtype *vr, int *ldvr, \
195                 dtype *work, int *lwork, int *info);
196   #endif
197 
198   MXNET_LAPACK_FSIG_GEEV(sgeev, float)
199   MXNET_LAPACK_FSIG_GEEV(dgeev, double)
200 }
201 
202 #endif  // (!MSHADOW_USE_MKL && !MSHADOW_USE_ARMPL)
203 
204 
205 #define CHECK_LAPACK_UPLO(a) \
206   CHECK(a == 'U' || a == 'L') << "neither L nor U specified as triangle in lapack call";
207 
loup(char uplo,bool invert)208 inline char loup(char uplo, bool invert) { return invert ? (uplo == 'U' ? 'L' : 'U') : uplo; }
209 
210 /*!
211  * \brief Transpose matrix data in memory
212  *
213  * Equivalently we can see it as flipping the layout of the matrix
214  * between row-major and column-major.
215  *
216  * \param m number of rows of input matrix a
217  * \param n number of columns of input matrix a
218  * \param b output matrix
219  * \param ldb leading dimension of b
220  * \param a input matrix
221  * \param lda leading dimension of a
222  */
223 template <typename xpu, typename DType>
flip(int m,int n,DType * b,int ldb,DType * a,int lda)224 inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
225   for (int i = 0; i < m; ++i)
226     for (int j = 0; j < n; ++j)
227       b[j * ldb + i] = a[i * lda + j];
228 }
229 
230 
231 #if (MXNET_USE_LAPACK && (MSHADOW_USE_MKL || MSHADOW_USE_ARMPL))
232 
233   #if (MSHADOW_USE_MKL)
234     // We interface with the C-interface of MKL
235     // as this is the preferred way.
236     #include <mkl_lapacke.h>
237   #endif
238 
239   #define MXNET_LAPACK_ROW_MAJOR LAPACK_ROW_MAJOR
240   #define MXNET_LAPACK_COL_MAJOR LAPACK_COL_MAJOR
241 
242   // These function have already matching signature.
243   #define MXNET_LAPACK_spotrf LAPACKE_spotrf
244   #define MXNET_LAPACK_dpotrf LAPACKE_dpotrf
245   #define MXNET_LAPACK_spotri LAPACKE_spotri
246   #define MXNET_LAPACK_dpotri LAPACKE_dpotri
247   #define mxnet_lapack_sposv  LAPACKE_sposv
248   #define mxnet_lapack_dposv  LAPACKE_dposv
249   #define MXNET_LAPACK_dgesv  LAPACKE_dgesv
250   #define MXNET_LAPACK_sgesv  LAPACKE_sgesv
251 
252   // The following functions differ in signature from the
253   // MXNET_LAPACK-signature and have to be wrapped.
254   #define MXNET_LAPACK_CWRAP_GELQF(prefix, dtype) \
255   inline int MXNET_LAPACK_##prefix##gelqf(int matrix_layout, int m, int n, \
256                                           dtype *a, int lda, dtype *tau, \
257                                           dtype *work, int lwork) { \
258     if (lwork != -1) { \
259       return LAPACKE_##prefix##gelqf(matrix_layout, m, n, a, lda, tau); \
260     } \
261     *work = 0; \
262     return 0; \
263   }
264   MXNET_LAPACK_CWRAP_GELQF(s, float)
265   MXNET_LAPACK_CWRAP_GELQF(d, double)
266 
267   #define MXNET_LAPACK_CWRAP_ORGLQ(prefix, dtype) \
268   inline int MXNET_LAPACK_##prefix##orglq(int matrix_layout, int m, int n, \
269                                           dtype *a, int lda, dtype *tau, \
270                                           dtype *work, int lwork) { \
271     if (lwork != -1) { \
272       return LAPACKE_##prefix##orglq(matrix_layout, m, n, m, a, lda, tau); \
273     } \
274     *work = 0; \
275     return 0; \
276   }
277   MXNET_LAPACK_CWRAP_ORGLQ(s, float)
278   MXNET_LAPACK_CWRAP_ORGLQ(d, double)
279 
280   // This has to be called internally in COL_MAJOR format even when matrix_layout
281   // is row-major as otherwise the eigenvectors would be returned as cols in a
282   // row-major matrix layout (see MKL documentation).
283   // We also have to allocate at least one DType element as workspace as the
284   // calling code assumes that the workspace has at least that size.
285   #define MXNET_LAPACK_CWRAP_SYEVD(prefix, dtype) \
286   inline int MXNET_LAPACK_##prefix##syevd(int matrix_layout, char uplo, int n, dtype *a, \
287                                           int lda, dtype *w, dtype *work, int lwork, \
288                                           int *iwork, int liwork) { \
289     if (lwork != -1) { \
290       char o(loup(uplo, (matrix_layout == MXNET_LAPACK_ROW_MAJOR))); \
291       return LAPACKE_##prefix##syevd(LAPACK_COL_MAJOR, 'V', o, n, a, lda, w); \
292     } \
293     *work = 1; \
294     *iwork = 0; \
295     return 0; \
296   }
297   MXNET_LAPACK_CWRAP_SYEVD(s, float)
298   MXNET_LAPACK_CWRAP_SYEVD(d, double)
299 
300   #define MXNET_LAPACK_sgetrf LAPACKE_sgetrf
301   #define MXNET_LAPACK_dgetrf LAPACKE_dgetrf
302 
303   // Internally A is factorized as U * L * VT, and (according to the tech report)
304   // we want to factorize it as UT * L * V, so we pass ut as u and v as vt.
305   // We also have to allocate at least m - 1 DType elements as workspace as the internal
306   // LAPACKE function needs it to store `superb`. (see MKL documentation)
307   #define MXNET_LAPACK_CWRAP_GESVD(prefix, dtype) \
308   inline int MXNET_LAPACK_##prefix##gesvd(int matrix_layout, int m, int n, dtype* ut, \
309                                           int ldut, dtype* s, dtype* v, int ldv, \
310                                           dtype* work, int lwork) { \
311     if (lwork != -1) { \
312       return LAPACKE_##prefix##gesvd(matrix_layout, 'S', 'O', m, n, v, ldv, s, ut, ldut, \
313                                      v, ldv, work); \
314     } \
315     *work = m - 1; \
316     return 0; \
317   }
318   MXNET_LAPACK_CWRAP_GESVD(s, float)
319   MXNET_LAPACK_CWRAP_GESVD(d, double)
320 
321   // Computes the singular value decomposition of a general rectangular matrix
322   // using a divide and conquer method.
323   #define MXNET_LAPACK_CWRAP_GESDD(prefix, dtype) \
324   inline int MXNET_LAPACK_##prefix##gesdd(int matrix_layout, int m, int n, \
325                                           dtype *a, int lda, dtype *s, \
326                                           dtype *u, int ldu, \
327                                           dtype *vt, int ldvt, \
328                                           dtype *work, int lwork, int *iwork) { \
329     if (lwork != -1) { \
330       return LAPACKE_##prefix##gesdd(matrix_layout, 'O', m, n, a, lda, \
331                                      s, u, ldu, vt, ldvt); \
332     } \
333     *work = 0; \
334     return 0; \
335   }
336   MXNET_LAPACK_CWRAP_GESDD(s, float)
337   MXNET_LAPACK_CWRAP_GESDD(d, double)
338 
339   #define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \
340   inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, int n, dtype *a, int lda, \
341                                           int *ipiv, dtype *work, int lwork) { \
342     if (lwork != -1) { \
343       return LAPACKE_##prefix##getri(matrix_layout, n, a, lda, ipiv); \
344     } \
345     *work = 0; \
346     return 0; \
347   }
348   MXNET_LAPACK_CWRAP_GETRI(s, float)
349   MXNET_LAPACK_CWRAP_GETRI(d, double)
350 
351   #define MXNET_LAPACK_CWRAP_GEEV(prefix, dtype) \
352   inline int MXNET_LAPACK_##prefix##geev(int matrix_layout, char jobvl, char jobvr, \
353                                          int n, dtype *a, int lda, \
354                                          dtype *wr, dtype *wi, \
355                                          dtype *vl, int ldvl, dtype *vr, int ldvr, \
356                                          dtype *work, int lwork) { \
357     if (lwork != -1) { \
358       return LAPACKE_##prefix##geev(matrix_layout, jobvl, jobvr, \
359                                     n, a, lda, wr, wi, vl, ldvl, vr, ldvr); \
360     } \
361     *work = 0; \
362     return 0; \
363   }
364   MXNET_LAPACK_CWRAP_GEEV(s, float)
365   MXNET_LAPACK_CWRAP_GEEV(d, double)
366 
367 #elif MXNET_USE_LAPACK
368 
369   #define MXNET_LAPACK_ROW_MAJOR 101
370   #define MXNET_LAPACK_COL_MAJOR 102
371 
372   // These functions can be called with either row- or col-major format.
373   #define MXNET_LAPACK_CWRAPPER1(func, dtype) \
374   inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, int lda) { \
375     CHECK_LAPACK_UPLO(uplo); \
376     char o(loup(uplo, (matrix_layout == MXNET_LAPACK_ROW_MAJOR))); \
377     int ret(0); \
378     func##_(&o, &n, a, &lda, &ret); \
379     return ret; \
380   }
381   MXNET_LAPACK_CWRAPPER1(spotrf, float)
382   MXNET_LAPACK_CWRAPPER1(dpotrf, double)
383   MXNET_LAPACK_CWRAPPER1(spotri, float)
384   MXNET_LAPACK_CWRAPPER1(dpotri, double)
385 
386   inline int mxnet_lapack_sposv(int matrix_layout, char uplo, int n, int nrhs,
387     float *a, int lda, float *b, int ldb) {
388     int info;
389     if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) {
390       // Transpose b to b_t of shape (nrhs, n)
391       float *b_t = new float[nrhs * n];
392       flip<cpu, float>(n, nrhs, b_t, n, b, ldb);
393       sposv_(&uplo, &n, &nrhs, a, &lda, b_t, &n, &info);
394       flip<cpu, float>(nrhs, n, b, ldb, b_t, n);
395       delete [] b_t;
396       return info;
397     }
398     sposv_(&uplo, &n, &nrhs, a, &lda, b, &ldb, &info);
399     return info;
400   }
401 
402   inline int mxnet_lapack_dposv(int matrix_layout, char uplo, int n, int nrhs,
403     double *a, int lda, double *b, int ldb) {
404     int info;
405     if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) {
406       // Transpose b to b_t of shape (nrhs, n)
407       double *b_t = new double[nrhs * n];
408       flip<cpu, double>(n, nrhs, b_t, n, b, ldb);
409       dposv_(&uplo, &n, &nrhs, a, &lda, b_t, &n, &info);
410       flip<cpu, double>(nrhs, n, b, ldb, b_t, n);
411       delete [] b_t;
412       return info;
413     }
414     dposv_(&uplo, &n, &nrhs, a, &lda, b, &ldb, &info);
415     return info;
416   }
417 
418   // Note: Both MXNET_LAPACK_*gelqf, MXNET_LAPACK_*orglq can only be called with
419   // row-major format (MXNet). Internally, the QR variants are done in column-major.
420   // In particular, the matrix dimensions m and n are flipped.
421   #define MXNET_LAPACK_CWRAP_GELQF(prefix, dtype) \
422   inline int MXNET_LAPACK_##prefix##gelqf(int matrix_layout, int m, int n, \
423                                           dtype *a, int lda, dtype* tau, \
424                                           dtype* work, int lwork) { \
425     if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
426       int info(0); \
427       prefix##geqrf_(&n, &m, a, &lda, tau, work, &lwork, &info); \
428       return info; \
429     } else { \
430       CHECK(false) << "MXNET_LAPACK_" << #prefix << "gelqf implemented for row-major layout only"; \
431       return 1; \
432     } \
433   }
434   MXNET_LAPACK_CWRAP_GELQF(s, float)
435   MXNET_LAPACK_CWRAP_GELQF(d, double)
436 
437   // Note: The k argument (rank) is equal to m as well
438   #define MXNET_LAPACK_CWRAP_ORGLQ(prefix, dtype) \
439   inline int MXNET_LAPACK_##prefix##orglq(int matrix_layout, int m, int n, \
440                                           dtype *a, int lda, dtype* tau, \
441                                           dtype* work, int lwork) { \
442     if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
443       int info(0); \
444       prefix##orgqr_(&n, &m, &m, a, &lda, tau, work, &lwork, &info);    \
445       return info; \
446     } else { \
447       CHECK(false) << "MXNET_LAPACK_" << #prefix << "orglq implemented for row-major layout only"; \
448       return 1; \
449     } \
450   }
451   MXNET_LAPACK_CWRAP_ORGLQ(s, float)
452   MXNET_LAPACK_CWRAP_ORGLQ(d, double)
453 
454   // Note: Supports row-major format only. Internally, column-major is used, so all
455   // inputs/outputs are flipped (in particular, uplo is flipped).
456   #define MXNET_LAPACK_CWRAP_SYEVD(func, dtype) \
457   inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \
458                                  int lda, dtype *w, dtype *work, int lwork, \
459                                  int *iwork, int liwork) { \
460     if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
461       int info(0); \
462       char jobz('V'); \
463       char uplo_(loup(uplo, true)); \
464       func##_(&jobz, &uplo_, &n, a, &lda, w, work, &lwork, iwork, &liwork, &info); \
465       return info; \
466     } else { \
467       CHECK(false) << "MXNET_LAPACK_" << #func << " implemented for row-major layout only"; \
468       return 1; \
469     } \
470   }
471   MXNET_LAPACK_CWRAP_SYEVD(ssyevd, float)
472   MXNET_LAPACK_CWRAP_SYEVD(dsyevd, double)
473 
474   // Note: Supports row-major format only. Internally, column-major is used, so all
475   // inputs/outputs are flipped and transposed. m and n are flipped as well.
476   #define MXNET_LAPACK_CWRAP_GESVD(func, dtype) \
477   inline int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* ut, \
478                                  int ldut, dtype* s, dtype* v, int ldv, \
479                                  dtype* work, int lwork) { \
480     if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
481       int info(0); \
482       char jobu('O'); \
483       char jobvt('S'); \
484       func##_(&jobu, &jobvt, &n, &m, v, &ldv, s, v, &ldv, ut, &ldut, work, &lwork, &info); \
485       return info; \
486     } else { \
487       CHECK(false) << "MXNET_LAPACK_" << #func << " implemented for row-major layout only"; \
488       return 1; \
489     } \
490   }
491   MXNET_LAPACK_CWRAP_GESVD(sgesvd, float)
492   MXNET_LAPACK_CWRAP_GESVD(dgesvd, double)
493 
494   #define MXNET_LAPACK_CWRAP_GESDD(func, dtype) \
495   inline int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \
496                                  dtype *a, int lda, dtype *s, \
497                                  dtype *u, int ldu, \
498                                  dtype *vt, int ldvt, \
499                                  dtype *work, int lwork, int *iwork) { \
500     if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
501       CHECK(false) << "MXNET_LAPACK_" << #func << " implemented for row-major layout only"; \
502       return 1; \
503     } else { \
504       int info(0); \
505       char jobz('O'); \
506       func##_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, &info); \
507       return info; \
508     } \
509   }
510   MXNET_LAPACK_CWRAP_GESDD(sgesdd, float)
511   MXNET_LAPACK_CWRAP_GESDD(dgesdd, double)
512 
513   #define MXNET_LAPACK_CWRAP_GEEV(prefix, dtype) \
514   inline int MXNET_LAPACK_##prefix##geev(int matrix_layout, char jobvl, char jobvr, \
515                                          int n, dtype *a, int lda, \
516                                          dtype *wr, dtype *wi, \
517                                          dtype *vl, int ldvl, dtype *vr, int ldvr, \
518                                          dtype *work, int lwork) { \
519     if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
520       CHECK(false) << "MXNET_LAPACK_" << #prefix << "geev implemented for col-major layout only"; \
521       return 1; \
522     } else { \
523       int info(0); \
524       prefix##geev_(&jobvl, &jobvr, \
525                     &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, &info); \
526       return info; \
527     } \
528   }
529   MXNET_LAPACK_CWRAP_GEEV(s, float)
530   MXNET_LAPACK_CWRAP_GEEV(d, double)
531 
532   #define MXNET_LAPACK
533 
534   // Note: Both MXNET_LAPACK_*getrf, MXNET_LAPACK_*getri can only be called with col-major format
535   // (MXNet) for performance.
536   #define MXNET_LAPACK_CWRAP_GETRF(prefix, dtype) \
537   inline int MXNET_LAPACK_##prefix##getrf(int matrix_layout, int m, int n, \
538                                           dtype *a, int lda, int *ipiv) { \
539     if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
540       CHECK(false) << "MXNET_LAPACK_" << #prefix << "getri implemented for col-major layout only"; \
541       return 1; \
542     } else { \
543       int info(0); \
544       prefix##getrf_(&m, &n, a, &lda, ipiv, &info); \
545       return info; \
546     } \
547   }
548   MXNET_LAPACK_CWRAP_GETRF(s, float)
549   MXNET_LAPACK_CWRAP_GETRF(d, double)
550 
551   #define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \
552   inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, int n, dtype *a, int lda, \
553                                           int *ipiv, dtype *work, int lwork) { \
554     if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
555       CHECK(false) << "MXNET_LAPACK_" << #prefix << "getri implemented for col-major layout only"; \
556       return 1; \
557     } else { \
558       int info(0); \
559       prefix##getri_(&n, a, &lda, ipiv, work, &lwork, &info); \
560       return info; \
561     } \
562   }
563   MXNET_LAPACK_CWRAP_GETRI(s, float)
564   MXNET_LAPACK_CWRAP_GETRI(d, double)
565 
566   #define MXNET_LAPACK_CWRAP_GESV(prefix, dtype) \
567   inline int MXNET_LAPACK_##prefix##gesv(int matrix_layout, \
568                                          int n, int nrhs, dtype *a, int lda, \
569                                          int *ipiv, dtype *b, int ldb) { \
570     if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
571       CHECK(false) << "MXNET_LAPACK_" << #prefix << "gesv implemented for col-major layout only"; \
572       return 1; \
573     } else { \
574       int info(0); \
575       prefix##gesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, &info); \
576       return info; \
577     } \
578   }
579   MXNET_LAPACK_CWRAP_GESV(s, float)
580   MXNET_LAPACK_CWRAP_GESV(d, double)
581 
582 #else
583 
584   #define MXNET_LAPACK_ROW_MAJOR 101
585   #define MXNET_LAPACK_COL_MAJOR 102
586 
587   // Define compilable stubs.
588   #define MXNET_LAPACK_CWRAPPER1(func, dtype) \
589   int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* a, int lda);
590 
591   #define MXNET_LAPACK_CWRAPPER2(func, dtype) \
592   int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \
593                           int lda, dtype* tau, dtype* work, int lwork);
594 
595   #define MXNET_LAPACK_CWRAPPER3(func, dtype) \
596   int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \
597                           int lda, dtype *w, dtype *work, int lwork, \
598                           int *iwork, int liwork);
599 
600   #define MXNET_LAPACK_CWRAPPER4(func, dtype) \
601   int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \
602                           dtype *a, int lda, int *ipiv);
603 
604   #define MXNET_LAPACK_CWRAPPER5(func, dtype) \
605   int MXNET_LAPACK_##func(int matrix_layout, int n, dtype *a, int lda, \
606                           int *ipiv, dtype *work, int lwork);
607 
608   #define MXNET_LAPACK_CWRAPPER6(func, dtype) \
609   int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* ut, \
610                           int ldut, dtype* s, dtype* v, int ldv, \
611                           dtype* work, int lwork);
612 
613   #define MXNET_LAPACK_CWRAPPER7(func, dtype) \
614   int MXNET_LAPACK_##func(int matrix_order, int n, int nrhs, dtype *a, \
615                           int lda, int *ipiv, dtype *b, int ldb); \
616 
617   #define MXNET_LAPACK_CWRAPPER8(func, dtype) \
618   int MXNET_LAPACK_##func(int matrix_layout, char jobvl, char jobvr, \
619                           int n, dtype *a, int lda, \
620                           dtype *wr, dtype *wi, \
621                           dtype *vl, int ldvl, dtype *vr, int ldvr, \
622                           dtype *work, int lwork); \
623 
624   #define MXNET_LAPACK_CWRAPPER9(func, dtype) \
625   int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \
626                           dtype *a, int lda, dtype *s, \
627                           dtype *u, int ldu, \
628                           dtype *vt, int ldvt, \
629                           dtype *work, int lwork, int *iwork);
630 
631   #define MXNET_LAPACK_UNAVAILABLE(func) \
632   int mxnet_lapack_##func(...);
633   MXNET_LAPACK_CWRAPPER1(spotrf, float)
634   MXNET_LAPACK_CWRAPPER1(dpotrf, double)
635   MXNET_LAPACK_CWRAPPER1(spotri, float)
636   MXNET_LAPACK_CWRAPPER1(dpotri, double)
637 
638   MXNET_LAPACK_UNAVAILABLE(sposv)
639   MXNET_LAPACK_UNAVAILABLE(dposv)
640 
641   MXNET_LAPACK_CWRAPPER2(sgelqf, float)
642   MXNET_LAPACK_CWRAPPER2(dgelqf, double)
643   MXNET_LAPACK_CWRAPPER2(sorglq, float)
644   MXNET_LAPACK_CWRAPPER2(dorglq, double)
645 
646   MXNET_LAPACK_CWRAPPER3(ssyevd, float)
647   MXNET_LAPACK_CWRAPPER3(dsyevd, double)
648 
649   MXNET_LAPACK_CWRAPPER4(sgetrf, float)
650   MXNET_LAPACK_CWRAPPER4(dgetrf, double)
651 
652   MXNET_LAPACK_CWRAPPER5(sgetri, float)
653   MXNET_LAPACK_CWRAPPER5(dgetri, double)
654 
655   MXNET_LAPACK_CWRAPPER6(sgesvd, float)
656   MXNET_LAPACK_CWRAPPER6(dgesvd, double)
657 
658   MXNET_LAPACK_CWRAPPER7(sgesv, float)
659   MXNET_LAPACK_CWRAPPER7(dgesv, double)
660 
661   MXNET_LAPACK_CWRAPPER8(sgeev, float)
662   MXNET_LAPACK_CWRAPPER8(dgeev, double)
663 
664   MXNET_LAPACK_CWRAPPER9(sgesdd, float)
665   MXNET_LAPACK_CWRAPPER9(dgesdd, double)
666 
667   #undef MXNET_LAPACK_CWRAPPER1
668   #undef MXNET_LAPACK_CWRAPPER2
669   #undef MXNET_LAPACK_CWRAPPER3
670   #undef MXNET_LAPACK_CWRAPPER4
671   #undef MXNET_LAPACK_UNAVAILABLE
672 #endif
673 
674 template <typename DType>
675 inline int MXNET_LAPACK_posv(int matrix_layout, char uplo, int n, int nrhs,
676   DType *a, int lda, DType *b, int ldb);
677 
678 template <>
679 inline int MXNET_LAPACK_posv<float>(int matrix_layout, char uplo, int n,
680   int nrhs, float *a, int lda, float *b, int ldb) {
681   return mxnet_lapack_sposv(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
682 }
683 
684 template <>
685 inline int MXNET_LAPACK_posv<double>(int matrix_layout, char uplo, int n,
686   int nrhs, double *a, int lda, double *b, int ldb) {
687   return mxnet_lapack_dposv(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
688 }
689 
690 #endif  // MXNET_OPERATOR_C_LAPACK_API_H_
691