1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file np_solve-inl.h
22 * \brief Placeholder for solve linear equation
23 */
24 #ifndef MXNET_OPERATOR_NUMPY_LINALG_NP_SOLVE_INL_H_
25 #define MXNET_OPERATOR_NUMPY_LINALG_NP_SOLVE_INL_H_
26
27 #include <mxnet/operator_util.h>
28 #include <vector>
29 #include "../../tensor/la_op.h"
30 #include "../../tensor/la_op-inl.h"
31 #include "../../linalg.h"
32 #include "../../operator_common.h"
33 #include "../../mshadow_op.h"
34
35 namespace mxnet {
36 namespace op {
37
38 using namespace mshadow;
39
40 template<typename xpu, typename DType>
41 void linalg_solve(const Tensor<xpu, 2, DType>& A,
42 const Tensor<xpu, 2, DType>& X,
43 const Tensor<xpu, 1, int>& ipiv,
44 Stream<xpu> *s);
45
46 template<typename xpu, typename DType>
47 void linalg_batch_solve(const Tensor<xpu, 3, DType>& A,
48 const Tensor<xpu, 3, DType>& X,
49 const Tensor<xpu, 2, int>& ipiv,
50 const mxnet::OpContext& ctx);
51
52 template<typename xpu, typename DType> inline
53 int linalg_dn_getrf_workspace_query(const Tensor<xpu, 2, DType>& A,
54 Stream<xpu> *s);
55
56 template<typename xpu, typename DType> inline
57 void linalg_dn_getrf(const Tensor<xpu, 2, DType>& A,
58 const Tensor<xpu, 1, int>& ipiv,
59 Stream<xpu> *s);
60
61 template<typename xpu, typename DType> inline
62 void linalg_dn_getrs(const Tensor<xpu, 2, DType>& A,
63 const Tensor<xpu, 2, DType>& X,
64 const Tensor<xpu, 1, int>& ipiv,
65 Stream<xpu> *s);
66
67 // kernel for transpose
68 struct SolveTypeTransposeHelper {
69 template<typename InDType, typename OutDType>
MapSolveTypeTransposeHelper70 MSHADOW_XINLINE static void Map(int i, const InDType *in_data, OutDType *out_data,
71 const int ncol1, const int ncol2, const int step) {
72 int idx = i / step, row = (i % step) / ncol1, col = (i % step) % ncol1;
73 out_data[idx * step + row + col * ncol2] = static_cast<OutDType>(in_data[i]);
74 }
75 };
76
77 template<typename xpu, typename DType>
check_solve(const Tensor<xpu,2,DType> & A,const Tensor<xpu,2,DType> & B)78 inline void check_solve(const Tensor<xpu, 2, DType>& A,
79 const Tensor<xpu, 2, DType>& B) {
80 CHECK_EQ(A.size(0), A.size(1)) << "A must bu square matrix";
81 CHECK_EQ(A.size(1), B.size(1)) << "A, B have incompatible sizes";
82 }
83
84 #define LINALG_CPU_SOLVE(fname, DType) \
85 template<> inline \
86 void linalg_solve<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
87 const Tensor<cpu, 2, DType>& X, \
88 const Tensor<cpu, 1, int>& ipiv, \
89 Stream<cpu> *s) { \
90 check_solve(A, X); \
91 const int N = X.size(1), nrhs = X.size(0); \
92 const int lda = (N == 0 ? 1 : N), ldx = (N == 0 ? 1 : N); \
93 int res(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, N, nrhs, \
94 A.dptr_, lda, ipiv.dptr_, X.dptr_, ldx)); \
95 CHECK_LE(res, 0) << #fname << ": U(" << res << ", " << res \
96 << ") is exactly zero. The factorization has been completed," \
97 << "but the factor U is exactly singular, so the solution could not be computed."; \
98 CHECK_GE(res, 0) << #fname << ": the " << -res \
99 << "-th argument had an illegal value"; \
100 }
101 LINALG_CPU_SOLVE(sgesv, float)
102 LINALG_CPU_SOLVE(dgesv, double)
103
104 #ifdef __CUDACC__
105
106 #if CUDA_VERSION >= 8000
107
108 #define LINALG_GPU_DN_GETRF_WORKSPACE_QUERY(fname, DType) \
109 template<> inline \
110 int linalg_dn_getrf_workspace_query<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
111 Stream<gpu> *s) { \
112 using namespace mxnet; \
113 using mshadow::gpu; \
114 int lwork(0); \
115 CUSOLVER_CALL(cusolver##fname##_bufferSize(Stream<gpu>::GetSolverHandle(s), \
116 A.size(1), A.size(1), A.dptr_, \
117 (A.size(1) == 0 ? 1 : A.size(1)), &lwork)); \
118 return lwork; \
119 }
120
121 #define LINALG_GPU_DN_GETRF(fname, DType) \
122 template<> inline \
123 void linalg_dn_getrf<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
124 const Tensor<gpu, 1, int>& ipiv, \
125 Stream<gpu> *s) { \
126 using namespace mxnet; \
127 using mshadow::gpu; \
128 Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \
129 const int lwork = linalg_dn_getrf_workspace_query(A, s); \
130 Storage::Handle workspace = Storage::Get()->Alloc(sizeof(DType) * lwork, Context::GPU()); \
131 CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \
132 A.size(1), A.size(1), A.dptr_, (A.size(1) == 0 ? 1 : A.size(1)), \
133 static_cast<DType*>(workspace.dptr), ipiv.dptr_, \
134 static_cast<int*>(info.dptr))); \
135 Storage::Get()->Free(info); \
136 Storage::Get()->Free(workspace); \
137 }
138
139 #define LINALG_GPU_DN_GETRS(fname, DType) \
140 template<> inline \
141 void linalg_dn_getrs<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
142 const Tensor<gpu, 2, DType>& X, \
143 const Tensor<gpu, 1, int>& ipiv, \
144 Stream<gpu> *s) { \
145 using namespace mxnet; \
146 using mshadow::gpu; \
147 const int N = A.size(0), nrhs = X.size(0); \
148 const int lda = (A.size(1) == 0 ? 1 : A.size(1)), ldx = (X.size(1) == 0 ? 1 : X.size(1)); \
149 Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \
150 CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \
151 CUBLAS_OP_N, N, nrhs, \
152 A.dptr_, lda, ipiv.dptr_, X.dptr_, ldx, \
153 static_cast<int*>(info.dptr))); \
154 Storage::Get()->Free(info); \
155 }
156
157 #define LINALG_GPU_SOLVE(DType) \
158 template<> inline \
159 void linalg_solve<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
160 const Tensor<gpu, 2, DType>& X, \
161 const Tensor<gpu, 1, int>& ipiv, \
162 Stream<gpu> *s) { \
163 using namespace mxnet; \
164 using mshadow::gpu; \
165 CHECK_NOTNULL(s); \
166 check_solve(A, X); \
167 linalg_dn_getrf(A, ipiv, s); \
168 linalg_dn_getrs(A, X, ipiv, s); \
169 }
170
171 #else // CUDA_VERSION >= 8000
172
173 #define LINALG_GPU_DN_GETRF_WORKSPACE_QUERY(fname, DType) \
174 template<> inline \
175 int linalg_dn_getrf_workspace_query<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
176 Stream<gpu> *s) { \
177 LOG(FATAL) << "Dn_getrf_workspace_query requires CUDA version >= 8.0!"; \
178 }
179
180 #define LINALG_GPU_DN_GETRF(fname, DType) \
181 template<> inline \
182 void linalg_dn_getrf<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
183 const Tensor<gpu, 1, int>& ipiv, \
184 Stream<gpu> *s) { \
185 LOG(FATAL) << "Dn_getrf requires CUDA version >= 8.0!"; \
186 }
187
188 #define LINALG_GPU_DN_GETRS(fname, DType) \
189 template<> inline \
190 void linalg_dn_getrs<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
191 const Tensor<gpu, 2, DType>& X, \
192 const Tensor<gpu, 1, int>& ipiv, \
193 Stream<gpu> *s) { \
194 LOG(FATAL) << "Dn_getrs requires CUDA version >= 8.0!"; \
195 }
196
197 #define LINALG_GPU_SOLVE(DType) \
198 template<> inline \
199 void linalg_solve<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
200 const Tensor<gpu, 2, DType>& X, \
201 const Tensor<gpu, 1, int>& ipiv, \
202 Stream<gpu> *s) { \
203 LOG(FATAL) << "gpu solve requires CUDA version >= 8.0!"; \
204 }
205
206 #endif // CUDA_VERSION >= 8000
207
208 LINALG_GPU_DN_GETRF_WORKSPACE_QUERY(DnSgetrf, float)
209 LINALG_GPU_DN_GETRF_WORKSPACE_QUERY(DnDgetrf, double)
210
211 LINALG_GPU_DN_GETRF(DnSgetrf, float)
212 LINALG_GPU_DN_GETRF(DnDgetrf, double)
213
214 LINALG_GPU_DN_GETRS(DnSgetrs, float)
215 LINALG_GPU_DN_GETRS(DnDgetrs, double)
216
217 LINALG_GPU_SOLVE(float)
218 LINALG_GPU_SOLVE(double)
219
220 #endif // __CUDACC__
221
222 #define LINALG_XPU_BATCH_SOLVE(xpu, DType) \
223 template<> inline \
224 void linalg_batch_solve<xpu, DType>(const Tensor<xpu, 3, DType>& A, \
225 const Tensor<xpu, 3, DType>& X, \
226 const Tensor<xpu, 2, int>& ipiv, \
227 const mxnet::OpContext& ctx) { \
228 Stream<xpu> *s = ctx.get_stream<xpu>(); \
229 for (index_t i = 0; i < A.size(0); ++i) { \
230 linalg_solve(A[i], X[i], ipiv[i], s); \
231 } \
232 }
233 LINALG_XPU_BATCH_SOLVE(cpu, float)
234 LINALG_XPU_BATCH_SOLVE(cpu, double)
235
236 #ifdef __CUDACC__
237
238 LINALG_XPU_BATCH_SOLVE(gpu, float)
239 LINALG_XPU_BATCH_SOLVE(gpu, double)
240
241 #endif // __CUDACC__
242
243 struct solve {
244 template<typename xpu, typename DType>
opsolve245 static void op(const Tensor<xpu, 3, DType>& A,
246 const Tensor<xpu, 3, DType>& X,
247 const Tensor<xpu, 2, int>& ipiv,
248 const OpContext& ctx,
249 const nnvm::NodeAttrs& attrs) {
250 linalg_batch_solve(A, X, ipiv, ctx); // ipiv for work_space in Lapacke_#gesv
251 }
252 };
253
254 template<typename xpu, int idim, int odim, int inum, int onum, typename laop>
LaOpForwardSolve(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)255 void LaOpForwardSolve(const nnvm::NodeAttrs& attrs,
256 const OpContext& ctx,
257 const std::vector<TBlob>& inputs,
258 const std::vector<OpReqType>& req,
259 const std::vector<TBlob>& outputs) {
260 using namespace mshadow;
261 CHECK_EQ(inputs.size(), inum);
262 CHECK_EQ(outputs.size(), onum);
263 CHECK_EQ(req.size(), onum);
264 MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
265 mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
266 const mxnet::TBlob& a_tblob = inputs[0];
267 const mxnet::TBlob& b_tblob = inputs[1];
268 const mxnet::TBlob& x_tblob = outputs[0];
269 const mxnet::TShape& a_shape = a_tblob.shape_;
270 mxnet::TShape b_shape(a_shape.ndim(), 1);
271 for (int i = 0; i < a_shape.ndim() - 1; ++i) { b_shape[i] = b_tblob.shape_[i]; }
272 if (b_tblob.shape_.ndim() == a_shape.ndim()) {
273 b_shape[a_shape.ndim() - 1] = b_tblob.shape_[a_shape.ndim() - 1];
274 }
275 const int ndim = a_shape.ndim();
276 mxnet::TShape ipiv_shape(a_shape);
277 ipiv_shape[ndim - 1] = 1;
278 if (0 == a_shape[ndim - 1] || 0 == a_shape[ndim - 2] ||
279 0 == b_shape[ndim - 1] || 0 == b_shape[ndim - 2]) { return; }
280
281 const int work_space_size =
282 sizeof(OType) * (a_shape.Size() + b_shape.Size()) + sizeof(int) * ipiv_shape.Size();
283 Tensor<xpu, 1, char> work_buffer =
284 ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(work_space_size), s);
285 MSHADOW_TYPE_SWITCH(a_tblob.type_flag_, AType, {
286 // cast type and transpose
287 mxnet_op::Kernel<SolveTypeTransposeHelper, xpu>::Launch(
288 s, a_shape.Size(),
289 a_tblob.dptr<AType>(),
290 reinterpret_cast<OType*>(work_buffer.dptr_),
291 a_shape[ndim - 1], a_shape[ndim - 2], a_shape[ndim - 1] * a_shape[ndim - 2]);
292 });
293 MSHADOW_TYPE_SWITCH(b_tblob.type_flag_, BType, {
294 // cast type and transpose
295 mxnet_op::Kernel<SolveTypeTransposeHelper, xpu>::Launch(
296 s, b_shape.Size(),
297 b_tblob.dptr<BType>(),
298 reinterpret_cast<OType*>(work_buffer.dptr_) + a_shape.Size(),
299 b_shape[ndim - 1], b_shape[ndim - 2], b_shape[ndim - 1] * b_shape[ndim - 2]);
300 });
301 // transpose shape
302 int temp = b_shape[ndim - 1];
303 b_shape[ndim - 1] = b_shape[ndim - 2];
304 b_shape[ndim - 2] = temp;
305 mxnet::TBlob a_transpose_tblob(reinterpret_cast<OType*>(work_buffer.dptr_),
306 a_shape, a_tblob.dev_mask(), a_tblob.dev_id());
307 mxnet::TBlob b_transpose_tblob(reinterpret_cast<OType*>(work_buffer.dptr_) + a_shape.Size(),
308 b_shape, b_tblob.dev_mask(), b_tblob.dev_id());
309 mxnet::TBlob ipiv_tblob(reinterpret_cast<int*>(
310 reinterpret_cast<OType*>(work_buffer.dptr_) + a_shape.Size() + b_shape.Size()),
311 ipiv_shape, b_tblob.dev_mask(), b_tblob.dev_id());
312
313 laop::op(a_transpose_tblob.FlatToKD<xpu, idim + 1, OType>(s),
314 b_transpose_tblob.FlatToKD<xpu, idim + 1, OType>(s),
315 ipiv_tblob.FlatToKD<xpu, idim, int>(s),
316 ctx,
317 attrs);
318 // X = transpose(B)
319 mxnet_op::Kernel<SolveTypeTransposeHelper, xpu>::Launch(
320 s, b_shape.Size(),
321 b_transpose_tblob.dptr<OType>(),
322 x_tblob.dptr<OType>(),
323 b_shape[ndim - 1], b_shape[ndim - 2], b_shape[ndim - 1] * b_shape[ndim - 2]);
324 });
325 }
326
327 // X = (inv_A) * B
328 struct solve_backward {
329 template<typename xpu, typename DType>
opsolve_backward330 static void op(const Tensor<xpu, 3, DType>& dX,
331 const Tensor<xpu, 3, DType>& inv_A,
332 const Tensor<xpu, 3, DType>& B,
333 const Tensor<xpu, 3, DType>& X,
334 const Tensor<xpu, 3, DType>& dA,
335 const Tensor<xpu, 3, DType>& dB,
336 const OpContext& ctx,
337 const nnvm::NodeAttrs& attrs) {
338 // (1) calcualte dB = trans(inv(A)) * dX
339 // (2) calcualte dA = dB * trans(X)
340 Stream<xpu> *s = ctx.get_stream<xpu>();
341 gemm2::op(inv_A, dX, dB, DType(1), true, false, s);
342 gemm2::op(dB, X, dA, DType(-1), false, true, s);
343 }
344 };
345
346 template<typename xpu, typename DType>
347 inline void batch_inverse(const Tensor<xpu, 3, DType>& inv_A,
348 const Tensor<xpu, 3, DType>& LU,
349 const Tensor<xpu, 2, int>& pivot,
350 const mxnet::OpContext& ctx);
351
352 #define CPU_BATCH_INVERSE(xpu, DType) \
353 template<> inline \
354 void batch_inverse<xpu, DType>(const Tensor<xpu, 3, DType>& inv_A, \
355 const Tensor<xpu, 3, DType>& LU, \
356 const Tensor<xpu, 2, int>& pivot, \
357 const mxnet::OpContext& ctx) { \
358 Stream<xpu> *s = ctx.get_stream<xpu>(); \
359 for (index_t i = 0; i < inv_A.size(0); ++i) { \
360 linalg_getrf(inv_A[i], pivot[i], true, s); \
361 const Tensor<xpu, 1, DType> work( \
362 LU[i].dptr_, Shape1(LU.size(1) * LU.size(2))); \
363 linalg_getri(inv_A[i], pivot[i], work, s); \
364 } \
365 }
CPU_BATCH_INVERSE(cpu,float)366 CPU_BATCH_INVERSE(cpu, float)
367 CPU_BATCH_INVERSE(cpu, double)
368
369 #ifdef __CUDACC__
370
371 // GETRF and GETRI only available with cuda8 or higher.
372 #if CUDA_VERSION >= 8000
373
374 #define GPU_BATCH_INVERSE(xpu, DType) \
375 template<> inline \
376 void batch_inverse<xpu, DType>(const Tensor<xpu, 3, DType>& inv_A, \
377 const Tensor<xpu, 3, DType>& LU, \
378 const Tensor<xpu, 2, int>& pivot, \
379 const mxnet::OpContext& ctx) { \
380 Stream<xpu> *s = ctx.get_stream<xpu>(); \
381 if (LU.dptr_ != inv_A.dptr_) Copy(LU, inv_A, s); \
382 linalg_batch_getrf(LU, pivot, true, s); \
383 linalg_batch_getri(inv_A, LU, pivot, s); \
384 }
385
386 #else // CUDA_VERSION >= 8000
387
388 #define GPU_BATCH_INVERSE(xpu, DType) \
389 template<> inline \
390 void batch_inverse<xpu, DType>(const Tensor<xpu, 3, DType>& inv_A, \
391 const Tensor<xpu, 3, DType>& LU, \
392 const Tensor<xpu, 2, int>& pivot, \
393 const mxnet::OpContext& ctx) { \
394 LOG(FATAL) << "gpu matrix inverse requires CUDA version >= 8.0!"; \
395 }
396
397 #endif // CUDA_VERSION >= 8000
398
399 GPU_BATCH_INVERSE(gpu, float)
400 GPU_BATCH_INVERSE(gpu, double)
401
402 #endif // __CUDACC__
403
404 template<typename xpu, int idim, int odim, int inum, int onum, typename laop>
405 void LaOpBackwardSolve(const nnvm::NodeAttrs& attrs,
406 const OpContext& ctx,
407 const std::vector<TBlob>& inputs,
408 const std::vector<OpReqType>& req,
409 const std::vector<TBlob>& outputs) {
410 using namespace mshadow;
411 CHECK_EQ(inputs.size(), inum);
412 CHECK_EQ(outputs.size(), onum);
413 CHECK_EQ(req.size(), onum);
414 MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
415 mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
416 const mxnet::TBlob& a_tblob = inputs[1];
417 const mxnet::TBlob& b_tblob = inputs[2];
418 const mxnet::TBlob& x_tblob = inputs[3];
419
420 const mxnet::TShape& a_shape = a_tblob.shape_;
421 mxnet::TShape b_shape(a_shape.ndim(), 1);
422 for (int i = 0; i < a_shape.ndim() - 1; ++i) { b_shape[i] = b_tblob.shape_[i]; }
423 if (b_tblob.shape_.ndim() == a_shape.ndim()) {
424 b_shape[a_shape.ndim() - 1] = b_tblob.shape_[a_shape.ndim() - 1];
425 }
426 const int ndim = a_shape.ndim();
427 const int N = a_shape[ndim - 1];
428 if (0 == a_shape[ndim - 1] || 0 == a_shape[ndim - 2] ||
429 0 == b_shape[ndim - 1] || 0 == b_shape[ndim - 2]) { return; }
430
431 const Tensor<xpu, idim + 1, OType> A = a_tblob.FlatToKD<xpu, idim + 1, OType>(s);
432 int work_space_size = sizeof(OType) * a_shape.Size(); // for inverse(A)
433 work_space_size += sizeof(OType) * a_shape.Size(); // for getri work space
434 work_space_size += 2 * sizeof(OType) * b_shape.Size(); // for B and X
435 work_space_size += sizeof(int) * A.size(0) * N; // for pivot work space
436 Tensor<xpu, 1, char> work_buffer =
437 ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(work_space_size), s);
438
439 MSHADOW_TYPE_SWITCH(a_tblob.type_flag_, AType, {
440 mxnet_op::Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
441 s, a_shape.Size(),
442 reinterpret_cast<OType*>(work_buffer.dptr_),
443 a_tblob.dptr<AType>());
444 });
445 mxnet::TBlob a_inverse_tblob(reinterpret_cast<OType*>(work_buffer.dptr_),
446 a_shape, a_tblob.dev_mask(), a_tblob.dev_id());
447 const Tensor<xpu, idim + 1, OType> inv_A = a_inverse_tblob.FlatToKD<xpu, idim + 1, OType>(s);
448
449 mxnet::TBlob lu_tblob(reinterpret_cast<OType*>(work_buffer.dptr_) + a_shape.Size(),
450 inv_A.shape_, a_tblob.dev_mask(), a_tblob.dev_id());
451 const Tensor<xpu, idim + 1, OType> LU = lu_tblob.FlatToKD<xpu, idim + 1, OType>(s);
452
453 MSHADOW_TYPE_SWITCH(b_tblob.type_flag_, BType, {
454 mxnet_op::Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
455 s, b_shape.Size(),
456 reinterpret_cast<OType*>(work_buffer.dptr_) + 2 * a_shape.Size(),
457 b_tblob.dptr<BType>());
458 });
459 mxnet::TBlob b_cp_tblob(reinterpret_cast<OType*>(work_buffer.dptr_) + 2 * a_shape.Size(),
460 b_shape, b_tblob.dev_mask(), b_tblob.dev_id());
461 const Tensor<xpu, idim + 1, OType> B = b_cp_tblob.FlatToKD<xpu, idim + 1, OType>(s);
462
463 MSHADOW_TYPE_SWITCH(x_tblob.type_flag_, XType, {
464 mxnet_op::Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
465 s, b_shape.Size(),
466 reinterpret_cast<OType*>(work_buffer.dptr_) + 2 * a_shape.Size() + b_shape.Size(),
467 x_tblob.dptr<XType>());
468 });
469 mxnet::TBlob x_cp_tblob(
470 reinterpret_cast<OType*>(work_buffer.dptr_) + 2 * a_shape.Size() + b_shape.Size(),
471 b_shape, b_tblob.dev_mask(), b_tblob.dev_id());
472 const Tensor<xpu, idim + 1, OType> X = x_cp_tblob.FlatToKD<xpu, idim + 1, OType>(s);
473
474 mxnet::TBlob pivot_tblob(reinterpret_cast<int*>(
475 reinterpret_cast<OType*>(work_buffer.dptr_) + 2 * a_shape.Size() + 2 * b_shape.Size()),
476 Shape2(A.size(0), N), a_tblob.dev_mask(), a_tblob.dev_id());
477 const Tensor<xpu, idim, int> pivot = pivot_tblob.FlatToKD<xpu, idim, int>(s);
478
479 // calculate inverse(A) on CPU or GPU
480 batch_inverse(inv_A, LU, pivot, ctx);
481 laop::op(inputs[0].FlatToKD<xpu, idim + 1, OType>(s),
482 inv_A,
483 B,
484 X,
485 outputs[0].FlatToKD<xpu, odim + 1, OType>(s),
486 outputs[1].FlatToKD<xpu, odim + 1, OType>(s),
487 ctx,
488 attrs);
489 });
490 }
491
492 } // namespace op
493 } // namespace mxnet
494
495 #endif // MXNET_OPERATOR_NUMPY_LINALG_NP_SOLVE_INL_H_
496