1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file np_pinv-inl.h
22  * \brief Placeholder for pinv
23  */
24 #ifndef MXNET_OPERATOR_NUMPY_LINALG_NP_PINV_INL_H_
25 #define MXNET_OPERATOR_NUMPY_LINALG_NP_PINV_INL_H_
26 
27 #include <mxnet/operator_util.h>
28 #include <vector>
29 #include <algorithm>
30 #include "../../operator_common.h"
31 #include "../../mshadow_op.h"
32 #include "../../tensor/elemwise_binary_op.h"
33 #include "../../tensor/elemwise_binary_broadcast_op.h"
34 #include "../../tensor/la_op.h"
35 #include "../../tensor/la_op-inl.h"
36 #include "../../tensor/matrix_op-inl.h"
37 
38 namespace mxnet {
39 namespace op {
40 
41 using namespace mshadow;
42 
43 struct PinvParam : public dmlc::Parameter<PinvParam> {
44   bool hermitian;
DMLC_DECLARE_PARAMETERPinvParam45   DMLC_DECLARE_PARAMETER(PinvParam) {
46     DMLC_DECLARE_FIELD(hermitian)
47     .set_default(false)
48     .describe("If True, A is assumed to be Hermitian (symmetric if real-valued).");
49   }
50 };
51 
52 struct PinvScalarRcondParam : public dmlc::Parameter<PinvScalarRcondParam> {
53   double rcond;
54   bool hermitian;
DMLC_DECLARE_PARAMETERPinvScalarRcondParam55   DMLC_DECLARE_PARAMETER(PinvScalarRcondParam) {
56     DMLC_DECLARE_FIELD(rcond)
57     .set_default(1e-15)
58     .describe("Cutoff for small singular values.");
59     DMLC_DECLARE_FIELD(hermitian)
60     .set_default(false)
61     .describe("If True, A is assumed to be Hermitian (symmetric if real-valued).");
62   }
63 };
64 
65 template<typename xpu, typename DType>
66 int linalg_gesdd_workspace_query(const int m, const int n,
67                                  const Tensor<xpu, 2, DType>& UT,
68                                  const Tensor<xpu, 1, DType>& S,
69                                  const Tensor<xpu, 2, DType>& V,
70                                  Stream<xpu>* s = 0);
71 
72 template<typename xpu, typename DType>
73 void linalg_gesdd(const int m, const int n,
74                   const Tensor<xpu, 2, DType>& UT,
75                   const Tensor<xpu, 1, DType>& S,
76                   const Tensor<xpu, 2, DType>& V,
77                   const Tensor<xpu, 1, DType>& work,
78                   Stream<xpu> *s = 0);
79 
80 template<typename xpu, typename DType>
81 void BatchSVDImpl(const int m, const int n,
82                   const Tensor<xpu, 3, DType>& UT,
83                   const Tensor<xpu, 2, DType>& S,
84                   const Tensor<xpu, 3, DType>& V,
85                   const Tensor<xpu, 1, DType>& work,
86                   Stream<xpu> *s = 0);
87 
88 #define LINALG_CPU_GESDD_WORKSPACE_QUERY(func, DType) \
89 template<> inline \
90 int linalg_gesdd_workspace_query<cpu, DType>(const int m, const int n, \
91                                              const Tensor<cpu, 2, DType>& UT, \
92                                              const Tensor<cpu, 1, DType>& S, \
93                                              const Tensor<cpu, 2, DType>& V, \
94                                              Stream<cpu> *s) { \
95   DType work(0.0); \
96   std::vector<int> iwork(8 * std::min(m, n), 0); \
97   if (m > n) { \
98     MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, n, m, \
99                         UT.dptr_, UT.stride_, S.dptr_, \
100                         V.dptr_, V.stride_, \
101                         UT.dptr_, UT.stride_, \
102                         &work, -1, iwork.data()); \
103   } else { \
104     MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, n, m, \
105                         V.dptr_, V.stride_, S.dptr_, \
106                         V.dptr_, V.stride_, \
107                         UT.dptr_, UT.stride_, \
108                         &work, -1, iwork.data()); \
109   } \
110   return static_cast<int>(work); \
111 }
112 
113 #define LINALG_CPU_GESDD(func, DType) \
114 template<> inline \
115 void linalg_gesdd<cpu, DType>(const int m, \
116                               const int n, \
117                               const Tensor<cpu, 2, DType>& UT, \
118                               const Tensor<cpu, 1, DType>& S, \
119                               const Tensor<cpu, 2, DType>& V, \
120                               const Tensor<cpu, 1, DType>& work, \
121                               Stream<cpu> *s) { \
122   std::vector<int> iwork(8 * std::min(m, n), 0); \
123   int res(0); \
124   if (m > n) { \
125     res = MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, n, m, \
126                               UT.dptr_, UT.stride_, S.dptr_, \
127                               V.dptr_, V.stride_, \
128                               UT.dptr_, UT.stride_, \
129                               work.dptr_, work.shape_.Size(), iwork.data()); \
130   } else { \
131     res = MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, n, m, \
132                               V.dptr_, V.stride_, S.dptr_, \
133                               V.dptr_, V.stride_, \
134                               UT.dptr_, UT.stride_, \
135                               work.dptr_, work.shape_.Size(), iwork.data()); \
136   } \
137   CHECK_GE(res, 0) << #func << ": the " << -res \
138     << "-th argument had an illegal value"; \
139   CHECK_LE(res, 0) << #func << " did not converge, updating process failed."; \
140 }
141 
142 LINALG_CPU_GESDD_WORKSPACE_QUERY(sgesdd, float)
143 LINALG_CPU_GESDD_WORKSPACE_QUERY(dgesdd, double)
144 
145 LINALG_CPU_GESDD(sgesdd, float)
146 LINALG_CPU_GESDD(dgesdd, double)
147 
148 #ifdef __CUDACC__
149 
150 #define LINALG_GPU_GESDD_WORKSPACE_QUERY(DType) \
151 template<> inline \
152 int linalg_gesdd_workspace_query<gpu, DType>(const int m, const int n, \
153                                              const Tensor<gpu, 2, DType>& U, \
154                                              const Tensor<gpu, 1, DType>& S, \
155                                              const Tensor<gpu, 2, DType>& VT, \
156                                              Stream<gpu> *s) { \
157   LOG(FATAL) << "Lapack gesdd workspace query routines is unsupported in gpu!"; \
158   return 0; \
159 }
160 
161 #define LINALG_GPU_GESDD(DType) \
162 template<> inline \
163 void linalg_gesdd<gpu, DType>(const int m, const int n, \
164                               const Tensor<gpu, 2, DType>& U, \
165                               const Tensor<gpu, 1, DType>& S, \
166                               const Tensor<gpu, 2, DType>& VT, \
167                               const Tensor<gpu, 1, DType>& work, \
168                               Stream<gpu> *s) { \
169   LOG(FATAL) << "Lapack gesdd routines is unsupported in gpu!"; \
170 }
171 
172 LINALG_GPU_GESDD_WORKSPACE_QUERY(float)
173 LINALG_GPU_GESDD_WORKSPACE_QUERY(double)
174 
175 LINALG_GPU_GESDD(float)
176 LINALG_GPU_GESDD(double)
177 
178 #endif  // __CUDACC__
179 
180 #define BATCH_SVD_IMPL_CPU(DType) \
181 template<> inline \
182 void BatchSVDImpl<cpu, DType>(const int m, const int n, \
183                               const Tensor<cpu, 3, DType>& UT, \
184                               const Tensor<cpu, 2, DType>& S, \
185                               const Tensor<cpu, 3, DType>& V, \
186                               const Tensor<cpu, 1, DType>& work, \
187                               Stream<cpu> *s) { \
188   for (index_t i = 0; i < S.size(0); ++i) { \
189     linalg_gesdd(m, n, UT[i], S[i], V[i], work, s); \
190   } \
191 }
192 
193 BATCH_SVD_IMPL_CPU(float)
194 BATCH_SVD_IMPL_CPU(double)
195 
196 #ifdef __CUDACC__
197 
198 #define BATCH_SVD_IMPL_GPU(DType) \
199 template<> inline \
200 void BatchSVDImpl<gpu, DType>(const int m, const int n, \
201                               const Tensor<gpu, 3, DType>& UT, \
202                               const Tensor<gpu, 2, DType>& S, \
203                               const Tensor<gpu, 3, DType>& V, \
204                               const Tensor<gpu, 1, DType>& work, \
205                               Stream<gpu> *s) { \
206   for (index_t i = 0; i < S.size(0); ++i) { \
207     linalg_gesvd(UT[i], S[i], V[i], work, s); \
208   } \
209 }
210 
211 BATCH_SVD_IMPL_GPU(float)
212 BATCH_SVD_IMPL_GPU(double)
213 
214 #endif  // __CUDACC__
215 
216 struct SingularValSmax {
217   template<typename DType>
MapSingularValSmax218   MSHADOW_XINLINE static void Map(int i, DType *smax_ptr, const DType *s_ptr,
219                                   const int length, const int lds) {
220     const DType *s_iptr = s_ptr + i * lds;
221     DType *smax_iptr = smax_ptr + i;
222     *smax_iptr = s_iptr[0];
223     for (int j = 1; j < length; ++j) {
224       *smax_iptr = s_iptr[j] > *smax_iptr ? s_iptr[j] : *smax_iptr;
225     }
226   }
227 };
228 
229 struct DiscardSmallSingularVal {
230   template<typename DType>
MapDiscardSmallSingularVal231   MSHADOW_XINLINE static void Map(int i, DType *s_ptr, const DType *large_ptr) {
232     if (large_ptr[i]) {
233       s_ptr[i] = DType(1) / s_ptr[i];
234     } else {
235       s_ptr[i] = DType(0);
236     }
237   }
238 };
239 
240 struct DiscardSmallSingularValWithScalarRcond {
241   template<typename DType>
MapDiscardSmallSingularValWithScalarRcond242   MSHADOW_XINLINE static void Map(int i, DType *s_ptr, const int length,
243                                   const int lds, const double rcond) {
244     DType *s_iptr = s_ptr + i * lds;
245     DType smax_i = s_iptr[0];
246     for (int j = 1; j < length; ++j) {
247       smax_i = s_iptr[j] > smax_i ? s_iptr[j] : smax_i;
248     }
249     for (int j = 0; j < length; ++j) {
250       s_iptr[j] = (s_iptr[j] > rcond * smax_i) ? (DType(1) / s_iptr[j]) : (DType(0));
251     }
252   }
253 };
254 
255 inline void GetPinvShape(const mxnet::TShape& a_shape,
256                          mxnet::TShape *ut_shape,
257                          mxnet::TShape *s_shape,
258                          mxnet::TShape *v_shape,
259                          mxnet::TShape *u_shape = nullptr,
260                          mxnet::TShape *vt_shape = nullptr) {
261   const int a_ndim = a_shape.ndim();
262   const int m = a_shape[a_ndim - 2];
263   const int n = a_shape[a_ndim - 1];
264 
265   // Calculate S shape.
266   std::vector<int> s_shape_vec(a_ndim - 1, -1);
267   for (int i = 0; i < a_ndim - 2; ++i) {
268     s_shape_vec[i] = a_shape[i];
269   }
270   s_shape_vec[a_ndim - 2] = std::min(m, n);
271   *s_shape = mxnet::TShape(s_shape_vec.begin(), s_shape_vec.end());
272 
273   std::vector<int> temp_shape_vec(a_ndim, -1);
274   for (int i = 0; i < a_ndim - 2; ++i) {
275     temp_shape_vec[i] = a_shape[i];
276   }
277   temp_shape_vec[a_ndim - 2] = std::min(m, n);
278   temp_shape_vec[a_ndim - 1] = std::min(m, n);
279   if (m >= n) {
280     // UT must have same shape as A.
281     *ut_shape = a_shape;
282     *v_shape = mxnet::TShape(temp_shape_vec.begin(), temp_shape_vec.end());
283     if (u_shape && vt_shape) {
284       *vt_shape = mxnet::TShape(temp_shape_vec.begin(), temp_shape_vec.end());
285       *u_shape = a_shape;
286       (*u_shape)[a_ndim - 2] = a_shape[a_ndim - 1];
287       (*u_shape)[a_ndim - 1] = a_shape[a_ndim - 2];
288     }
289   } else {
290     // V must have same shape as A.
291     *v_shape = a_shape;
292     *ut_shape = mxnet::TShape(temp_shape_vec.begin(), temp_shape_vec.end());
293     if (u_shape && vt_shape) {
294       *u_shape = mxnet::TShape(temp_shape_vec.begin(), temp_shape_vec.end());
295       *vt_shape = a_shape;
296       (*vt_shape)[a_ndim - 2] = a_shape[a_ndim - 1];
297       (*vt_shape)[a_ndim - 1] = a_shape[a_ndim - 2];
298     }
299   }
300 }
301 
302 inline void GetOrCheckCutoffAndLargeShape(const nnvm::NodeAttrs& attrs,
303                                           const mxnet::TShape& a_shape,
304                                           const mxnet::TShape& rcond_shape,
305                                           mxnet::TShape *cutoff_shape = nullptr,
306                                           mxnet::TShape *large_shape = nullptr) {
307   if (!shape_is_known(a_shape)) { return ; }
308   const int a_ndim = a_shape.ndim();
309   const int rcond_ndim = rcond_shape.ndim();
310   mxnet::TShape s_shape(a_ndim - 1, 1);
311   mxnet::TShape smax_shape(a_ndim - 1, 1);
312   mxnet::TShape new_rcond_shape(rcond_ndim + 1, 1);
313   // Get new rcond shape.
314   for (int i = 0; i < rcond_ndim; ++i) {
315     new_rcond_shape[i] = rcond_shape[i];
316   }
317   // Get Smax shape.
318   for (int i = 0; i < a_ndim - 2; ++i) {
319     s_shape[i] = a_shape[i];
320     smax_shape[i] = a_shape[i];
321   }
322   s_shape[s_shape.ndim() - 1] = std::min(a_shape[a_ndim - 2], a_shape[a_ndim - 1]);
323   smax_shape[smax_shape.ndim() - 1] = 1;
324   // Check cutoff = rcond[..., newaxis] * smax.
325   mxnet::ShapeVector in_shape_vec1({ new_rcond_shape, smax_shape });
326   mxnet::ShapeVector out_shape_vec1(1);
327   mxnet::op::BinaryBroadcastShape(attrs, &in_shape_vec1, &out_shape_vec1);
328   // Check large = s > cutoff.
329   mxnet::ShapeVector in_shape_vec2({ s_shape, out_shape_vec1[0] });
330   mxnet::ShapeVector out_shape_vec2(1);
331   mxnet::op::BinaryBroadcastShape(attrs, &in_shape_vec2, &out_shape_vec2);
332   // Check s = divide(1, s, where=large, out=s).
333   if (s_shape != out_shape_vec2[0]) {
334     LOG(FATAL) << "Error: non-broadcastable output operand with shape "
335       << s_shape << " doesn't match the broadcast shape " << out_shape_vec2[0];
336   }
337   if (cutoff_shape) {
338     *cutoff_shape = out_shape_vec1[0];
339   }
340   if (large_shape) {
341     *large_shape = out_shape_vec2[0];
342   }
343 }
344 
345 template<typename xpu>
SVDWorkspaceSize(const TBlob & a,const TBlob & pinv_a,const mxnet::TShape & u_shape,const mxnet::TShape & s_shape,const mxnet::TShape & v_shape,const std::vector<OpReqType> & req,const OpContext & ctx)346 size_t SVDWorkspaceSize(const TBlob& a,
347                         const TBlob& pinv_a,
348                         const mxnet::TShape& u_shape,
349                         const mxnet::TShape& s_shape,
350                         const mxnet::TShape& v_shape,
351                         const std::vector<OpReqType>& req,
352                         const OpContext& ctx) {
353   if (kNullOp == req[0]) { return 0U; }
354 
355   // Zero-size input, no need to launch kernel
356   if (0U == a.Size()) { return 0U; }
357 
358   size_t work_space_size = 0;
359   Stream<xpu> *s = ctx.get_stream<xpu>();
360   MSHADOW_SGL_DBL_TYPE_SWITCH(pinv_a.type_flag_, OType, {
361     const int a_ndim = a.shape_.ndim();
362     const int u_ndim = u_shape.ndim();
363     const int s_ndim = s_shape.ndim();
364     const int v_ndim = v_shape.ndim();
365     mxnet::TShape u_shape2 = Shape2(u_shape[u_ndim - 2], u_shape[u_ndim - 1]);
366     mxnet::TShape s_shape1 = Shape1(s_shape[s_ndim - 1]);
367     mxnet::TShape v_shape2 = Shape2(v_shape[v_ndim - 2], v_shape[v_ndim - 1]);
368     if (xpu::kDevCPU) {
369       std::vector<OType> u_vec(u_shape2.Size(), 0);
370       std::vector<OType> s_vec(s_shape1.Size(), 0);
371       std::vector<OType> v_vec(v_shape2.Size(), 0);
372       // For workspace size in linalg_gesdd.
373       work_space_size += linalg_gesdd_workspace_query(
374           a.shape_[a_ndim - 2], a.shape_[a_ndim - 1],
375           TBlob(u_vec.data(), u_shape2, a.dev_mask(), a.dev_id()).get<xpu, 2, OType>(s),
376           TBlob(s_vec.data(), s_shape1, a.dev_mask(), a.dev_id()).get<xpu, 1, OType>(s),
377           TBlob(v_vec.data(), v_shape2, a.dev_mask(), a.dev_id()).get<xpu, 2, OType>(s), s);
378     } else {
379       Storage::Handle u_handle =
380         Storage::Get()->Alloc(sizeof(OType) * u_shape2.Size(), Context::GPU());
381       Storage::Handle s_handle =
382         Storage::Get()->Alloc(sizeof(OType) * s_shape1.Size(), Context::GPU());
383       Storage::Handle v_handle =
384         Storage::Get()->Alloc(sizeof(OType) * v_shape2.Size(), Context::GPU());
385       TBlob u_data(static_cast<OType*>(u_handle.dptr), u_shape2, a.dev_mask(), a.dev_id());
386       TBlob s_data(static_cast<OType*>(s_handle.dptr), s_shape1, a.dev_mask(), a.dev_id());
387       TBlob v_data(static_cast<OType*>(v_handle.dptr), v_shape2, a.dev_mask(), a.dev_id());
388       // For workspace size in linalg_gesvd.
389       if (a.shape_[a_ndim - 2] >= a.shape_[a_ndim - 1]) {
390         work_space_size += linalg_gesvd_workspace_query(v_data.get<xpu, 2, OType>(s),
391                                                         s_data.get<xpu, 1, OType>(s),
392                                                         u_data.get<xpu, 2, OType>(s), s);
393       } else {
394         work_space_size += linalg_gesvd_workspace_query(u_data.get<xpu, 2, OType>(s),
395                                                         s_data.get<xpu, 1, OType>(s),
396                                                         v_data.get<xpu, 2, OType>(s), s);
397       }
398       Storage::Get()->Free(u_handle);
399       Storage::Get()->Free(s_handle);
400       Storage::Get()->Free(v_handle);
401     }
402   });
403   return work_space_size;
404 }
405 
406 // Calculates workspace size of pinv op forward.
407 template<typename xpu>
PinvForwardWorkspaceSize(const TBlob & a,const TBlob & rcond,const TBlob & pinv_a,const nnvm::NodeAttrs & attrs,const std::vector<OpReqType> & req,const OpContext & ctx)408 size_t PinvForwardWorkspaceSize(const TBlob& a,
409                                 const TBlob& rcond,
410                                 const TBlob& pinv_a,
411                                 const nnvm::NodeAttrs& attrs,
412                                 const std::vector<OpReqType>& req,
413                                 const OpContext& ctx) {
414   if (kNullOp == req[0]) { return 0U; }
415   // Zero-size input, no need to launch kernel
416   if (0U == a.Size()) { return 0U; }
417 
418   size_t work_space_size = 0;
419   mxnet::TShape u_shape, s_shape, v_shape;
420   GetPinvShape(a.shape_, &u_shape, &s_shape, &v_shape);
421 
422   MSHADOW_SGL_DBL_TYPE_SWITCH(pinv_a.type_flag_, OType, {
423     mxnet::TShape smax_shape = s_shape;
424     smax_shape[s_shape.ndim() - 1] = 1;
425     mxnet::TShape cutoff_shape;
426     mxnet::TShape large_shape;
427     GetOrCheckCutoffAndLargeShape(attrs, a.shape_, rcond.shape_, &cutoff_shape, &large_shape);
428     work_space_size +=  // For #gesdd_ or #gesvd work space size.
429       SVDWorkspaceSize<xpu>(a, pinv_a, u_shape, s_shape, v_shape, req, ctx);
430     work_space_size += rcond.shape_.Size();  // For rcond.
431     work_space_size += 2 * u_shape.Size();   // For UT.
432     work_space_size += s_shape.Size();       // For S.
433     work_space_size += 2 * v_shape.Size();   // For V.
434     work_space_size += smax_shape.Size();    // For Smax.
435     work_space_size += cutoff_shape.Size();  // For Cutoff.
436     work_space_size += large_shape.Size();   // For Large.
437     return work_space_size * sizeof(OType);
438   });
439   LOG(FATAL) << "InternalError: cannot reach here";
440   return 0U;
441 }
442 
GetTransAxis(const mxnet::TShape & in_shape)443 inline mxnet::TShape GetTransAxis(const mxnet::TShape& in_shape) {
444   const int in_ndim = in_shape.ndim();
445   std::vector<int> trans_axis(in_ndim, -1);
446   for (int i = 0; i < in_ndim - 2; ++i) { trans_axis[i] = i; }
447   trans_axis[in_ndim - 2] = in_ndim - 1;
448   trans_axis[in_ndim - 1] = in_ndim - 2;
449   return mxnet::TShape(trans_axis.begin(), trans_axis.end());
450 }
451 
452 template<typename xpu>
PinvOpForwardImpl(const TBlob & a,const TBlob & rcond,const TBlob & pinv_a,const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<OpReqType> & req,const Tensor<xpu,1,char> & workspace)453 void PinvOpForwardImpl(const TBlob& a,
454                        const TBlob& rcond,
455                        const TBlob& pinv_a,
456                        const nnvm::NodeAttrs& attrs,
457                        const OpContext& ctx,
458                        const std::vector<OpReqType>& req,
459                        const Tensor<xpu, 1, char>& workspace) {
460   Stream<xpu> *s = ctx.get_stream<xpu>();
461   const mxnet::TShape a_shape = a.shape_;
462   const mxnet::TShape rcond_shape = rcond.shape_;
463   const int a_ndim = a_shape.ndim();
464   const int rcond_ndim = rcond_shape.ndim();
465   mxnet::TShape rcond_shape_newaxis(rcond_ndim + 1, 1);
466   for (int i = 0; i < rcond_ndim; ++i) {
467     rcond_shape_newaxis[i] = rcond_shape[i];
468   }
469   mxnet::TShape s_shape;
470   mxnet::TShape u_shape;
471   mxnet::TShape ut_shape;
472   mxnet::TShape v_shape;
473   mxnet::TShape vt_shape;
474   GetPinvShape(a_shape, &u_shape, &s_shape, &v_shape, &ut_shape, &vt_shape);
475   mxnet::TShape smax_shape = s_shape;
476   smax_shape[s_shape.ndim() - 1] = 1;
477   mxnet::TShape s_shape_newaxis(s_shape.ndim() + 1, 1);
478   for (int i = 0; i < s_shape.ndim(); ++i) {
479     s_shape_newaxis[i] = s_shape[i];
480   }
481   mxnet::TShape cutoff_shape;
482   mxnet::TShape large_shape;
483   GetOrCheckCutoffAndLargeShape(attrs, a_shape, rcond_shape, &cutoff_shape, &large_shape);
484 
485   MSHADOW_SGL_DBL_TYPE_SWITCH(a.type_flag_, AType, {
486     MSHADOW_SGL_DBL_TYPE_SWITCH(pinv_a.type_flag_, DType, {
487       const size_t workspace_size = (workspace.shape_.Size() + sizeof(DType) - 1) / sizeof(DType);
488       const size_t lwork = workspace_size - rcond_shape_newaxis.Size()
489         - 2 * u_shape.Size() - s_shape.Size() - 2 * v_shape.Size() - smax_shape.Size()
490         - cutoff_shape.Size() - large_shape.Size();
491       DType *work_ptr = reinterpret_cast<DType*>(workspace.dptr_);
492       DType *rcond_ptr = work_ptr + lwork;
493       DType *ut_ptr = rcond_ptr + rcond_shape_newaxis.Size();
494       DType *u_ptr = ut_ptr + ut_shape.Size();
495       DType *s_ptr = u_ptr + u_shape.Size();
496       DType *v_ptr = s_ptr + s_shape.Size();
497       DType *vt_ptr = v_ptr + v_shape.Size();
498       DType *smax_ptr = vt_ptr + vt_shape.Size();
499       DType *cutoff_ptr = smax_ptr + smax_shape.Size();
500       DType *large_ptr = cutoff_ptr + cutoff_shape.Size();
501       // Step1: Calculate SVD.
502       TBlob work_data(work_ptr, Shape1(lwork), a.dev_mask(), a.dev_id());
503       TBlob u_data(u_ptr, u_shape, a.dev_mask(), a.dev_id());
504       TBlob ut_data(ut_ptr, ut_shape, a.dev_mask(), a.dev_id());
505       TBlob v_data(v_ptr, v_shape, a.dev_mask(), a.dev_id());
506       TBlob vt_data(vt_ptr, vt_shape, a.dev_mask(), a.dev_id());
507       TBlob s_data(s_ptr, s_shape, a.dev_mask(), a.dev_id());
508       // Noet: Only a_shape[a_ndim - 2] > a_shape[a_ndim - 1], need transpose operation.
509       if (a_shape[a_ndim - 2] > a_shape[a_ndim - 1]) {
510         mxnet_op::Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
511           s, a.Size(), u_ptr, a.dptr<AType>());
512         mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, u_data, ut_data,  // u_data: src, ut_data: res
513                                       GetTransAxis(u_data.shape_));
514         BatchSVDImpl(a_shape[a_ndim - 1], a_shape[a_ndim - 2],
515                      vt_data.FlatToKD<xpu, 3, DType>(s),
516                      s_data.FlatToKD<xpu, 2, DType>(s),
517                      ut_data.FlatToKD<xpu, 3, DType>(s),
518                      work_data.FlatToKD<xpu, 1, DType>(s), s);
519       } else {
520         mxnet_op::Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
521           s, a.Size(), v_ptr, a.dptr<AType>());
522         BatchSVDImpl(a_shape[a_ndim - 2], a_shape[a_ndim - 1],
523                      u_data.FlatToKD<xpu, 3, DType>(s),
524                      s_data.FlatToKD<xpu, 2, DType>(s),
525                      v_data.FlatToKD<xpu, 3, DType>(s),
526                      work_data.FlatToKD<xpu, 1, DType>(s), s);
527       }
528       TBlob smax_data(smax_ptr, smax_shape, a.dev_mask(), a.dev_id());
529       TBlob cutoff_data(cutoff_ptr, cutoff_shape, a.dev_mask(), a.dev_id());
530       TBlob large_data(large_ptr, large_shape, a.dev_mask(), a.dev_id());
531       TBlob rcond_data(rcond_ptr, rcond_shape_newaxis, a.dev_mask(), a.dev_id());
532       Tensor<xpu, 2, DType> S = s_data.FlatToKD<xpu, 2, DType>(s);
533       Tensor<xpu, 2, DType> Smax = smax_data.FlatToKD<xpu, 2, DType>(s);
534       mxnet_op::Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
535         s, rcond_shape_newaxis.Size(), rcond_ptr, rcond.dptr<AType>());
536       // Step2: Calculate Smax.
537       mxnet_op::Kernel<SingularValSmax, xpu>::Launch(
538         s, S.size(0), Smax.dptr_, S.dptr_, S.size(1), S.stride_);
539       // Step3: Calculate Cutoff.
540       std::vector<OpReqType> temp_req({kWriteTo});
541       mxnet::op::BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
542                                                                   {rcond_data, smax_data},
543                                                                   temp_req, {cutoff_data});
544       // Step4: Calculte Large.
545       mxnet::op::BinaryBroadcastCompute<xpu, op::mshadow_op::gt>(attrs, ctx,
546                                                                  {s_data, cutoff_data},
547                                                                  temp_req, {large_data});
548       // Step5: Discard small singular values.
549       mxnet_op::Kernel<DiscardSmallSingularVal, xpu>::Launch(
550         s, s_data.Size(), s_data.dptr<DType>(), large_data.dptr<DType>());
551       // Step6: Calculte matmul(transpose(v), multiply(s[..., newaxis], transpose(u))).
552       // Note: No need transpose when a_shape[a_ndim - 2] >= a_shape[a_ndim - 1]
553       if (a_shape[a_ndim - 2] <= a_shape[a_ndim - 1]) {
554         mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, u_data, ut_data,  // u_data: src, ut_data: res
555                                       GetTransAxis(u_data.shape_));
556         mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, v_data, vt_data,  // v_data: src, vt_data: res
557                                       GetTransAxis(v_data.shape_));
558       }
559       s_data = s_data.reshape(s_shape_newaxis);
560       u_data = ut_data.reshape(ut_shape);
561       mxnet::op::BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx, {s_data, ut_data},
562                                                                   temp_req, {u_data});
563       gemm2::op(vt_data.FlatToKD<xpu, 3, DType>(s),
564                 u_data.FlatToKD<xpu, 3, DType>(s),
565                 pinv_a.FlatToKD<xpu, 3, DType>(s),
566                 DType(1), false, false, s);
567     });
568   });
569 }
570 
571 template<typename xpu>
PinvOpForward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)572 void PinvOpForward(const nnvm::NodeAttrs& attrs,
573                    const OpContext& ctx,
574                    const std::vector<TBlob>& inputs,
575                    const std::vector<OpReqType>& req,
576                    const std::vector<TBlob>& outputs) {
577   CHECK_EQ(inputs.size(), 2U);
578   CHECK_EQ(outputs.size(), 1U);
579   CHECK_EQ(req.size(), 1U);
580   Stream<xpu> *s = ctx.get_stream<xpu>();
581   const TBlob& a = inputs[0];
582   const TBlob& rcond = inputs[1];
583   const TBlob& pinv_a = outputs[0];
584   const mxnet::TShape a_shape = a.shape_;
585 
586   if (kNullOp == req[0]) { return; }
587 
588   // Zero-size output, no need to launch kernel
589   if (0U == a.Size()) { return; }
590 
591   size_t workspace_size = PinvForwardWorkspaceSize<xpu>(a, rcond, pinv_a, attrs, req, ctx);
592   Tensor<xpu, 1, char> workspace =
593     ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
594   PinvOpForwardImpl<xpu>(a, rcond, pinv_a, attrs, ctx, req, workspace);
595 }
596 
597 // Calculates workspace size of pinv scalar rcond op forward.
598 template<typename xpu>
PinvScalarRcondForwardWorkspaceSize(const TBlob & a,const TBlob & pinv_a,const nnvm::NodeAttrs & attrs,const std::vector<OpReqType> & req,const OpContext & ctx)599 size_t PinvScalarRcondForwardWorkspaceSize(const TBlob& a,
600                                            const TBlob& pinv_a,
601                                            const nnvm::NodeAttrs& attrs,
602                                            const std::vector<OpReqType>& req,
603                                            const OpContext& ctx) {
604   if (kNullOp == req[0]) { return 0U; }
605   // Zero-size input, no need to launch kernel
606   if (0U == a.Size()) { return 0U; }
607 
608   size_t work_space_size = 0;
609   mxnet::TShape u_shape, s_shape, v_shape;
610   GetPinvShape(a.shape_, &u_shape, &s_shape, &v_shape);
611 
612   MSHADOW_SGL_DBL_TYPE_SWITCH(pinv_a.type_flag_, OType, {
613     mxnet::TShape smax_shape = s_shape;
614     smax_shape[s_shape.ndim() - 1] = 1;
615     work_space_size +=  // For #gesdd_ or #gesvd work space size.
616       SVDWorkspaceSize<xpu>(a, pinv_a, u_shape, s_shape, v_shape, req, ctx);
617     work_space_size += 2 * u_shape.Size();  // For UT.
618     work_space_size += s_shape.Size();      // For S.
619     work_space_size += 2 * v_shape.Size();  // For V.
620     return work_space_size * sizeof(OType);
621   });
622   LOG(FATAL) << "InternalError: cannot reach here";
623   return 0U;
624 }
625 
626 template<typename xpu>
PinvScalarRcondOpForwardImpl(const TBlob & a,const TBlob & pinv_a,const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<OpReqType> & req,const Tensor<xpu,1,char> & workspace)627 void PinvScalarRcondOpForwardImpl(const TBlob& a,
628                                   const TBlob& pinv_a,
629                                   const nnvm::NodeAttrs& attrs,
630                                   const OpContext& ctx,
631                                   const std::vector<OpReqType>& req,
632                                   const Tensor<xpu, 1, char>& workspace) {
633   Stream<xpu> *s = ctx.get_stream<xpu>();
634   const mxnet::TShape a_shape = a.shape_;
635   const int a_ndim = a_shape.ndim();
636 
637   mxnet::TShape s_shape;
638   mxnet::TShape u_shape;
639   mxnet::TShape ut_shape;
640   mxnet::TShape v_shape;
641   mxnet::TShape vt_shape;
642   GetPinvShape(a_shape, &u_shape, &s_shape, &v_shape, &ut_shape, &vt_shape);
643   mxnet::TShape s_shape_newaxis(s_shape.ndim() + 1, 1);
644   for (int i = 0; i < s_shape.ndim(); ++i) {
645     s_shape_newaxis[i] = s_shape[i];
646   }
647   MSHADOW_SGL_DBL_TYPE_SWITCH(a.type_flag_, AType, {
648     MSHADOW_SGL_DBL_TYPE_SWITCH(pinv_a.type_flag_, DType, {
649       const double rcond = nnvm::get<PinvScalarRcondParam>(attrs.parsed).rcond;
650       const size_t workspace_size = (workspace.shape_.Size() + sizeof(DType) - 1) / sizeof(DType);
651       const size_t lwork = workspace_size - 2 * u_shape.Size() - s_shape.Size()
652         - 2 * v_shape.Size();
653       DType *work_ptr = reinterpret_cast<DType*>(workspace.dptr_);
654       DType *u_ptr = work_ptr + lwork;
655       DType *ut_ptr = u_ptr + u_shape.Size();
656       DType *s_ptr = ut_ptr + ut_shape.Size();
657       DType *v_ptr = s_ptr + s_shape.Size();
658       DType *vt_ptr = v_ptr + v_shape.Size();
659       // Step1: Calculate SVD.
660       TBlob work_data(work_ptr, Shape1(lwork), a.dev_mask(), a.dev_id());
661       TBlob u_data(u_ptr, u_shape, a.dev_mask(), a.dev_id());
662       TBlob ut_data(ut_ptr, ut_shape, a.dev_mask(), a.dev_id());
663       TBlob v_data(v_ptr, v_shape, a.dev_mask(), a.dev_id());
664       TBlob vt_data(vt_ptr, vt_shape, a.dev_mask(), a.dev_id());
665       TBlob s_data(s_ptr, s_shape, a.dev_mask(), a.dev_id());
666       Tensor<xpu, 2, DType> S = s_data.FlatToKD<xpu, 2, DType>(s);
667       // Noet: Only a_shape[a_ndim - 2] > a_shape[a_ndim - 1], need transpose operation.
668       if (a_shape[a_ndim - 2] > a_shape[a_ndim - 1]) {
669         mxnet_op::Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
670           s, a.Size(), u_ptr, a.dptr<AType>());
671         mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, u_data, ut_data,  // u_data: src, ut_data: res
672                                       GetTransAxis(u_data.shape_));
673         BatchSVDImpl(a_shape[a_ndim - 1], a_shape[a_ndim - 2],
674                      vt_data.FlatToKD<xpu, 3, DType>(s),
675                      s_data.FlatToKD<xpu, 2, DType>(s),
676                      ut_data.FlatToKD<xpu, 3, DType>(s),
677                      work_data.FlatToKD<xpu, 1, DType>(s), s);
678       } else {
679         mxnet_op::Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
680           s, a.Size(), v_ptr, a.dptr<AType>());
681         BatchSVDImpl(a_shape[a_ndim - 2], a_shape[a_ndim - 1],
682                      u_data.FlatToKD<xpu, 3, DType>(s),
683                      s_data.FlatToKD<xpu, 2, DType>(s),
684                      v_data.FlatToKD<xpu, 3, DType>(s),
685                      work_data.FlatToKD<xpu, 1, DType>(s), s);
686       }
687       // Step2: Discard small singular values.
688       mxnet_op::Kernel<DiscardSmallSingularValWithScalarRcond, xpu>::Launch(
689         s, S.size(0), S.dptr_, S.size(1), S.stride_, rcond);
690       // Step3: Calculte matmul(transpose(v), multiply(s[..., newaxis], transpose(u))).
691       // Note: No need transpose when a_shape[a_ndim - 2] >= a_shape[a_ndim - 1]
692       if (a_shape[a_ndim - 2] <= a_shape[a_ndim - 1]) {
693         mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, u_data, ut_data,  // u_data: src, ut_data: res
694                                       GetTransAxis(u_data.shape_));
695         mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, v_data, vt_data,  // v_data: src, vt_data: res
696                                       GetTransAxis(v_data.shape_));
697       }
698       s_data = s_data.reshape(s_shape_newaxis);
699       u_data = ut_data.reshape(ut_shape);
700       mxnet::op::BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx, {s_data, ut_data},
701                                                                   {kWriteTo}, {u_data});
702       gemm2::op(vt_data.FlatToKD<xpu, 3, DType>(s),
703                 u_data.FlatToKD<xpu, 3, DType>(s),
704                 pinv_a.FlatToKD<xpu, 3, DType>(s),
705                 DType(1), false, false, s);
706     });
707   });
708 }
709 
710 template<typename xpu>
PinvScalarRcondOpForward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)711 void PinvScalarRcondOpForward(const nnvm::NodeAttrs& attrs,
712                               const OpContext& ctx,
713                               const std::vector<TBlob>& inputs,
714                               const std::vector<OpReqType>& req,
715                               const std::vector<TBlob>& outputs) {
716   CHECK_EQ(inputs.size(), 1U);
717   CHECK_EQ(outputs.size(), 1U);
718   CHECK_EQ(req.size(), 1U);
719   Stream<xpu> *s = ctx.get_stream<xpu>();
720   const TBlob& a = inputs[0];
721   const TBlob& pinv_a = outputs[0];
722 
723   if (kNullOp == req[0]) { return; }
724   // Zero-size output, no need to launch kernel
725   if (0U == a.Size()) { return; }
726 
727   // Calculate workspace size.
728   size_t workspace_size = PinvScalarRcondForwardWorkspaceSize<xpu>(a, pinv_a, attrs, req, ctx);
729   Tensor<xpu, 1, char> workspace =
730     ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
731   PinvScalarRcondOpForwardImpl<xpu>(a, pinv_a, attrs, ctx, req, workspace);
732 }
733 
734 }  // namespace op
735 }  // namespace mxnet
736 
737 #endif  // MXNET_OPERATOR_NUMPY_LINALG_NP_PINV_INL_H_
738