1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file np_solve-inl.h
22  * \brief Placeholder for solve linear equation
23  */
24 #ifndef MXNET_OPERATOR_NUMPY_LINALG_NP_SOLVE_INL_H_
25 #define MXNET_OPERATOR_NUMPY_LINALG_NP_SOLVE_INL_H_
26 
27 #include <mxnet/operator_util.h>
28 #include <vector>
29 #include "../../tensor/la_op.h"
30 #include "../../tensor/la_op-inl.h"
31 #include "../../linalg.h"
32 #include "../../operator_common.h"
33 #include "../../mshadow_op.h"
34 
35 namespace mxnet {
36 namespace op {
37 
38 using namespace mshadow;
39 
40 template<typename xpu, typename DType>
41 void linalg_solve(const Tensor<xpu, 2, DType>& A,
42                   const Tensor<xpu, 2, DType>& X,
43                   const Tensor<xpu, 1, int>& ipiv,
44                   Stream<xpu> *s);
45 
46 template<typename xpu, typename DType>
47 void linalg_batch_solve(const Tensor<xpu, 3, DType>& A,
48                         const Tensor<xpu, 3, DType>& X,
49                         const Tensor<xpu, 2, int>& ipiv,
50                         const mxnet::OpContext& ctx);
51 
52 template<typename xpu, typename DType> inline
53 int linalg_dn_getrf_workspace_query(const Tensor<xpu, 2, DType>& A,
54                                     Stream<xpu> *s);
55 
56 template<typename xpu, typename DType> inline
57 void linalg_dn_getrf(const Tensor<xpu, 2, DType>& A,
58                      const Tensor<xpu, 1, int>& ipiv,
59                      Stream<xpu> *s);
60 
61 template<typename xpu, typename DType> inline
62 void linalg_dn_getrs(const Tensor<xpu, 2, DType>& A,
63                      const Tensor<xpu, 2, DType>& X,
64                      const Tensor<xpu, 1, int>& ipiv,
65                      Stream<xpu> *s);
66 
67 // kernel for transpose
68 struct SolveTypeTransposeHelper {
69   template<typename InDType, typename OutDType>
MapSolveTypeTransposeHelper70   MSHADOW_XINLINE static void Map(int i, const InDType *in_data, OutDType *out_data,
71                                   const int ncol1, const int ncol2, const int step) {
72     int idx = i / step, row = (i % step) / ncol1, col = (i % step) % ncol1;
73     out_data[idx * step + row + col * ncol2] = static_cast<OutDType>(in_data[i]);
74   }
75 };
76 
77 template<typename xpu, typename DType>
check_solve(const Tensor<xpu,2,DType> & A,const Tensor<xpu,2,DType> & B)78 inline void check_solve(const Tensor<xpu, 2, DType>& A,
79                         const Tensor<xpu, 2, DType>& B) {
80   CHECK_EQ(A.size(0), A.size(1)) << "A must bu square matrix";
81   CHECK_EQ(A.size(1), B.size(1)) << "A, B have incompatible sizes";
82 }
83 
84 #define LINALG_CPU_SOLVE(fname, DType) \
85 template<> inline \
86 void linalg_solve<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
87                               const Tensor<cpu, 2, DType>& X, \
88                               const Tensor<cpu, 1, int>& ipiv, \
89                               Stream<cpu> *s) { \
90   check_solve(A, X); \
91   const int N = X.size(1), nrhs = X.size(0); \
92   const int lda = (N == 0 ? 1 : N), ldx = (N == 0 ? 1 : N); \
93   int res(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, N, nrhs, \
94                                A.dptr_, lda, ipiv.dptr_, X.dptr_, ldx)); \
95   CHECK_LE(res, 0) << #fname << ": U(" << res << ", " << res \
96     << ") is exactly zero. The factorization has been completed," \
97     << "but the factor U is exactly singular, so the solution could not be computed."; \
98   CHECK_GE(res, 0) << #fname << ": the " << -res \
99     << "-th argument had an illegal value"; \
100 }
101 LINALG_CPU_SOLVE(sgesv, float)
102 LINALG_CPU_SOLVE(dgesv, double)
103 
104 #ifdef __CUDACC__
105 
106 #if CUDA_VERSION >= 8000
107 
108 #define LINALG_GPU_DN_GETRF_WORKSPACE_QUERY(fname, DType) \
109 template<> inline \
110 int linalg_dn_getrf_workspace_query<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
111                                                 Stream<gpu> *s) { \
112   using namespace mxnet; \
113   using mshadow::gpu; \
114   int lwork(0); \
115   CUSOLVER_CALL(cusolver##fname##_bufferSize(Stream<gpu>::GetSolverHandle(s), \
116                                              A.size(1), A.size(1), A.dptr_, \
117                                              (A.size(1) == 0 ? 1 : A.size(1)), &lwork)); \
118   return lwork; \
119 }
120 
121 #define LINALG_GPU_DN_GETRF(fname, DType) \
122 template<> inline \
123 void linalg_dn_getrf<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
124                                  const Tensor<gpu, 1, int>& ipiv, \
125                                  Stream<gpu> *s) { \
126   using namespace mxnet; \
127   using mshadow::gpu; \
128   Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \
129   const int lwork = linalg_dn_getrf_workspace_query(A, s); \
130   Storage::Handle workspace = Storage::Get()->Alloc(sizeof(DType) * lwork, Context::GPU()); \
131   CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \
132                                 A.size(1), A.size(1), A.dptr_, (A.size(1) == 0 ? 1 : A.size(1)), \
133                                 static_cast<DType*>(workspace.dptr), ipiv.dptr_, \
134                                 static_cast<int*>(info.dptr))); \
135   Storage::Get()->Free(info); \
136   Storage::Get()->Free(workspace); \
137 }
138 
139 #define LINALG_GPU_DN_GETRS(fname, DType) \
140 template<> inline \
141 void linalg_dn_getrs<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
142                                  const Tensor<gpu, 2, DType>& X, \
143                                  const Tensor<gpu, 1, int>& ipiv, \
144                                  Stream<gpu> *s) { \
145   using namespace mxnet; \
146   using mshadow::gpu; \
147   const int N = A.size(0), nrhs = X.size(0); \
148   const int lda = (A.size(1) == 0 ? 1 : A.size(1)), ldx = (X.size(1) == 0 ? 1 : X.size(1)); \
149   Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \
150   CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \
151                                 CUBLAS_OP_N, N, nrhs, \
152                                 A.dptr_, lda, ipiv.dptr_, X.dptr_, ldx, \
153                                 static_cast<int*>(info.dptr))); \
154   Storage::Get()->Free(info); \
155 }
156 
157 #define LINALG_GPU_SOLVE(DType) \
158 template<> inline \
159 void linalg_solve<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
160                               const Tensor<gpu, 2, DType>& X, \
161                               const Tensor<gpu, 1, int>& ipiv, \
162                               Stream<gpu> *s) { \
163   using namespace mxnet; \
164   using mshadow::gpu; \
165   CHECK_NOTNULL(s); \
166   check_solve(A, X); \
167   linalg_dn_getrf(A, ipiv, s); \
168   linalg_dn_getrs(A, X, ipiv, s); \
169 }
170 
171 #else  // CUDA_VERSION >= 8000
172 
173 #define LINALG_GPU_DN_GETRF_WORKSPACE_QUERY(fname, DType) \
174 template<> inline \
175 int linalg_dn_getrf_workspace_query<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
176                                                 Stream<gpu> *s) { \
177   LOG(FATAL) << "Dn_getrf_workspace_query requires CUDA version >= 8.0!"; \
178 }
179 
180 #define LINALG_GPU_DN_GETRF(fname, DType) \
181 template<> inline \
182 void linalg_dn_getrf<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
183                                  const Tensor<gpu, 1, int>& ipiv, \
184                                  Stream<gpu> *s) { \
185   LOG(FATAL) << "Dn_getrf requires CUDA version >= 8.0!"; \
186 }
187 
188 #define LINALG_GPU_DN_GETRS(fname, DType) \
189 template<> inline \
190 void linalg_dn_getrs<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
191                                  const Tensor<gpu, 2, DType>& X, \
192                                  const Tensor<gpu, 1, int>& ipiv, \
193                                  Stream<gpu> *s) { \
194   LOG(FATAL) << "Dn_getrs requires CUDA version >= 8.0!"; \
195 }
196 
197 #define LINALG_GPU_SOLVE(DType) \
198 template<> inline \
199 void linalg_solve<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
200                               const Tensor<gpu, 2, DType>& X, \
201                               const Tensor<gpu, 1, int>& ipiv, \
202                               Stream<gpu> *s) { \
203   LOG(FATAL) << "gpu solve requires CUDA version >= 8.0!"; \
204 }
205 
206 #endif  // CUDA_VERSION >= 8000
207 
208 LINALG_GPU_DN_GETRF_WORKSPACE_QUERY(DnSgetrf, float)
209 LINALG_GPU_DN_GETRF_WORKSPACE_QUERY(DnDgetrf, double)
210 
211 LINALG_GPU_DN_GETRF(DnSgetrf, float)
212 LINALG_GPU_DN_GETRF(DnDgetrf, double)
213 
214 LINALG_GPU_DN_GETRS(DnSgetrs, float)
215 LINALG_GPU_DN_GETRS(DnDgetrs, double)
216 
217 LINALG_GPU_SOLVE(float)
218 LINALG_GPU_SOLVE(double)
219 
220 #endif  // __CUDACC__
221 
222 #define LINALG_XPU_BATCH_SOLVE(xpu, DType) \
223 template<> inline \
224 void linalg_batch_solve<xpu, DType>(const Tensor<xpu, 3, DType>& A, \
225                                     const Tensor<xpu, 3, DType>& X, \
226                                     const Tensor<xpu, 2, int>& ipiv, \
227                                     const mxnet::OpContext& ctx) { \
228   Stream<xpu> *s = ctx.get_stream<xpu>(); \
229   for (index_t i = 0; i < A.size(0); ++i) { \
230     linalg_solve(A[i], X[i], ipiv[i], s); \
231   } \
232 }
233 LINALG_XPU_BATCH_SOLVE(cpu, float)
234 LINALG_XPU_BATCH_SOLVE(cpu, double)
235 
236 #ifdef __CUDACC__
237 
238 LINALG_XPU_BATCH_SOLVE(gpu, float)
239 LINALG_XPU_BATCH_SOLVE(gpu, double)
240 
241 #endif  // __CUDACC__
242 
243 struct solve {
244   template<typename xpu, typename DType>
opsolve245   static void op(const Tensor<xpu, 3, DType>& A,
246                  const Tensor<xpu, 3, DType>& X,
247                  const Tensor<xpu, 2, int>& ipiv,
248                  const OpContext& ctx,
249                  const nnvm::NodeAttrs& attrs) {
250     linalg_batch_solve(A, X, ipiv, ctx);  // ipiv for work_space in Lapacke_#gesv
251   }
252 };
253 
254 template<typename xpu, int idim, int odim, int inum, int onum, typename laop>
LaOpForwardSolve(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)255 void LaOpForwardSolve(const nnvm::NodeAttrs& attrs,
256                       const OpContext& ctx,
257                       const std::vector<TBlob>& inputs,
258                       const std::vector<OpReqType>& req,
259                       const std::vector<TBlob>& outputs) {
260   using namespace mshadow;
261   CHECK_EQ(inputs.size(), inum);
262   CHECK_EQ(outputs.size(), onum);
263   CHECK_EQ(req.size(), onum);
264   MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
265     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
266     const mxnet::TBlob& a_tblob = inputs[0];
267     const mxnet::TBlob& b_tblob = inputs[1];
268     const mxnet::TBlob& x_tblob = outputs[0];
269     const mxnet::TShape& a_shape = a_tblob.shape_;
270     mxnet::TShape b_shape(a_shape.ndim(), 1);
271     for (int i = 0; i < a_shape.ndim() - 1; ++i) { b_shape[i] = b_tblob.shape_[i]; }
272     if (b_tblob.shape_.ndim() == a_shape.ndim()) {
273       b_shape[a_shape.ndim() - 1] = b_tblob.shape_[a_shape.ndim() - 1];
274     }
275     const int ndim = a_shape.ndim();
276     mxnet::TShape ipiv_shape(a_shape);
277     ipiv_shape[ndim - 1] = 1;
278     if (0 == a_shape[ndim - 1] || 0 == a_shape[ndim - 2] ||
279         0 == b_shape[ndim - 1] || 0 == b_shape[ndim - 2]) { return; }
280 
281     const int work_space_size =
282       sizeof(OType) * (a_shape.Size() + b_shape.Size()) + sizeof(int) * ipiv_shape.Size();
283     Tensor<xpu, 1, char> work_buffer =
284       ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(work_space_size), s);
285     MSHADOW_TYPE_SWITCH(a_tblob.type_flag_, AType, {
286       // cast type and transpose
287       mxnet_op::Kernel<SolveTypeTransposeHelper, xpu>::Launch(
288         s, a_shape.Size(),
289         a_tblob.dptr<AType>(),
290         reinterpret_cast<OType*>(work_buffer.dptr_),
291         a_shape[ndim - 1], a_shape[ndim - 2], a_shape[ndim - 1] * a_shape[ndim - 2]);
292     });
293     MSHADOW_TYPE_SWITCH(b_tblob.type_flag_, BType, {
294       // cast type and transpose
295       mxnet_op::Kernel<SolveTypeTransposeHelper, xpu>::Launch(
296         s, b_shape.Size(),
297         b_tblob.dptr<BType>(),
298         reinterpret_cast<OType*>(work_buffer.dptr_) + a_shape.Size(),
299         b_shape[ndim - 1], b_shape[ndim - 2], b_shape[ndim - 1] * b_shape[ndim - 2]);
300     });
301     // transpose shape
302     int temp = b_shape[ndim - 1];
303     b_shape[ndim - 1] = b_shape[ndim - 2];
304     b_shape[ndim - 2] = temp;
305     mxnet::TBlob a_transpose_tblob(reinterpret_cast<OType*>(work_buffer.dptr_),
306       a_shape, a_tblob.dev_mask(), a_tblob.dev_id());
307     mxnet::TBlob b_transpose_tblob(reinterpret_cast<OType*>(work_buffer.dptr_) + a_shape.Size(),
308       b_shape, b_tblob.dev_mask(), b_tblob.dev_id());
309     mxnet::TBlob ipiv_tblob(reinterpret_cast<int*>(
310       reinterpret_cast<OType*>(work_buffer.dptr_) + a_shape.Size() + b_shape.Size()),
311       ipiv_shape, b_tblob.dev_mask(), b_tblob.dev_id());
312 
313     laop::op(a_transpose_tblob.FlatToKD<xpu, idim + 1, OType>(s),
314              b_transpose_tblob.FlatToKD<xpu, idim + 1, OType>(s),
315              ipiv_tblob.FlatToKD<xpu, idim, int>(s),
316              ctx,
317              attrs);
318     // X = transpose(B)
319     mxnet_op::Kernel<SolveTypeTransposeHelper, xpu>::Launch(
320       s, b_shape.Size(),
321       b_transpose_tblob.dptr<OType>(),
322       x_tblob.dptr<OType>(),
323       b_shape[ndim - 1], b_shape[ndim - 2], b_shape[ndim - 1] * b_shape[ndim - 2]);
324   });
325 }
326 
327 // X = (inv_A) * B
328 struct solve_backward {
329   template<typename xpu, typename DType>
opsolve_backward330   static void op(const Tensor<xpu, 3, DType>& dX,
331                  const Tensor<xpu, 3, DType>& inv_A,
332                  const Tensor<xpu, 3, DType>& B,
333                  const Tensor<xpu, 3, DType>& X,
334                  const Tensor<xpu, 3, DType>& dA,
335                  const Tensor<xpu, 3, DType>& dB,
336                  const OpContext& ctx,
337                  const nnvm::NodeAttrs& attrs) {
338     // (1) calcualte dB = trans(inv(A)) * dX
339     // (2) calcualte dA = dB * trans(X)
340     Stream<xpu> *s = ctx.get_stream<xpu>();
341     gemm2::op(inv_A, dX, dB, DType(1), true, false, s);
342     gemm2::op(dB, X, dA, DType(-1), false, true, s);
343   }
344 };
345 
346 template<typename xpu, typename DType>
347 inline void batch_inverse(const Tensor<xpu, 3, DType>& inv_A,
348                           const Tensor<xpu, 3, DType>& LU,
349                           const Tensor<xpu, 2, int>& pivot,
350                           const mxnet::OpContext& ctx);
351 
352 #define CPU_BATCH_INVERSE(xpu, DType) \
353 template<> inline \
354 void batch_inverse<xpu, DType>(const Tensor<xpu, 3, DType>& inv_A, \
355                                const Tensor<xpu, 3, DType>& LU, \
356                                const Tensor<xpu, 2, int>& pivot, \
357                                const mxnet::OpContext& ctx) { \
358   Stream<xpu> *s = ctx.get_stream<xpu>(); \
359   for (index_t i = 0; i < inv_A.size(0); ++i) { \
360     linalg_getrf(inv_A[i], pivot[i], true, s); \
361     const Tensor<xpu, 1, DType> work( \
362       LU[i].dptr_, Shape1(LU.size(1) * LU.size(2))); \
363     linalg_getri(inv_A[i], pivot[i], work, s); \
364   } \
365 }
CPU_BATCH_INVERSE(cpu,float)366 CPU_BATCH_INVERSE(cpu, float)
367 CPU_BATCH_INVERSE(cpu, double)
368 
369 #ifdef __CUDACC__
370 
371 // GETRF and GETRI only available with cuda8 or higher.
372 #if CUDA_VERSION >= 8000
373 
374 #define GPU_BATCH_INVERSE(xpu, DType) \
375 template<> inline \
376 void batch_inverse<xpu, DType>(const Tensor<xpu, 3, DType>& inv_A, \
377                                const Tensor<xpu, 3, DType>& LU, \
378                                const Tensor<xpu, 2, int>& pivot, \
379                                const mxnet::OpContext& ctx) { \
380   Stream<xpu> *s = ctx.get_stream<xpu>(); \
381   if (LU.dptr_ != inv_A.dptr_) Copy(LU, inv_A, s); \
382   linalg_batch_getrf(LU, pivot, true, s); \
383   linalg_batch_getri(inv_A, LU, pivot, s); \
384 }
385 
386 #else  // CUDA_VERSION >= 8000
387 
388 #define GPU_BATCH_INVERSE(xpu, DType) \
389 template<> inline \
390 void batch_inverse<xpu, DType>(const Tensor<xpu, 3, DType>& inv_A, \
391                                const Tensor<xpu, 3, DType>& LU, \
392                                const Tensor<xpu, 2, int>& pivot, \
393                                const mxnet::OpContext& ctx) { \
394   LOG(FATAL) << "gpu matrix inverse requires CUDA version >= 8.0!"; \
395 }
396 
397 #endif  // CUDA_VERSION >= 8000
398 
399 GPU_BATCH_INVERSE(gpu, float)
400 GPU_BATCH_INVERSE(gpu, double)
401 
402 #endif  // __CUDACC__
403 
404 template<typename xpu, int idim, int odim, int inum, int onum, typename laop>
405 void LaOpBackwardSolve(const nnvm::NodeAttrs& attrs,
406                        const OpContext& ctx,
407                        const std::vector<TBlob>& inputs,
408                        const std::vector<OpReqType>& req,
409                        const std::vector<TBlob>& outputs) {
410   using namespace mshadow;
411   CHECK_EQ(inputs.size(), inum);
412   CHECK_EQ(outputs.size(), onum);
413   CHECK_EQ(req.size(), onum);
414   MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
415     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
416     const mxnet::TBlob& a_tblob = inputs[1];
417     const mxnet::TBlob& b_tblob = inputs[2];
418     const mxnet::TBlob& x_tblob = inputs[3];
419 
420     const mxnet::TShape& a_shape = a_tblob.shape_;
421     mxnet::TShape b_shape(a_shape.ndim(), 1);
422     for (int i = 0; i < a_shape.ndim() - 1; ++i) { b_shape[i] = b_tblob.shape_[i]; }
423     if (b_tblob.shape_.ndim() == a_shape.ndim()) {
424       b_shape[a_shape.ndim() - 1] = b_tblob.shape_[a_shape.ndim() - 1];
425     }
426     const int ndim = a_shape.ndim();
427     const int N = a_shape[ndim - 1];
428     if (0 == a_shape[ndim - 1] || 0 == a_shape[ndim - 2] ||
429         0 == b_shape[ndim - 1] || 0 == b_shape[ndim - 2]) { return; }
430 
431     const Tensor<xpu, idim + 1, OType> A = a_tblob.FlatToKD<xpu, idim + 1, OType>(s);
432     int work_space_size = sizeof(OType) * a_shape.Size();  // for inverse(A)
433     work_space_size += sizeof(OType) * a_shape.Size();  // for getri work space
434     work_space_size += 2 * sizeof(OType) * b_shape.Size();  // for B and X
435     work_space_size += sizeof(int) * A.size(0) * N;  // for pivot work space
436     Tensor<xpu, 1, char> work_buffer =
437       ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(work_space_size), s);
438 
439     MSHADOW_TYPE_SWITCH(a_tblob.type_flag_, AType, {
440       mxnet_op::Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
441         s, a_shape.Size(),
442         reinterpret_cast<OType*>(work_buffer.dptr_),
443         a_tblob.dptr<AType>());
444     });
445     mxnet::TBlob a_inverse_tblob(reinterpret_cast<OType*>(work_buffer.dptr_),
446       a_shape, a_tblob.dev_mask(), a_tblob.dev_id());
447     const Tensor<xpu, idim + 1, OType> inv_A = a_inverse_tblob.FlatToKD<xpu, idim + 1, OType>(s);
448 
449     mxnet::TBlob lu_tblob(reinterpret_cast<OType*>(work_buffer.dptr_) + a_shape.Size(),
450       inv_A.shape_, a_tblob.dev_mask(), a_tblob.dev_id());
451     const Tensor<xpu, idim + 1, OType> LU = lu_tblob.FlatToKD<xpu, idim + 1, OType>(s);
452 
453     MSHADOW_TYPE_SWITCH(b_tblob.type_flag_, BType, {
454       mxnet_op::Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
455         s, b_shape.Size(),
456         reinterpret_cast<OType*>(work_buffer.dptr_) + 2 * a_shape.Size(),
457         b_tblob.dptr<BType>());
458     });
459     mxnet::TBlob b_cp_tblob(reinterpret_cast<OType*>(work_buffer.dptr_) + 2 * a_shape.Size(),
460       b_shape, b_tblob.dev_mask(), b_tblob.dev_id());
461     const Tensor<xpu, idim + 1, OType> B = b_cp_tblob.FlatToKD<xpu, idim + 1, OType>(s);
462 
463     MSHADOW_TYPE_SWITCH(x_tblob.type_flag_, XType, {
464       mxnet_op::Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
465         s, b_shape.Size(),
466         reinterpret_cast<OType*>(work_buffer.dptr_) + 2 * a_shape.Size() + b_shape.Size(),
467         x_tblob.dptr<XType>());
468     });
469     mxnet::TBlob x_cp_tblob(
470       reinterpret_cast<OType*>(work_buffer.dptr_) + 2 * a_shape.Size() + b_shape.Size(),
471       b_shape, b_tblob.dev_mask(), b_tblob.dev_id());
472     const Tensor<xpu, idim + 1, OType> X = x_cp_tblob.FlatToKD<xpu, idim + 1, OType>(s);
473 
474     mxnet::TBlob pivot_tblob(reinterpret_cast<int*>(
475       reinterpret_cast<OType*>(work_buffer.dptr_) + 2 * a_shape.Size() + 2 * b_shape.Size()),
476       Shape2(A.size(0), N), a_tblob.dev_mask(), a_tblob.dev_id());
477     const Tensor<xpu, idim, int> pivot = pivot_tblob.FlatToKD<xpu, idim, int>(s);
478 
479     // calculate inverse(A) on CPU or GPU
480     batch_inverse(inv_A, LU, pivot, ctx);
481     laop::op(inputs[0].FlatToKD<xpu, idim + 1, OType>(s),
482              inv_A,
483              B,
484              X,
485              outputs[0].FlatToKD<xpu, odim + 1, OType>(s),
486              outputs[1].FlatToKD<xpu, odim + 1, OType>(s),
487              ctx,
488              attrs);
489   });
490 }
491 
492 }  // namespace op
493 }  // namespace mxnet
494 
495 #endif  // MXNET_OPERATOR_NUMPY_LINALG_NP_SOLVE_INL_H_
496