1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file np_pinv-inl.h
22 * \brief Placeholder for pinv
23 */
24 #ifndef MXNET_OPERATOR_NUMPY_LINALG_NP_PINV_INL_H_
25 #define MXNET_OPERATOR_NUMPY_LINALG_NP_PINV_INL_H_
26
27 #include <mxnet/operator_util.h>
28 #include <vector>
29 #include <algorithm>
30 #include "../../operator_common.h"
31 #include "../../mshadow_op.h"
32 #include "../../tensor/elemwise_binary_op.h"
33 #include "../../tensor/elemwise_binary_broadcast_op.h"
34 #include "../../tensor/la_op.h"
35 #include "../../tensor/la_op-inl.h"
36 #include "../../tensor/matrix_op-inl.h"
37
38 namespace mxnet {
39 namespace op {
40
41 using namespace mshadow;
42
43 struct PinvParam : public dmlc::Parameter<PinvParam> {
44 bool hermitian;
DMLC_DECLARE_PARAMETERPinvParam45 DMLC_DECLARE_PARAMETER(PinvParam) {
46 DMLC_DECLARE_FIELD(hermitian)
47 .set_default(false)
48 .describe("If True, A is assumed to be Hermitian (symmetric if real-valued).");
49 }
50 };
51
52 struct PinvScalarRcondParam : public dmlc::Parameter<PinvScalarRcondParam> {
53 double rcond;
54 bool hermitian;
DMLC_DECLARE_PARAMETERPinvScalarRcondParam55 DMLC_DECLARE_PARAMETER(PinvScalarRcondParam) {
56 DMLC_DECLARE_FIELD(rcond)
57 .set_default(1e-15)
58 .describe("Cutoff for small singular values.");
59 DMLC_DECLARE_FIELD(hermitian)
60 .set_default(false)
61 .describe("If True, A is assumed to be Hermitian (symmetric if real-valued).");
62 }
63 };
64
65 template<typename xpu, typename DType>
66 int linalg_gesdd_workspace_query(const int m, const int n,
67 const Tensor<xpu, 2, DType>& UT,
68 const Tensor<xpu, 1, DType>& S,
69 const Tensor<xpu, 2, DType>& V,
70 Stream<xpu>* s = 0);
71
72 template<typename xpu, typename DType>
73 void linalg_gesdd(const int m, const int n,
74 const Tensor<xpu, 2, DType>& UT,
75 const Tensor<xpu, 1, DType>& S,
76 const Tensor<xpu, 2, DType>& V,
77 const Tensor<xpu, 1, DType>& work,
78 Stream<xpu> *s = 0);
79
80 template<typename xpu, typename DType>
81 void BatchSVDImpl(const int m, const int n,
82 const Tensor<xpu, 3, DType>& UT,
83 const Tensor<xpu, 2, DType>& S,
84 const Tensor<xpu, 3, DType>& V,
85 const Tensor<xpu, 1, DType>& work,
86 Stream<xpu> *s = 0);
87
88 #define LINALG_CPU_GESDD_WORKSPACE_QUERY(func, DType) \
89 template<> inline \
90 int linalg_gesdd_workspace_query<cpu, DType>(const int m, const int n, \
91 const Tensor<cpu, 2, DType>& UT, \
92 const Tensor<cpu, 1, DType>& S, \
93 const Tensor<cpu, 2, DType>& V, \
94 Stream<cpu> *s) { \
95 DType work(0.0); \
96 std::vector<int> iwork(8 * std::min(m, n), 0); \
97 if (m > n) { \
98 MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, n, m, \
99 UT.dptr_, UT.stride_, S.dptr_, \
100 V.dptr_, V.stride_, \
101 UT.dptr_, UT.stride_, \
102 &work, -1, iwork.data()); \
103 } else { \
104 MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, n, m, \
105 V.dptr_, V.stride_, S.dptr_, \
106 V.dptr_, V.stride_, \
107 UT.dptr_, UT.stride_, \
108 &work, -1, iwork.data()); \
109 } \
110 return static_cast<int>(work); \
111 }
112
113 #define LINALG_CPU_GESDD(func, DType) \
114 template<> inline \
115 void linalg_gesdd<cpu, DType>(const int m, \
116 const int n, \
117 const Tensor<cpu, 2, DType>& UT, \
118 const Tensor<cpu, 1, DType>& S, \
119 const Tensor<cpu, 2, DType>& V, \
120 const Tensor<cpu, 1, DType>& work, \
121 Stream<cpu> *s) { \
122 std::vector<int> iwork(8 * std::min(m, n), 0); \
123 int res(0); \
124 if (m > n) { \
125 res = MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, n, m, \
126 UT.dptr_, UT.stride_, S.dptr_, \
127 V.dptr_, V.stride_, \
128 UT.dptr_, UT.stride_, \
129 work.dptr_, work.shape_.Size(), iwork.data()); \
130 } else { \
131 res = MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, n, m, \
132 V.dptr_, V.stride_, S.dptr_, \
133 V.dptr_, V.stride_, \
134 UT.dptr_, UT.stride_, \
135 work.dptr_, work.shape_.Size(), iwork.data()); \
136 } \
137 CHECK_GE(res, 0) << #func << ": the " << -res \
138 << "-th argument had an illegal value"; \
139 CHECK_LE(res, 0) << #func << " did not converge, updating process failed."; \
140 }
141
142 LINALG_CPU_GESDD_WORKSPACE_QUERY(sgesdd, float)
143 LINALG_CPU_GESDD_WORKSPACE_QUERY(dgesdd, double)
144
145 LINALG_CPU_GESDD(sgesdd, float)
146 LINALG_CPU_GESDD(dgesdd, double)
147
148 #ifdef __CUDACC__
149
150 #define LINALG_GPU_GESDD_WORKSPACE_QUERY(DType) \
151 template<> inline \
152 int linalg_gesdd_workspace_query<gpu, DType>(const int m, const int n, \
153 const Tensor<gpu, 2, DType>& U, \
154 const Tensor<gpu, 1, DType>& S, \
155 const Tensor<gpu, 2, DType>& VT, \
156 Stream<gpu> *s) { \
157 LOG(FATAL) << "Lapack gesdd workspace query routines is unsupported in gpu!"; \
158 return 0; \
159 }
160
161 #define LINALG_GPU_GESDD(DType) \
162 template<> inline \
163 void linalg_gesdd<gpu, DType>(const int m, const int n, \
164 const Tensor<gpu, 2, DType>& U, \
165 const Tensor<gpu, 1, DType>& S, \
166 const Tensor<gpu, 2, DType>& VT, \
167 const Tensor<gpu, 1, DType>& work, \
168 Stream<gpu> *s) { \
169 LOG(FATAL) << "Lapack gesdd routines is unsupported in gpu!"; \
170 }
171
172 LINALG_GPU_GESDD_WORKSPACE_QUERY(float)
173 LINALG_GPU_GESDD_WORKSPACE_QUERY(double)
174
175 LINALG_GPU_GESDD(float)
176 LINALG_GPU_GESDD(double)
177
178 #endif // __CUDACC__
179
180 #define BATCH_SVD_IMPL_CPU(DType) \
181 template<> inline \
182 void BatchSVDImpl<cpu, DType>(const int m, const int n, \
183 const Tensor<cpu, 3, DType>& UT, \
184 const Tensor<cpu, 2, DType>& S, \
185 const Tensor<cpu, 3, DType>& V, \
186 const Tensor<cpu, 1, DType>& work, \
187 Stream<cpu> *s) { \
188 for (index_t i = 0; i < S.size(0); ++i) { \
189 linalg_gesdd(m, n, UT[i], S[i], V[i], work, s); \
190 } \
191 }
192
193 BATCH_SVD_IMPL_CPU(float)
194 BATCH_SVD_IMPL_CPU(double)
195
196 #ifdef __CUDACC__
197
198 #define BATCH_SVD_IMPL_GPU(DType) \
199 template<> inline \
200 void BatchSVDImpl<gpu, DType>(const int m, const int n, \
201 const Tensor<gpu, 3, DType>& UT, \
202 const Tensor<gpu, 2, DType>& S, \
203 const Tensor<gpu, 3, DType>& V, \
204 const Tensor<gpu, 1, DType>& work, \
205 Stream<gpu> *s) { \
206 for (index_t i = 0; i < S.size(0); ++i) { \
207 linalg_gesvd(UT[i], S[i], V[i], work, s); \
208 } \
209 }
210
211 BATCH_SVD_IMPL_GPU(float)
212 BATCH_SVD_IMPL_GPU(double)
213
214 #endif // __CUDACC__
215
216 struct SingularValSmax {
217 template<typename DType>
MapSingularValSmax218 MSHADOW_XINLINE static void Map(int i, DType *smax_ptr, const DType *s_ptr,
219 const int length, const int lds) {
220 const DType *s_iptr = s_ptr + i * lds;
221 DType *smax_iptr = smax_ptr + i;
222 *smax_iptr = s_iptr[0];
223 for (int j = 1; j < length; ++j) {
224 *smax_iptr = s_iptr[j] > *smax_iptr ? s_iptr[j] : *smax_iptr;
225 }
226 }
227 };
228
229 struct DiscardSmallSingularVal {
230 template<typename DType>
MapDiscardSmallSingularVal231 MSHADOW_XINLINE static void Map(int i, DType *s_ptr, const DType *large_ptr) {
232 if (large_ptr[i]) {
233 s_ptr[i] = DType(1) / s_ptr[i];
234 } else {
235 s_ptr[i] = DType(0);
236 }
237 }
238 };
239
240 struct DiscardSmallSingularValWithScalarRcond {
241 template<typename DType>
MapDiscardSmallSingularValWithScalarRcond242 MSHADOW_XINLINE static void Map(int i, DType *s_ptr, const int length,
243 const int lds, const double rcond) {
244 DType *s_iptr = s_ptr + i * lds;
245 DType smax_i = s_iptr[0];
246 for (int j = 1; j < length; ++j) {
247 smax_i = s_iptr[j] > smax_i ? s_iptr[j] : smax_i;
248 }
249 for (int j = 0; j < length; ++j) {
250 s_iptr[j] = (s_iptr[j] > rcond * smax_i) ? (DType(1) / s_iptr[j]) : (DType(0));
251 }
252 }
253 };
254
255 inline void GetPinvShape(const mxnet::TShape& a_shape,
256 mxnet::TShape *ut_shape,
257 mxnet::TShape *s_shape,
258 mxnet::TShape *v_shape,
259 mxnet::TShape *u_shape = nullptr,
260 mxnet::TShape *vt_shape = nullptr) {
261 const int a_ndim = a_shape.ndim();
262 const int m = a_shape[a_ndim - 2];
263 const int n = a_shape[a_ndim - 1];
264
265 // Calculate S shape.
266 std::vector<int> s_shape_vec(a_ndim - 1, -1);
267 for (int i = 0; i < a_ndim - 2; ++i) {
268 s_shape_vec[i] = a_shape[i];
269 }
270 s_shape_vec[a_ndim - 2] = std::min(m, n);
271 *s_shape = mxnet::TShape(s_shape_vec.begin(), s_shape_vec.end());
272
273 std::vector<int> temp_shape_vec(a_ndim, -1);
274 for (int i = 0; i < a_ndim - 2; ++i) {
275 temp_shape_vec[i] = a_shape[i];
276 }
277 temp_shape_vec[a_ndim - 2] = std::min(m, n);
278 temp_shape_vec[a_ndim - 1] = std::min(m, n);
279 if (m >= n) {
280 // UT must have same shape as A.
281 *ut_shape = a_shape;
282 *v_shape = mxnet::TShape(temp_shape_vec.begin(), temp_shape_vec.end());
283 if (u_shape && vt_shape) {
284 *vt_shape = mxnet::TShape(temp_shape_vec.begin(), temp_shape_vec.end());
285 *u_shape = a_shape;
286 (*u_shape)[a_ndim - 2] = a_shape[a_ndim - 1];
287 (*u_shape)[a_ndim - 1] = a_shape[a_ndim - 2];
288 }
289 } else {
290 // V must have same shape as A.
291 *v_shape = a_shape;
292 *ut_shape = mxnet::TShape(temp_shape_vec.begin(), temp_shape_vec.end());
293 if (u_shape && vt_shape) {
294 *u_shape = mxnet::TShape(temp_shape_vec.begin(), temp_shape_vec.end());
295 *vt_shape = a_shape;
296 (*vt_shape)[a_ndim - 2] = a_shape[a_ndim - 1];
297 (*vt_shape)[a_ndim - 1] = a_shape[a_ndim - 2];
298 }
299 }
300 }
301
302 inline void GetOrCheckCutoffAndLargeShape(const nnvm::NodeAttrs& attrs,
303 const mxnet::TShape& a_shape,
304 const mxnet::TShape& rcond_shape,
305 mxnet::TShape *cutoff_shape = nullptr,
306 mxnet::TShape *large_shape = nullptr) {
307 if (!shape_is_known(a_shape)) { return ; }
308 const int a_ndim = a_shape.ndim();
309 const int rcond_ndim = rcond_shape.ndim();
310 mxnet::TShape s_shape(a_ndim - 1, 1);
311 mxnet::TShape smax_shape(a_ndim - 1, 1);
312 mxnet::TShape new_rcond_shape(rcond_ndim + 1, 1);
313 // Get new rcond shape.
314 for (int i = 0; i < rcond_ndim; ++i) {
315 new_rcond_shape[i] = rcond_shape[i];
316 }
317 // Get Smax shape.
318 for (int i = 0; i < a_ndim - 2; ++i) {
319 s_shape[i] = a_shape[i];
320 smax_shape[i] = a_shape[i];
321 }
322 s_shape[s_shape.ndim() - 1] = std::min(a_shape[a_ndim - 2], a_shape[a_ndim - 1]);
323 smax_shape[smax_shape.ndim() - 1] = 1;
324 // Check cutoff = rcond[..., newaxis] * smax.
325 mxnet::ShapeVector in_shape_vec1({ new_rcond_shape, smax_shape });
326 mxnet::ShapeVector out_shape_vec1(1);
327 mxnet::op::BinaryBroadcastShape(attrs, &in_shape_vec1, &out_shape_vec1);
328 // Check large = s > cutoff.
329 mxnet::ShapeVector in_shape_vec2({ s_shape, out_shape_vec1[0] });
330 mxnet::ShapeVector out_shape_vec2(1);
331 mxnet::op::BinaryBroadcastShape(attrs, &in_shape_vec2, &out_shape_vec2);
332 // Check s = divide(1, s, where=large, out=s).
333 if (s_shape != out_shape_vec2[0]) {
334 LOG(FATAL) << "Error: non-broadcastable output operand with shape "
335 << s_shape << " doesn't match the broadcast shape " << out_shape_vec2[0];
336 }
337 if (cutoff_shape) {
338 *cutoff_shape = out_shape_vec1[0];
339 }
340 if (large_shape) {
341 *large_shape = out_shape_vec2[0];
342 }
343 }
344
345 template<typename xpu>
SVDWorkspaceSize(const TBlob & a,const TBlob & pinv_a,const mxnet::TShape & u_shape,const mxnet::TShape & s_shape,const mxnet::TShape & v_shape,const std::vector<OpReqType> & req,const OpContext & ctx)346 size_t SVDWorkspaceSize(const TBlob& a,
347 const TBlob& pinv_a,
348 const mxnet::TShape& u_shape,
349 const mxnet::TShape& s_shape,
350 const mxnet::TShape& v_shape,
351 const std::vector<OpReqType>& req,
352 const OpContext& ctx) {
353 if (kNullOp == req[0]) { return 0U; }
354
355 // Zero-size input, no need to launch kernel
356 if (0U == a.Size()) { return 0U; }
357
358 size_t work_space_size = 0;
359 Stream<xpu> *s = ctx.get_stream<xpu>();
360 MSHADOW_SGL_DBL_TYPE_SWITCH(pinv_a.type_flag_, OType, {
361 const int a_ndim = a.shape_.ndim();
362 const int u_ndim = u_shape.ndim();
363 const int s_ndim = s_shape.ndim();
364 const int v_ndim = v_shape.ndim();
365 mxnet::TShape u_shape2 = Shape2(u_shape[u_ndim - 2], u_shape[u_ndim - 1]);
366 mxnet::TShape s_shape1 = Shape1(s_shape[s_ndim - 1]);
367 mxnet::TShape v_shape2 = Shape2(v_shape[v_ndim - 2], v_shape[v_ndim - 1]);
368 if (xpu::kDevCPU) {
369 std::vector<OType> u_vec(u_shape2.Size(), 0);
370 std::vector<OType> s_vec(s_shape1.Size(), 0);
371 std::vector<OType> v_vec(v_shape2.Size(), 0);
372 // For workspace size in linalg_gesdd.
373 work_space_size += linalg_gesdd_workspace_query(
374 a.shape_[a_ndim - 2], a.shape_[a_ndim - 1],
375 TBlob(u_vec.data(), u_shape2, a.dev_mask(), a.dev_id()).get<xpu, 2, OType>(s),
376 TBlob(s_vec.data(), s_shape1, a.dev_mask(), a.dev_id()).get<xpu, 1, OType>(s),
377 TBlob(v_vec.data(), v_shape2, a.dev_mask(), a.dev_id()).get<xpu, 2, OType>(s), s);
378 } else {
379 Storage::Handle u_handle =
380 Storage::Get()->Alloc(sizeof(OType) * u_shape2.Size(), Context::GPU());
381 Storage::Handle s_handle =
382 Storage::Get()->Alloc(sizeof(OType) * s_shape1.Size(), Context::GPU());
383 Storage::Handle v_handle =
384 Storage::Get()->Alloc(sizeof(OType) * v_shape2.Size(), Context::GPU());
385 TBlob u_data(static_cast<OType*>(u_handle.dptr), u_shape2, a.dev_mask(), a.dev_id());
386 TBlob s_data(static_cast<OType*>(s_handle.dptr), s_shape1, a.dev_mask(), a.dev_id());
387 TBlob v_data(static_cast<OType*>(v_handle.dptr), v_shape2, a.dev_mask(), a.dev_id());
388 // For workspace size in linalg_gesvd.
389 if (a.shape_[a_ndim - 2] >= a.shape_[a_ndim - 1]) {
390 work_space_size += linalg_gesvd_workspace_query(v_data.get<xpu, 2, OType>(s),
391 s_data.get<xpu, 1, OType>(s),
392 u_data.get<xpu, 2, OType>(s), s);
393 } else {
394 work_space_size += linalg_gesvd_workspace_query(u_data.get<xpu, 2, OType>(s),
395 s_data.get<xpu, 1, OType>(s),
396 v_data.get<xpu, 2, OType>(s), s);
397 }
398 Storage::Get()->Free(u_handle);
399 Storage::Get()->Free(s_handle);
400 Storage::Get()->Free(v_handle);
401 }
402 });
403 return work_space_size;
404 }
405
406 // Calculates workspace size of pinv op forward.
407 template<typename xpu>
PinvForwardWorkspaceSize(const TBlob & a,const TBlob & rcond,const TBlob & pinv_a,const nnvm::NodeAttrs & attrs,const std::vector<OpReqType> & req,const OpContext & ctx)408 size_t PinvForwardWorkspaceSize(const TBlob& a,
409 const TBlob& rcond,
410 const TBlob& pinv_a,
411 const nnvm::NodeAttrs& attrs,
412 const std::vector<OpReqType>& req,
413 const OpContext& ctx) {
414 if (kNullOp == req[0]) { return 0U; }
415 // Zero-size input, no need to launch kernel
416 if (0U == a.Size()) { return 0U; }
417
418 size_t work_space_size = 0;
419 mxnet::TShape u_shape, s_shape, v_shape;
420 GetPinvShape(a.shape_, &u_shape, &s_shape, &v_shape);
421
422 MSHADOW_SGL_DBL_TYPE_SWITCH(pinv_a.type_flag_, OType, {
423 mxnet::TShape smax_shape = s_shape;
424 smax_shape[s_shape.ndim() - 1] = 1;
425 mxnet::TShape cutoff_shape;
426 mxnet::TShape large_shape;
427 GetOrCheckCutoffAndLargeShape(attrs, a.shape_, rcond.shape_, &cutoff_shape, &large_shape);
428 work_space_size += // For #gesdd_ or #gesvd work space size.
429 SVDWorkspaceSize<xpu>(a, pinv_a, u_shape, s_shape, v_shape, req, ctx);
430 work_space_size += rcond.shape_.Size(); // For rcond.
431 work_space_size += 2 * u_shape.Size(); // For UT.
432 work_space_size += s_shape.Size(); // For S.
433 work_space_size += 2 * v_shape.Size(); // For V.
434 work_space_size += smax_shape.Size(); // For Smax.
435 work_space_size += cutoff_shape.Size(); // For Cutoff.
436 work_space_size += large_shape.Size(); // For Large.
437 return work_space_size * sizeof(OType);
438 });
439 LOG(FATAL) << "InternalError: cannot reach here";
440 return 0U;
441 }
442
GetTransAxis(const mxnet::TShape & in_shape)443 inline mxnet::TShape GetTransAxis(const mxnet::TShape& in_shape) {
444 const int in_ndim = in_shape.ndim();
445 std::vector<int> trans_axis(in_ndim, -1);
446 for (int i = 0; i < in_ndim - 2; ++i) { trans_axis[i] = i; }
447 trans_axis[in_ndim - 2] = in_ndim - 1;
448 trans_axis[in_ndim - 1] = in_ndim - 2;
449 return mxnet::TShape(trans_axis.begin(), trans_axis.end());
450 }
451
452 template<typename xpu>
PinvOpForwardImpl(const TBlob & a,const TBlob & rcond,const TBlob & pinv_a,const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<OpReqType> & req,const Tensor<xpu,1,char> & workspace)453 void PinvOpForwardImpl(const TBlob& a,
454 const TBlob& rcond,
455 const TBlob& pinv_a,
456 const nnvm::NodeAttrs& attrs,
457 const OpContext& ctx,
458 const std::vector<OpReqType>& req,
459 const Tensor<xpu, 1, char>& workspace) {
460 Stream<xpu> *s = ctx.get_stream<xpu>();
461 const mxnet::TShape a_shape = a.shape_;
462 const mxnet::TShape rcond_shape = rcond.shape_;
463 const int a_ndim = a_shape.ndim();
464 const int rcond_ndim = rcond_shape.ndim();
465 mxnet::TShape rcond_shape_newaxis(rcond_ndim + 1, 1);
466 for (int i = 0; i < rcond_ndim; ++i) {
467 rcond_shape_newaxis[i] = rcond_shape[i];
468 }
469 mxnet::TShape s_shape;
470 mxnet::TShape u_shape;
471 mxnet::TShape ut_shape;
472 mxnet::TShape v_shape;
473 mxnet::TShape vt_shape;
474 GetPinvShape(a_shape, &u_shape, &s_shape, &v_shape, &ut_shape, &vt_shape);
475 mxnet::TShape smax_shape = s_shape;
476 smax_shape[s_shape.ndim() - 1] = 1;
477 mxnet::TShape s_shape_newaxis(s_shape.ndim() + 1, 1);
478 for (int i = 0; i < s_shape.ndim(); ++i) {
479 s_shape_newaxis[i] = s_shape[i];
480 }
481 mxnet::TShape cutoff_shape;
482 mxnet::TShape large_shape;
483 GetOrCheckCutoffAndLargeShape(attrs, a_shape, rcond_shape, &cutoff_shape, &large_shape);
484
485 MSHADOW_SGL_DBL_TYPE_SWITCH(a.type_flag_, AType, {
486 MSHADOW_SGL_DBL_TYPE_SWITCH(pinv_a.type_flag_, DType, {
487 const size_t workspace_size = (workspace.shape_.Size() + sizeof(DType) - 1) / sizeof(DType);
488 const size_t lwork = workspace_size - rcond_shape_newaxis.Size()
489 - 2 * u_shape.Size() - s_shape.Size() - 2 * v_shape.Size() - smax_shape.Size()
490 - cutoff_shape.Size() - large_shape.Size();
491 DType *work_ptr = reinterpret_cast<DType*>(workspace.dptr_);
492 DType *rcond_ptr = work_ptr + lwork;
493 DType *ut_ptr = rcond_ptr + rcond_shape_newaxis.Size();
494 DType *u_ptr = ut_ptr + ut_shape.Size();
495 DType *s_ptr = u_ptr + u_shape.Size();
496 DType *v_ptr = s_ptr + s_shape.Size();
497 DType *vt_ptr = v_ptr + v_shape.Size();
498 DType *smax_ptr = vt_ptr + vt_shape.Size();
499 DType *cutoff_ptr = smax_ptr + smax_shape.Size();
500 DType *large_ptr = cutoff_ptr + cutoff_shape.Size();
501 // Step1: Calculate SVD.
502 TBlob work_data(work_ptr, Shape1(lwork), a.dev_mask(), a.dev_id());
503 TBlob u_data(u_ptr, u_shape, a.dev_mask(), a.dev_id());
504 TBlob ut_data(ut_ptr, ut_shape, a.dev_mask(), a.dev_id());
505 TBlob v_data(v_ptr, v_shape, a.dev_mask(), a.dev_id());
506 TBlob vt_data(vt_ptr, vt_shape, a.dev_mask(), a.dev_id());
507 TBlob s_data(s_ptr, s_shape, a.dev_mask(), a.dev_id());
508 // Noet: Only a_shape[a_ndim - 2] > a_shape[a_ndim - 1], need transpose operation.
509 if (a_shape[a_ndim - 2] > a_shape[a_ndim - 1]) {
510 mxnet_op::Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
511 s, a.Size(), u_ptr, a.dptr<AType>());
512 mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, u_data, ut_data, // u_data: src, ut_data: res
513 GetTransAxis(u_data.shape_));
514 BatchSVDImpl(a_shape[a_ndim - 1], a_shape[a_ndim - 2],
515 vt_data.FlatToKD<xpu, 3, DType>(s),
516 s_data.FlatToKD<xpu, 2, DType>(s),
517 ut_data.FlatToKD<xpu, 3, DType>(s),
518 work_data.FlatToKD<xpu, 1, DType>(s), s);
519 } else {
520 mxnet_op::Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
521 s, a.Size(), v_ptr, a.dptr<AType>());
522 BatchSVDImpl(a_shape[a_ndim - 2], a_shape[a_ndim - 1],
523 u_data.FlatToKD<xpu, 3, DType>(s),
524 s_data.FlatToKD<xpu, 2, DType>(s),
525 v_data.FlatToKD<xpu, 3, DType>(s),
526 work_data.FlatToKD<xpu, 1, DType>(s), s);
527 }
528 TBlob smax_data(smax_ptr, smax_shape, a.dev_mask(), a.dev_id());
529 TBlob cutoff_data(cutoff_ptr, cutoff_shape, a.dev_mask(), a.dev_id());
530 TBlob large_data(large_ptr, large_shape, a.dev_mask(), a.dev_id());
531 TBlob rcond_data(rcond_ptr, rcond_shape_newaxis, a.dev_mask(), a.dev_id());
532 Tensor<xpu, 2, DType> S = s_data.FlatToKD<xpu, 2, DType>(s);
533 Tensor<xpu, 2, DType> Smax = smax_data.FlatToKD<xpu, 2, DType>(s);
534 mxnet_op::Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
535 s, rcond_shape_newaxis.Size(), rcond_ptr, rcond.dptr<AType>());
536 // Step2: Calculate Smax.
537 mxnet_op::Kernel<SingularValSmax, xpu>::Launch(
538 s, S.size(0), Smax.dptr_, S.dptr_, S.size(1), S.stride_);
539 // Step3: Calculate Cutoff.
540 std::vector<OpReqType> temp_req({kWriteTo});
541 mxnet::op::BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
542 {rcond_data, smax_data},
543 temp_req, {cutoff_data});
544 // Step4: Calculte Large.
545 mxnet::op::BinaryBroadcastCompute<xpu, op::mshadow_op::gt>(attrs, ctx,
546 {s_data, cutoff_data},
547 temp_req, {large_data});
548 // Step5: Discard small singular values.
549 mxnet_op::Kernel<DiscardSmallSingularVal, xpu>::Launch(
550 s, s_data.Size(), s_data.dptr<DType>(), large_data.dptr<DType>());
551 // Step6: Calculte matmul(transpose(v), multiply(s[..., newaxis], transpose(u))).
552 // Note: No need transpose when a_shape[a_ndim - 2] >= a_shape[a_ndim - 1]
553 if (a_shape[a_ndim - 2] <= a_shape[a_ndim - 1]) {
554 mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, u_data, ut_data, // u_data: src, ut_data: res
555 GetTransAxis(u_data.shape_));
556 mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, v_data, vt_data, // v_data: src, vt_data: res
557 GetTransAxis(v_data.shape_));
558 }
559 s_data = s_data.reshape(s_shape_newaxis);
560 u_data = ut_data.reshape(ut_shape);
561 mxnet::op::BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx, {s_data, ut_data},
562 temp_req, {u_data});
563 gemm2::op(vt_data.FlatToKD<xpu, 3, DType>(s),
564 u_data.FlatToKD<xpu, 3, DType>(s),
565 pinv_a.FlatToKD<xpu, 3, DType>(s),
566 DType(1), false, false, s);
567 });
568 });
569 }
570
571 template<typename xpu>
PinvOpForward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)572 void PinvOpForward(const nnvm::NodeAttrs& attrs,
573 const OpContext& ctx,
574 const std::vector<TBlob>& inputs,
575 const std::vector<OpReqType>& req,
576 const std::vector<TBlob>& outputs) {
577 CHECK_EQ(inputs.size(), 2U);
578 CHECK_EQ(outputs.size(), 1U);
579 CHECK_EQ(req.size(), 1U);
580 Stream<xpu> *s = ctx.get_stream<xpu>();
581 const TBlob& a = inputs[0];
582 const TBlob& rcond = inputs[1];
583 const TBlob& pinv_a = outputs[0];
584 const mxnet::TShape a_shape = a.shape_;
585
586 if (kNullOp == req[0]) { return; }
587
588 // Zero-size output, no need to launch kernel
589 if (0U == a.Size()) { return; }
590
591 size_t workspace_size = PinvForwardWorkspaceSize<xpu>(a, rcond, pinv_a, attrs, req, ctx);
592 Tensor<xpu, 1, char> workspace =
593 ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
594 PinvOpForwardImpl<xpu>(a, rcond, pinv_a, attrs, ctx, req, workspace);
595 }
596
597 // Calculates workspace size of pinv scalar rcond op forward.
598 template<typename xpu>
PinvScalarRcondForwardWorkspaceSize(const TBlob & a,const TBlob & pinv_a,const nnvm::NodeAttrs & attrs,const std::vector<OpReqType> & req,const OpContext & ctx)599 size_t PinvScalarRcondForwardWorkspaceSize(const TBlob& a,
600 const TBlob& pinv_a,
601 const nnvm::NodeAttrs& attrs,
602 const std::vector<OpReqType>& req,
603 const OpContext& ctx) {
604 if (kNullOp == req[0]) { return 0U; }
605 // Zero-size input, no need to launch kernel
606 if (0U == a.Size()) { return 0U; }
607
608 size_t work_space_size = 0;
609 mxnet::TShape u_shape, s_shape, v_shape;
610 GetPinvShape(a.shape_, &u_shape, &s_shape, &v_shape);
611
612 MSHADOW_SGL_DBL_TYPE_SWITCH(pinv_a.type_flag_, OType, {
613 mxnet::TShape smax_shape = s_shape;
614 smax_shape[s_shape.ndim() - 1] = 1;
615 work_space_size += // For #gesdd_ or #gesvd work space size.
616 SVDWorkspaceSize<xpu>(a, pinv_a, u_shape, s_shape, v_shape, req, ctx);
617 work_space_size += 2 * u_shape.Size(); // For UT.
618 work_space_size += s_shape.Size(); // For S.
619 work_space_size += 2 * v_shape.Size(); // For V.
620 return work_space_size * sizeof(OType);
621 });
622 LOG(FATAL) << "InternalError: cannot reach here";
623 return 0U;
624 }
625
626 template<typename xpu>
PinvScalarRcondOpForwardImpl(const TBlob & a,const TBlob & pinv_a,const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<OpReqType> & req,const Tensor<xpu,1,char> & workspace)627 void PinvScalarRcondOpForwardImpl(const TBlob& a,
628 const TBlob& pinv_a,
629 const nnvm::NodeAttrs& attrs,
630 const OpContext& ctx,
631 const std::vector<OpReqType>& req,
632 const Tensor<xpu, 1, char>& workspace) {
633 Stream<xpu> *s = ctx.get_stream<xpu>();
634 const mxnet::TShape a_shape = a.shape_;
635 const int a_ndim = a_shape.ndim();
636
637 mxnet::TShape s_shape;
638 mxnet::TShape u_shape;
639 mxnet::TShape ut_shape;
640 mxnet::TShape v_shape;
641 mxnet::TShape vt_shape;
642 GetPinvShape(a_shape, &u_shape, &s_shape, &v_shape, &ut_shape, &vt_shape);
643 mxnet::TShape s_shape_newaxis(s_shape.ndim() + 1, 1);
644 for (int i = 0; i < s_shape.ndim(); ++i) {
645 s_shape_newaxis[i] = s_shape[i];
646 }
647 MSHADOW_SGL_DBL_TYPE_SWITCH(a.type_flag_, AType, {
648 MSHADOW_SGL_DBL_TYPE_SWITCH(pinv_a.type_flag_, DType, {
649 const double rcond = nnvm::get<PinvScalarRcondParam>(attrs.parsed).rcond;
650 const size_t workspace_size = (workspace.shape_.Size() + sizeof(DType) - 1) / sizeof(DType);
651 const size_t lwork = workspace_size - 2 * u_shape.Size() - s_shape.Size()
652 - 2 * v_shape.Size();
653 DType *work_ptr = reinterpret_cast<DType*>(workspace.dptr_);
654 DType *u_ptr = work_ptr + lwork;
655 DType *ut_ptr = u_ptr + u_shape.Size();
656 DType *s_ptr = ut_ptr + ut_shape.Size();
657 DType *v_ptr = s_ptr + s_shape.Size();
658 DType *vt_ptr = v_ptr + v_shape.Size();
659 // Step1: Calculate SVD.
660 TBlob work_data(work_ptr, Shape1(lwork), a.dev_mask(), a.dev_id());
661 TBlob u_data(u_ptr, u_shape, a.dev_mask(), a.dev_id());
662 TBlob ut_data(ut_ptr, ut_shape, a.dev_mask(), a.dev_id());
663 TBlob v_data(v_ptr, v_shape, a.dev_mask(), a.dev_id());
664 TBlob vt_data(vt_ptr, vt_shape, a.dev_mask(), a.dev_id());
665 TBlob s_data(s_ptr, s_shape, a.dev_mask(), a.dev_id());
666 Tensor<xpu, 2, DType> S = s_data.FlatToKD<xpu, 2, DType>(s);
667 // Noet: Only a_shape[a_ndim - 2] > a_shape[a_ndim - 1], need transpose operation.
668 if (a_shape[a_ndim - 2] > a_shape[a_ndim - 1]) {
669 mxnet_op::Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
670 s, a.Size(), u_ptr, a.dptr<AType>());
671 mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, u_data, ut_data, // u_data: src, ut_data: res
672 GetTransAxis(u_data.shape_));
673 BatchSVDImpl(a_shape[a_ndim - 1], a_shape[a_ndim - 2],
674 vt_data.FlatToKD<xpu, 3, DType>(s),
675 s_data.FlatToKD<xpu, 2, DType>(s),
676 ut_data.FlatToKD<xpu, 3, DType>(s),
677 work_data.FlatToKD<xpu, 1, DType>(s), s);
678 } else {
679 mxnet_op::Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
680 s, a.Size(), v_ptr, a.dptr<AType>());
681 BatchSVDImpl(a_shape[a_ndim - 2], a_shape[a_ndim - 1],
682 u_data.FlatToKD<xpu, 3, DType>(s),
683 s_data.FlatToKD<xpu, 2, DType>(s),
684 v_data.FlatToKD<xpu, 3, DType>(s),
685 work_data.FlatToKD<xpu, 1, DType>(s), s);
686 }
687 // Step2: Discard small singular values.
688 mxnet_op::Kernel<DiscardSmallSingularValWithScalarRcond, xpu>::Launch(
689 s, S.size(0), S.dptr_, S.size(1), S.stride_, rcond);
690 // Step3: Calculte matmul(transpose(v), multiply(s[..., newaxis], transpose(u))).
691 // Note: No need transpose when a_shape[a_ndim - 2] >= a_shape[a_ndim - 1]
692 if (a_shape[a_ndim - 2] <= a_shape[a_ndim - 1]) {
693 mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, u_data, ut_data, // u_data: src, ut_data: res
694 GetTransAxis(u_data.shape_));
695 mxnet::op::TransposeImpl<xpu>(ctx.run_ctx, v_data, vt_data, // v_data: src, vt_data: res
696 GetTransAxis(v_data.shape_));
697 }
698 s_data = s_data.reshape(s_shape_newaxis);
699 u_data = ut_data.reshape(ut_shape);
700 mxnet::op::BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx, {s_data, ut_data},
701 {kWriteTo}, {u_data});
702 gemm2::op(vt_data.FlatToKD<xpu, 3, DType>(s),
703 u_data.FlatToKD<xpu, 3, DType>(s),
704 pinv_a.FlatToKD<xpu, 3, DType>(s),
705 DType(1), false, false, s);
706 });
707 });
708 }
709
710 template<typename xpu>
PinvScalarRcondOpForward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)711 void PinvScalarRcondOpForward(const nnvm::NodeAttrs& attrs,
712 const OpContext& ctx,
713 const std::vector<TBlob>& inputs,
714 const std::vector<OpReqType>& req,
715 const std::vector<TBlob>& outputs) {
716 CHECK_EQ(inputs.size(), 1U);
717 CHECK_EQ(outputs.size(), 1U);
718 CHECK_EQ(req.size(), 1U);
719 Stream<xpu> *s = ctx.get_stream<xpu>();
720 const TBlob& a = inputs[0];
721 const TBlob& pinv_a = outputs[0];
722
723 if (kNullOp == req[0]) { return; }
724 // Zero-size output, no need to launch kernel
725 if (0U == a.Size()) { return; }
726
727 // Calculate workspace size.
728 size_t workspace_size = PinvScalarRcondForwardWorkspaceSize<xpu>(a, pinv_a, attrs, req, ctx);
729 Tensor<xpu, 1, char> workspace =
730 ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
731 PinvScalarRcondOpForwardImpl<xpu>(a, pinv_a, attrs, ctx, req, workspace);
732 }
733
734 } // namespace op
735 } // namespace mxnet
736
737 #endif // MXNET_OPERATOR_NUMPY_LINALG_NP_PINV_INL_H_
738