1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file mshadow_op.h
22  * \brief
23  * \author Bing Xu
24 */
25 #ifndef MXNET_OPERATOR_MSHADOW_OP_H_
26 #define MXNET_OPERATOR_MSHADOW_OP_H_
27 
28 #include <mxnet/base.h>
29 #include <mshadow/base.h>
30 #include "math.h"
31 #include "math_functions-inl.h"
32 #include "special_functions-inl.h"
33 #include "./operator_tune.h"
34 #include "./contrib/erfinv-inl.h"
35 
36 #ifdef __CUDACC__
37 #include <cuda_fp16.h>
38 #endif
39 
40 namespace mxnet {
41 namespace op {
42 namespace mshadow_op {
43 
44 using mshadow::isnan_typed::IsNan;
45 using mshadow::isinf_typed::IsInf;
46 
47 #ifdef __CUDA_ARCH__
48 __constant__ const float PI = 3.14159265358979323846;
49 __constant__ const float SELU_ALPHA = 1.6732632423543772848170429916717;
50 __constant__ const float SELU_LAMBDA = 1.0507009873554804934193349852946;
51 __constant__ const float SQRT_2 = 1.4142135623730950488016887242096;
52 #else
53 const float PI = 3.14159265358979323846;
54 const float SELU_ALPHA = 1.6732632423543772848170429916717;
55 const float SELU_LAMBDA = 1.0507009873554804934193349852946;
56 const float SQRT_2 = 1.4142135623730950488016887242096;
57 #endif
58 using std::enable_if;
59 using std::is_unsigned;
60 using std::is_integral;
61 
62 #define MXNET_UNARY_MATH_OP(name, expr) \
63   struct name : public mxnet_op::tunable { \
64     template<typename DType> \
65     MSHADOW_XINLINE static DType Map(DType a) { \
66       return DType(expr); \
67     } \
68   }
69 
70 #define MXNET_UNARY_MATH_OP_NC(name, expr) \
71   struct name : public mxnet_op::tunable { \
72     template<typename DType> \
73     MSHADOW_XINLINE static DType Map(DType a) { \
74       return (expr); \
75     } \
76   }
77 
78 #define MXNET_UNARY_LOGIC_OP_NC(name, expr) \
79   struct name : public mxnet_op::tunable { \
80     template<typename DType> \
81     MSHADOW_XINLINE static bool Map(DType a) { \
82       return (expr); \
83     } \
84   }
85 
86 #define MXNET_BINARY_MATH_OP(name, expr) \
87   struct name : public mxnet_op::tunable { \
88     template<typename DType> \
89     MSHADOW_XINLINE static DType Map(DType a, DType b) { \
90       return DType(expr); \
91     } \
92   }
93 
94 #define MXNET_BINARY_MATH_OP_NC(name, expr) \
95   struct name : public mxnet_op::tunable  { \
96     template<typename DType> \
97     MSHADOW_XINLINE static DType Map(DType a, DType b) { \
98       return (expr); \
99     } \
100   }
101 
102 #define MXNET_BINARY_MATH_OP_NC_WITH_BOOL(name, expr) \
103   struct name : public mxnet_op::tunable  { \
104     template<typename DType, \
105              typename std::enable_if<!std::is_same<DType, bool>::value, int>::type = 0> \
106     MSHADOW_XINLINE static DType Map(DType a, DType b) { \
107       return (expr); \
108     } \
109     MSHADOW_XINLINE static bool Map(bool a, bool b) { \
110       return (expr); \
111     } \
112   }
113 
114 #define MXNET_BINARY_LOGIC_OP_NC(name, expr) \
115   struct name : public mxnet_op::tunable  { \
116     template<typename DType> \
117     MSHADOW_XINLINE static bool Map(DType a, DType b) { \
118       return (expr); \
119     } \
120   }
121 
122 #define MXNET_SIMPLE_UNARY_MATH_OP(name) MXNET_UNARY_MATH_OP(name, math::name(a))
123 
124 #define MXNET_SIMPLE_BINARY_MATH_OP(name) MXNET_BINARY_MATH_OP(name, math::name(a, b))
125 
126 MXNET_UNARY_MATH_OP_NC(identity, a);
127 
128 MXNET_UNARY_MATH_OP(identity_grad, 1);
129 
130 struct identity_with_cast {
131   template<typename DTypeIn, typename DTypeOut>
Mapidentity_with_cast132   MSHADOW_XINLINE static void Map(index_t i, DTypeOut *out, DTypeIn *in) {
133     out[i] = DTypeOut(in[i]);
134   }
135 };
136 
137 struct true_divide : public mxnet_op::tunable  {
138   template<typename DType,
139            typename std::enable_if<!std::is_integral<DType>::value, int>::type = 0>
Maptrue_divide140   MSHADOW_XINLINE static DType Map(DType a, DType b) {
141     return a / b;
142   }
143 
144   template<typename DType,
145            typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Maptrue_divide146   MSHADOW_XINLINE static float Map(DType a, DType b) {
147     return static_cast<float>(a) / static_cast<float>(b);
148   }
149 
150   template<typename DType,
151            typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Maptrue_divide152   MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
153     return static_cast<mshadow::half::half_t>(a) / b;
154   }
155 
156   template<typename DType,
157            typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Maptrue_divide158   MSHADOW_XINLINE static float Map(DType a, float b) {
159     return static_cast<float>(a) / b;
160   }
161 
162   template<typename DType,
163            typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Maptrue_divide164   MSHADOW_XINLINE static double Map(DType a, double b) {
165     return static_cast<double>(a) / b;
166   }
167 };
168 
169 struct rtrue_divide : public mxnet_op::tunable  {
170   template<typename DType,
171            typename std::enable_if<!std::is_integral<DType>::value, int>::type = 0>
Maprtrue_divide172   MSHADOW_XINLINE static DType Map(DType a, DType b) {
173     return b / a;
174   }
175 
176   template<typename DType,
177            typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Maprtrue_divide178   MSHADOW_XINLINE static float Map(DType a, DType b) {
179     return static_cast<float>(b) / static_cast<float>(a);
180   }
181 
182   template<typename DType,
183            typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Maprtrue_divide184   MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
185     return b / static_cast<mshadow::half::half_t>(a);
186   }
187 
188   template<typename DType,
189            typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Maprtrue_divide190   MSHADOW_XINLINE static float Map(DType a, float b) {
191     return b / static_cast<float>(a);
192   }
193 
194   template<typename DType,
195            typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Maprtrue_divide196   MSHADOW_XINLINE static double Map(DType a, double b) {
197     return b / static_cast<double>(a);
198   }
199 };
200 
201 MXNET_BINARY_MATH_OP_NC(left, a);
202 
203 MXNET_BINARY_MATH_OP_NC(right, b);
204 
205 struct mixed_plus {
206   template<typename DType,
207            typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Mapmixed_plus208   MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
209     return static_cast<mshadow::half::half_t>(a) + b;
210   }
211 
212   template<typename DType,
213            typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
214                                    std::is_integral<DType>::value, int>::type = 0>
Mapmixed_plus215   MSHADOW_XINLINE static float Map(DType a, float b) {
216     return static_cast<float>(a) + b;
217   }
218 
219   template<typename DType,
220            typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
221                                    std::is_same<DType, float>::value ||
222                                    std::is_integral<DType>::value, int>::type = 0>
Mapmixed_plus223   MSHADOW_XINLINE static double Map(DType a, double b) {
224     return static_cast<double>(a) + b;
225   }
226 };
227 
228 struct mixed_minus {
229   template<typename DType,
230            typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Mapmixed_minus231   MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
232     return static_cast<mshadow::half::half_t>(a) - b;
233   }
234 
235   template<typename DType,
236            typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
237                                    std::is_integral<DType>::value, int>::type = 0>
Mapmixed_minus238   MSHADOW_XINLINE static float Map(DType a, float b) {
239     return static_cast<float>(a) - b;
240   }
241 
242   template<typename DType,
243            typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
244                                    std::is_same<DType, float>::value ||
245                                    std::is_integral<DType>::value, int>::type = 0>
Mapmixed_minus246   MSHADOW_XINLINE static double Map(DType a, double b) {
247     return static_cast<double>(a) - b;
248   }
249 };
250 
251 struct mixed_rminus {
252   template<typename DType,
253            typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Mapmixed_rminus254   MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
255     return b - static_cast<mshadow::half::half_t>(a);
256   }
257 
258   template<typename DType,
259            typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
260                                    std::is_integral<DType>::value, int>::type = 0>
Mapmixed_rminus261   MSHADOW_XINLINE static float Map(DType a, float b) {
262     return b - static_cast<float>(a);
263   }
264 
265   template<typename DType,
266            typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
267                                    std::is_same<DType, float>::value ||
268                                    std::is_integral<DType>::value, int>::type = 0>
Mapmixed_rminus269   MSHADOW_XINLINE static double Map(DType a, double b) {
270     return b - static_cast<double>(a);
271   }
272 };
273 
274 struct mixed_mul {
275   template<typename DType,
276            typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Mapmixed_mul277   MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
278     return static_cast<mshadow::half::half_t>(a) * b;
279   }
280 
281   template<typename DType,
282            typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
283                                    std::is_integral<DType>::value, int>::type = 0>
Mapmixed_mul284   MSHADOW_XINLINE static float Map(DType a, float b) {
285     return static_cast<float>(a) * b;
286   }
287 
288   template<typename DType,
289            typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
290                                    std::is_same<DType, float>::value ||
291                                    std::is_integral<DType>::value, int>::type = 0>
Mapmixed_mul292   MSHADOW_XINLINE static double Map(DType a, double b) {
293     return static_cast<double>(a) * b;
294   }
295 };
296 
297 struct mixed_power {
298   template<typename DType,
299            typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Mapmixed_power300   MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
301     return static_cast<mshadow::half::half_t>(math::pow(a, b));
302   }
303 
304   template<typename DType,
305            typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
306                                    std::is_integral<DType>::value, int>::type = 0>
Mapmixed_power307   MSHADOW_XINLINE static float Map(DType a, float b) {
308     return static_cast<float>(math::pow(a, b));
309   }
310 
311   template<typename DType,
312            typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
313                                    std::is_same<DType, float>::value ||
314                                    std::is_integral<DType>::value, int>::type = 0>
Mapmixed_power315   MSHADOW_XINLINE static double Map(DType a, double b) {
316     return static_cast<double>(math::pow(a, b));
317   }
318 };
319 
320 struct mixed_rpower {
321   template<typename DType,
322            typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Mapmixed_rpower323   MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
324     return static_cast<mshadow::half::half_t>(math::pow(b, a));
325   }
326 
327   template<typename DType,
328            typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
329                                    std::is_integral<DType>::value, int>::type = 0>
Mapmixed_rpower330   MSHADOW_XINLINE static float Map(DType a, float b) {
331     return static_cast<float>(math::pow(b, a));
332   }
333 
334   template<typename DType,
335            typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
336                                    std::is_same<DType, float>::value ||
337                                    std::is_integral<DType>::value, int>::type = 0>
Mapmixed_rpower338   MSHADOW_XINLINE static double Map(DType a, double b) {
339     return static_cast<double>(math::pow(b, a));
340   }
341 };
342 
343 #pragma GCC diagnostic push
344 #if __GNUC__ >= 7
345 #pragma GCC diagnostic ignored "-Wint-in-bool-context"
346 #pragma GCC diagnostic ignored "-Wbool-compare"
347 #endif
348 MXNET_BINARY_MATH_OP_NC_WITH_BOOL(mul, a * b);
349 
350 MXNET_BINARY_MATH_OP_NC_WITH_BOOL(div, a / b);
351 
352 MXNET_BINARY_MATH_OP_NC_WITH_BOOL(plus, a + b);
353 
354 MXNET_BINARY_MATH_OP_NC_WITH_BOOL(minus, a - b);
355 
356 MXNET_UNARY_MATH_OP(negation, -a);
357 
358 MXNET_UNARY_MATH_OP(reciprocal, 1.0f / math::id(a));
359 
360 struct bitwise_not : public mxnet_op::tunable {
361   template<typename DType,
362            typename std::enable_if<!std::is_same<DType, bool>::value, int>::type = 0>
Mapbitwise_not363   MSHADOW_XINLINE static DType Map(DType a) {
364     return ~static_cast<int64_t>(a);
365   }
366 
Mapbitwise_not367   MSHADOW_XINLINE static bool Map(bool a) {
368     return !a;
369   }
370 };
371 
372 MXNET_UNARY_MATH_OP(reciprocal_grad, -1.0f / math::sqr(a));
373 
374 MXNET_UNARY_MATH_OP(sigmoid, 1.0f / (1.0f + math::exp(-a)));
375 
376 MXNET_UNARY_MATH_OP(sigmoid_grad, math::id(a) * (1.0f - math::id(a)));
377 
378 MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));
379 
380 MXNET_UNARY_MATH_OP(softsign_grad, 1.0f /  math::sqr(1.0f + math::fabs(a)));
381 
382 MXNET_UNARY_MATH_OP_NC(selu, DType(SELU_LAMBDA) *
383                          (a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a))));
384 
385 MXNET_UNARY_MATH_OP_NC(selu_grad,
386                        DType(SELU_LAMBDA) * (a > DType(0) ? DType(1) : DType(SELU_ALPHA + a)));
387 
388 MXNET_BINARY_MATH_OP_NC(prelu_grad, a > DType(0) ? DType(0) : a);
389 
390 MXNET_BINARY_MATH_OP_NC(xelu, a > DType(0) ? a :
391                         DType(static_cast<float>(a) * static_cast<float>(b)));
392 
393 MXNET_BINARY_MATH_OP_NC(xelu_grad, a > DType(0) ? DType(1) : b);
394 
395 MXNET_BINARY_MATH_OP_NC(elu, a > DType(0) ? a :
396                         DType(math::id(b) * math::expm1(a)));
397 
398 MXNET_BINARY_MATH_OP_NC(elu_grad, a > DType(0) ? DType(1) : DType(b + a));
399 
400 MXNET_SIMPLE_UNARY_MATH_OP(tanh);
401 
402 MXNET_UNARY_MATH_OP(tanh_grad, 1.0f - math::sqr(a));
403 
404 /*! \brief SoftReLU, also known as softplus activation */
405 struct softrelu : public mxnet_op::tunable {
406   template<typename DType>
Mapsoftrelu407   MSHADOW_XINLINE static DType Map(DType a) {
408     // Avoid overflow of exp for large inputs.
409     // Thresholds 20.0 is chosen such that softrelu(a) = a
410     // for a > 20 using floating precision
411     if (a > DType(20.0f)) {
412       return a;
413     } else {
414       return DType(math::log1p(math::exp(a)));
415     }
416   }
417 };
418 
419 MXNET_UNARY_MATH_OP(softrelu_grad, -math::expm1(-a));
420 
421 MXNET_UNARY_MATH_OP(erfinv_grad, 0.5 * math::sqrt(PI) * math::exp(math::sqr(a)));
422 
423 MXNET_UNARY_MATH_OP(erf_grad, 2.0 / math::sqrt(PI) * math::exp(-(a * a)));
424 
425 MXNET_SIMPLE_UNARY_MATH_OP(erf);
426 
427 MXNET_UNARY_MATH_OP(gelu,
428   DType(0.5f * static_cast<float>(a) * (1.0f + math::erf(static_cast<float>(a) / SQRT_2))));
429 
430 MXNET_BINARY_MATH_OP_NC(gelu_grad,
431   DType(0.5f * (1.0f + math::erf(static_cast<float>(a) / SQRT_2) +
432                 static_cast<float>(a) * erf_grad::Map(static_cast<float>(a) / SQRT_2) / SQRT_2)));
433 
434 MXNET_SIMPLE_UNARY_MATH_OP(exp);
435 
436 MXNET_SIMPLE_UNARY_MATH_OP(expm1);
437 
438 MXNET_SIMPLE_UNARY_MATH_OP(log);
439 
440 MXNET_UNARY_MATH_OP(log_grad, 1.0f / math::id(a));
441 
442 MXNET_SIMPLE_UNARY_MATH_OP(log10);
443 
444 // Constant is 1 / log(10)
445 struct log10_grad : public mxnet_op::tunable {
446   template<typename DType>
Maplog10_grad447   MSHADOW_XINLINE static DType Map(DType a) {
448     return DType(0.4342944819f / static_cast<float>(a));
449   }
450 };
451 
452 template<>
453 MSHADOW_XINLINE double log10_grad::Map<double>(double a) {
454   return 0.43429448190325182765 / a;
455 }
456 
457 MXNET_SIMPLE_UNARY_MATH_OP(log2);
458 
459 // Constant is 1 / log(2)
460 struct log2_grad : public mxnet_op::tunable {
461   template<typename DType>
Maplog2_grad462   MSHADOW_XINLINE static DType Map(DType a) {
463     return DType(1.442695041f / static_cast<float>(a));
464   }
465 };
466 
467 template<>
468 MSHADOW_XINLINE double log2_grad::Map<double>(double a) {
469   return 1.44269504088896340737 / a;
470 }
471 
472 MXNET_SIMPLE_UNARY_MATH_OP(sin);
473 
474 MXNET_UNARY_MATH_OP(sin_grad, math::cos(a));
475 
476 MXNET_SIMPLE_UNARY_MATH_OP(log1p);
477 
478 MXNET_UNARY_MATH_OP(log1p_grad, 1.0f / (1.0f + math::id(a)));
479 
480 MXNET_SIMPLE_UNARY_MATH_OP(cos);
481 
482 MXNET_UNARY_MATH_OP(cos_grad, -math::sin(a));
483 
484 MXNET_SIMPLE_UNARY_MATH_OP(tan);
485 
486 MXNET_UNARY_MATH_OP(tan_grad, math::sqr(a) + 1.0f);
487 
488 MXNET_UNARY_MATH_OP(arcsin, math::asin(a));
489 
490 MXNET_UNARY_MATH_OP(arcsin_grad, 1.0f / math::sqrt(1.0f - math::sqr(a)));
491 
492 MXNET_UNARY_MATH_OP(arccos, math::acos(a));
493 
494 MXNET_UNARY_MATH_OP(arccos_grad, -1.0f / math::sqrt(1.0f - math::sqr(a)));
495 
496 MXNET_UNARY_MATH_OP(arctan, math::atan(a));
497 
498 MXNET_UNARY_MATH_OP(arctan_grad, 1.0f / (math::sqr(a) + 1.0f));
499 
500 MXNET_SIMPLE_BINARY_MATH_OP(hypot);
501 
502 MXNET_BINARY_MATH_OP(hypot_grad_left, math::id(a) / math::hypot(a, b));
503 
504 MXNET_BINARY_MATH_OP(hypot_grad_right, math::id(b) / math::hypot(a, b));
505 
506 MXNET_UNARY_MATH_OP(degrees, 180.0f / PI * math::id(a));
507 
508 MXNET_UNARY_MATH_OP(degrees_grad, 180.0f / PI);
509 
510 MXNET_UNARY_MATH_OP(radians, PI / 180.0f * math::id(a));
511 
512 MXNET_UNARY_MATH_OP(radians_grad, PI / 180.0f);
513 
514 MXNET_SIMPLE_UNARY_MATH_OP(sinh);
515 
516 MXNET_UNARY_MATH_OP(sinh_grad, math::cosh(a));
517 
518 MXNET_SIMPLE_UNARY_MATH_OP(cosh);
519 
520 MXNET_UNARY_MATH_OP(cosh_grad, math::sinh(a));
521 
522 MXNET_UNARY_MATH_OP(arcsinh, math::asinh(a));
523 
524 MXNET_UNARY_MATH_OP(arcsinh_grad, 1.0f / math::hypot(a, DType(1)));
525 
526 MXNET_UNARY_MATH_OP(arccosh, math::acosh(a));
527 
528 MXNET_UNARY_MATH_OP(arccosh_grad, 1.0f / math::sqrt(math::sqr(a) - 1.0f));
529 
530 MXNET_UNARY_MATH_OP(arctanh, math::atanh(a));
531 
532 MXNET_UNARY_MATH_OP(arctanh_grad, 1.0f / (1.0f - math::sqr(a)));
533 
534 MXNET_UNARY_MATH_OP(square, math::sqr(a));
535 
536 MXNET_UNARY_MATH_OP(square_grad, 2.0f * math::id(a));
537 
538 /*! \brief used for generate Bernoulli mask */
539 MXNET_BINARY_MATH_OP_NC(threshold, a < b ? DType(1) : DType(0));
540 MXNET_BINARY_MATH_OP_NC(threshold_eq, a <= b ? DType(1) : DType(0));
541 
542 /*! \brief used for generate element of abs */
543 MXNET_UNARY_MATH_OP(abs, math::fabs(a)); // NOLINT(*)
544 
545 /*! \brief used for generate element of sign */
546 struct sign : public mxnet_op::tunable {
547   template<typename DType>
548   MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type
Mapsign549   Map(DType a) {
550     if (a < DType(0)) return DType(-DType(1));
551     if (a > DType(0)) return DType(1);
552     return DType(0);
553   }
554   template<typename DType>
555   MSHADOW_XINLINE static typename enable_if<is_unsigned<DType>::value, DType>::type
Mapsign556   Map(DType a) {
557     if (a > DType(0)) return DType(1);
558     return DType(0);
559   }
560 };
561 
562 MXNET_UNARY_MATH_OP_NC(sign_grad, DType(0));
563 
564 /*! \brief used for generate element of power */
565 MXNET_BINARY_MATH_OP(power, math::pow(a, b));
566 
567 MXNET_BINARY_MATH_OP(power_grad, math::pow(a, b - DType(1)) * math::id(b));
568 
569 MXNET_BINARY_MATH_OP(power_rgrad, math::pow(a, b) * math::log(a));
570 
571 MXNET_BINARY_MATH_OP(rpower, math::pow(b, a));
572 
573 MXNET_BINARY_MATH_OP(rpower_grad, math::id(a) * math::log(b));
574 
575 MXNET_BINARY_MATH_OP(arctan2, math::atan2(a, b));
576 MXNET_BINARY_MATH_OP(arctan2_grad, math::id(b) / (math::id(a * a + b * b)));
577 
578 MXNET_BINARY_MATH_OP(arctan2_rgrad, -math::id(a) / (math::id(a * a + b * b)));
579 
580 MXNET_BINARY_MATH_OP(rarctan2, math::atan2(b, a));
581 
582 MXNET_BINARY_MATH_OP(rarctan2_grad, math::id(a) / (math::id(a * a + b * b)));
583 
584 MXNET_UNARY_MATH_OP_NC(nt, a != DType(0) ? DType(0) : DType(1));
585 
586 MXNET_UNARY_LOGIC_OP_NC(np_logical_not, !static_cast<bool>(a));
587 
588 MXNET_BINARY_MATH_OP_NC(ge, a >= b ? DType(1) : DType(0));
589 
590 MXNET_BINARY_MATH_OP_NC(gt, a > b ? DType(1) : DType(0));
591 
592 MXNET_BINARY_MATH_OP_NC(lt, a < b ? DType(1) : DType(0));
593 
594 MXNET_BINARY_MATH_OP_NC(le, a <= b ? DType(1) : DType(0));
595 
596 MXNET_BINARY_MATH_OP_NC(eq, a == b ? DType(1) : DType(0));
597 
598 MXNET_BINARY_MATH_OP_NC(ne, a != b ? DType(1) : DType(0));
599 
600 MXNET_BINARY_LOGIC_OP_NC(np_greater_equal, a >= b ? true : false);
601 
602 MXNET_BINARY_LOGIC_OP_NC(np_greater, a > b ? true : false);
603 
604 MXNET_BINARY_LOGIC_OP_NC(np_less, a < b ? true : false);
605 
606 MXNET_BINARY_LOGIC_OP_NC(np_less_equal, a <= b ? true : false);
607 
608 MXNET_BINARY_LOGIC_OP_NC(np_equal, a == b ? true : false);
609 
610 MXNET_BINARY_LOGIC_OP_NC(np_not_equal, a != b ? true : false);
611 
612 MXNET_BINARY_MATH_OP(logical_and, a && b ? DType(1) : DType(0));
613 
614 MXNET_BINARY_MATH_OP(logical_or, a || b ? DType(1) : DType(0));
615 
616 MXNET_BINARY_MATH_OP(logical_xor, (a || b) && !(a && b) ? DType(1) : DType(0));
617 
618 MXNET_BINARY_MATH_OP(bitwise_and, static_cast<int64_t>(a) & static_cast<int64_t>(b));
619 
620 MXNET_BINARY_MATH_OP(bitwise_xor, static_cast<int64_t>(a) ^ static_cast<int64_t>(b));
621 
622 MXNET_BINARY_MATH_OP(bitwise_or, static_cast<int64_t>(a) | static_cast<int64_t>(b));
623 
624 MXNET_UNARY_MATH_OP(square_root, math::sqrt(a));
625 
626 MXNET_UNARY_MATH_OP(square_root_grad, 0.5f / math::id(a));
627 MXNET_UNARY_MATH_OP(reciprocal_square_root, 1.0f / math::sqrt(a));
628 
629 MXNET_UNARY_MATH_OP(reciprocal_square_root_grad, -0.5f / (math::sqrt(a) * math::id(a)));
630 
631 MXNET_UNARY_MATH_OP(cube_root, math::cbrt(a));
632 
633 MXNET_UNARY_MATH_OP(cube_root_grad, 1.0f / (3.0f * math::sqr(a)));
634 
635 MXNET_UNARY_MATH_OP(reciprocal_cube_root, 1.0f / math::cbrt(a));
636 
637 MXNET_UNARY_MATH_OP(reciprocal_cube_root_grad, -1.0f / (3.0f * math::cbrt(a) * math::id(a)));
638 
639 /*! \brief used for generate element of ldexp */
640 MXNET_BINARY_MATH_OP(ldexp, math::id(a) * math::pow(2.0f, b));
641 
642 MXNET_BINARY_MATH_OP(ldexp_grad, math::pow(2.0f, b));
643 
644 MXNET_BINARY_MATH_OP(ldexp_rgrad, math::id(a) * math::pow(2.0f, b) * math::log(2.0f));
645 
646 MXNET_BINARY_MATH_OP(rldexp, math::id(b) * math::pow(2.0f, a));  // swap a and b if a is scalar.
647 
648 MXNET_BINARY_MATH_OP(rldexp_grad, math::id(b) * math::pow(2.0f, a) * math::log(2.0f));
649 
650 /*! \brief used for generate element of round */
651 MXNET_SIMPLE_UNARY_MATH_OP(round);
652 
653 /*! \brief used for generate element of ceil */
654 MXNET_SIMPLE_UNARY_MATH_OP(ceil);
655 
656 /*! \brief used for generate element of floor */
657 MXNET_SIMPLE_UNARY_MATH_OP(floor);
658 
659 /*! \brief used to round towards zero */
660 MXNET_SIMPLE_UNARY_MATH_OP(trunc);
661 
662 /*! \brief used to round number to nearest integer */
663 struct rint : public mxnet_op::tunable {
664   template<typename DType>
Maprint665   MSHADOW_XINLINE static DType Map(DType a) {
666     auto floor = math::floor(a);
667     auto ceil = math::ceil(a);
668     auto af = math::id(a);
669     return DType((af - floor) <= (ceil - af) ? floor : ceil);
670   }
671 };
672 
673 /*! \brief used to round number to integer nearest to 0 */
674 struct fix : public mxnet_op::tunable {
675   template<typename DType>
Mapfix676   MSHADOW_XINLINE static DType Map(DType a) {
677     auto floor = math::floor(a);
678     auto ceil = math::ceil(a);
679     return DType((floor > 0 ? floor : -floor) < (ceil > 0 ? ceil : -ceil) ? floor : ceil);
680   }
681 };
682 
683 /*! \brief used to determine whether a number is Not A Number*/
684 struct isnan : public mxnet_op::tunable {
685   template<typename DType>
Mapisnan686   MSHADOW_XINLINE static bool Map(DType a) {
687     return IsNan(a);
688   }
689 };
690 
691 /*! \brief used to determine whether a number is infinite*/
692 struct isinf : public mxnet_op::tunable {
693   template<typename DType>
Mapisinf694   MSHADOW_XINLINE static bool Map(DType a) {
695     return IsInf(a);
696   }
697 };
698 
699 /*! \brief used to determine whether a number is finite*/
700 struct isfinite : public mxnet_op::tunable {
701   template<typename DType>
Mapisfinite702   MSHADOW_XINLINE static bool Map(DType a) {
703     return !IsNan(a) && !IsInf(a);
704   }
705 };
706 
707 /*! \brief used to determine whether a number is positive infinity*/
708 struct isposinf : public mxnet_op::tunable {
709   template<typename DType>
Mapisposinf710   MSHADOW_XINLINE static bool Map(DType a) {
711     return IsInf(a) && a > 0;
712   }
713 };
714 
715 /*! \brief used to determine whether a number is negative infinity*/
716 struct isneginf : public mxnet_op::tunable {
717   template<typename DType>
Mapisneginf718   MSHADOW_XINLINE static bool Map(DType a) {
719     return IsInf(a) && a < 0;
720   }
721 };
722 
723 /*! \brief used for generate gradient of MAE loss*/
724 MXNET_BINARY_MATH_OP_NC(minus_sign, a - b > DType(0) ? DType(1) : -DType(1));
725 
726 MXNET_BINARY_MATH_OP(rminus, b - a);
727 
728 MXNET_BINARY_MATH_OP_NC(posone, 1);
729 
730 MXNET_BINARY_MATH_OP_NC(negone, -1);
731 
732 MXNET_BINARY_MATH_OP(div_grad, 1.0f / math::id(b));
733 
734 template<>
735 MSHADOW_XINLINE mshadow::half::half2_t div_grad::Map<mshadow::half::half2_t>
736                                                (mshadow::half::half2_t a,
737                                                 mshadow::half::half2_t b) {
738   return mshadow::half::half2_t(1) / b;
739 }
740 
741 MXNET_BINARY_MATH_OP(div_rgrad, -math::id(a) / math::sqr(b));
742 
743 template<>
744 MSHADOW_XINLINE mshadow::half::half2_t div_rgrad::Map<mshadow::half::half2_t>
745                                                (mshadow::half::half2_t a,
746                                                 mshadow::half::half2_t b) {
747   return -a / (b * b);
748 }
749 
750 MXNET_BINARY_MATH_OP(rdiv, math::id(b) / math::id(a));
751 
752 MXNET_BINARY_MATH_OP(rdiv_grad, -math::id(b) / math::sqr(a));
753 
754 MXNET_BINARY_MATH_OP(copysign, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? a : -a);
755 
756 MXNET_BINARY_MATH_OP(copysign_grad, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? 1: -1);
757 
758 MXNET_BINARY_MATH_OP(copysign_rgrad, 0);
759 
760 MXNET_BINARY_MATH_OP(rcopysign, (b >= 0 && a >= 0) || (b < 0 && a < 0) ? b : -b);
761 
762 MXNET_BINARY_MATH_OP(rcopysign_grad, 0);
763 
764 struct mod : public mxnet_op::tunable {
765   template<typename DType>
766   MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type
Mapmod767   Map(DType a, DType b) {
768     if (b == DType(0)) {
769       return DType(0);
770     } else if (b < DType(0)) {
771       if (a < DType(0)) {
772         return DType(-::fmod(-static_cast<double>(a), -static_cast<double>(b)));
773       } else {
774         return DType(::fmod(static_cast<double>(a), -static_cast<double>(b)) +
775                      (::fmod(static_cast<double>(a), -static_cast<double>(b)) != DType(0)
776                       ? b : DType(0)));
777       }
778     } else {
779       if (a < DType(0)) {
780         return DType(-::fmod(-static_cast<double>(a), static_cast<double>(b)) +
781                      (::fmod(-static_cast<double>(a), static_cast<double>(b)) != DType(0)
782                       ? b : DType(0)));
783       } else {
784         return DType(::fmod(static_cast<double>(a), static_cast<double>(b)));
785       }
786     }
787   }
788   template<typename DType>
789   MSHADOW_XINLINE static typename enable_if<is_unsigned<DType>::value, DType>::type
Mapmod790   Map(DType a, DType b) {
791     if (b == DType(0)) {
792       return DType(0);
793     } else {
794       return DType(::fmod(static_cast<double>(a), static_cast<double>(b)));
795     }
796   }
797 };
798 
799 struct mixed_mod {
800   template<typename DType,
801            typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Mapmixed_mod802   MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
803     return mod::Map(static_cast<mshadow::half::half_t>(a), b);
804   }
805 
806   template<typename DType,
807            typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
808                                    std::is_integral<DType>::value, int>::type = 0>
Mapmixed_mod809   MSHADOW_XINLINE static float Map(DType a, float b) {
810     return mod::Map(static_cast<float>(a), b);
811   }
812 
813   template<typename DType,
814            typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
815                                    std::is_same<DType, float>::value ||
816                                    std::is_integral<DType>::value, int>::type = 0>
Mapmixed_mod817   MSHADOW_XINLINE static double Map(DType a, double b) {
818     return mod::Map(static_cast<double>(a), b);
819   }
820 };
821 
822 struct mixed_rmod {
823   template<typename DType,
824            typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Mapmixed_rmod825   MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
826     return mod::Map(b, static_cast<mshadow::half::half_t>(a));
827   }
828 
829   template<typename DType,
830            typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
831                                    std::is_integral<DType>::value, int>::type = 0>
Mapmixed_rmod832   MSHADOW_XINLINE static float Map(DType a, float b) {
833     return mod::Map(b, static_cast<float>(a));
834   }
835 
836   template<typename DType,
837            typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
838                                    std::is_same<DType, float>::value ||
839                                    std::is_integral<DType>::value, int>::type = 0>
Mapmixed_rmod840   MSHADOW_XINLINE static double Map(DType a, double b) {
841     return mod::Map(b, static_cast<double>(a));
842   }
843 };
844 
845 struct fmod : public mxnet_op::tunable {
846   template<typename DType>
Mapfmod847   MSHADOW_XINLINE static DType Map(DType a, DType b) {
848     if (b == DType(0)) {
849       return DType(0);
850     } else {
851         return DType(::fmod(static_cast<double>(a), static_cast<double>(b)));
852     }
853   }
854 };
855 
856 struct rfmod : public mxnet_op::tunable {
857   template<typename DType>
Maprfmod858   MSHADOW_XINLINE static DType Map(DType a, DType b) {
859     if (a == DType(0)) {
860       return DType(0);
861     } else  {
862       return DType(::fmod(static_cast<double>(b), static_cast<double>(a)));
863     }
864   }
865 };
866 
867 template<>
868 MSHADOW_XINLINE mshadow::half::half2_t mod::Map<mshadow::half::half2_t>
869                                                (mshadow::half::half2_t a,
870                                                 mshadow::half::half2_t b) {
871   return a%b;
872 }
873 
874 struct mod_grad : public mxnet_op::tunable  {
875   template<typename DType>
Mapmod_grad876   MSHADOW_XINLINE static DType Map(DType a, DType b) {
877     return DType(0);
878   }
879 };
880 template<>
881 MSHADOW_XINLINE double mod_grad::Map<double>(double a, double b) {
882   return 1.0;
883 }
884 template<>
885 MSHADOW_XINLINE float mod_grad::Map<float>(float a, float b) {
886   return 1.0f;
887 }
888 
889 template<>
890 MSHADOW_XINLINE mshadow::half::half_t mod_grad::Map<mshadow::half::half_t>
891                                                    (mshadow::half::half_t a,
892                                                     mshadow::half::half_t b) {
893   return mshadow::half::half_t(1.0f);
894 }
895 template<>
896 MSHADOW_XINLINE mshadow::half::half2_t mod_grad::Map<mshadow::half::half2_t>
897                                                     (mshadow::half::half2_t a,
898                                                      mshadow::half::half2_t b) {
899   mshadow::half::half2_t result = mshadow::half::half2_t();
900 #if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2)
901   result.half2_ = ::__float2half2_rn(1.0f);
902 #else
903   result.half_t2[0] = mshadow::half::half_t(0.0f);
904   result.half_t2[1] = mshadow::half::half_t(1.0f);
905 #endif
906   return result;
907 }
908 
909 struct mod_rgrad : public mxnet_op::tunable {
910   template<typename DType>
Mapmod_rgrad911   MSHADOW_XINLINE static DType Map(DType a, DType b) {
912     return DType(0);
913   }
914 };
915 template<>
916 MSHADOW_XINLINE double mod_rgrad::Map<double>(double a, double b) {
917   return -::floor(a/b);
918 }
919 template<>
920 MSHADOW_XINLINE float mod_rgrad::Map<float>(float a, float b) {
921   return -::floorf(a/b);
922 }
923 
924 template<>
925 MSHADOW_XINLINE mshadow::half::half_t mod_rgrad::Map<mshadow::half::half_t>
926                                                     (mshadow::half::half_t a,
927                                                      mshadow::half::half_t b) {
928   return mshadow::half::half_t(-::floorf(static_cast<float>(a/b)));
929 }
930 template<>
931 MSHADOW_XINLINE mshadow::half::half2_t mod_rgrad::Map<mshadow::half::half2_t>
932                                                      (mshadow::half::half2_t a,
933                                                       mshadow::half::half2_t b) {
934 #if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2)
935   return mshadow::half::half2_t(__hneg2(::h2floor((a/b).half2_)));
936 #else
937   return mshadow::half::half2_t(mshadow::half::half_t(-::floorf(
938                                   static_cast<float>(a.half_t2[0]/b.half_t2[0]))),
939                                 mshadow::half::half_t(-::floorf(
940                                   static_cast<float>(a.half_t2[1]/b.half_t2[1]))));
941 #endif
942 }
943 
944 struct rmod : public mxnet_op::tunable {
945   template<typename DType>
946   MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type
Maprmod947   Map(DType a, DType b) {
948     if (a == DType(0)) {
949       return DType(0);
950     } else if (a < DType(0)) {
951       if (b < DType(0)) {
952         return DType(-::fmod(-static_cast<double>(b), -static_cast<double>(a)));
953       } else {
954         return DType(::fmod(static_cast<double>(b), -static_cast<double>(a)) +
955                      (::fmod(static_cast<double>(b), -static_cast<double>(a)) != DType(0)
956                       ? a : DType(0)));
957       }
958     } else {
959       if (b < DType(0)) {
960         return DType(-::fmod(-static_cast<double>(b), static_cast<double>(a)) +
961                      (::fmod(-static_cast<double>(b), static_cast<double>(a)) != DType(0)
962                       ? a : DType(0)));
963       } else {
964         return DType(::fmod(static_cast<double>(b), static_cast<double>(a)));
965       }
966     }
967   }
968   template<typename DType>
969   MSHADOW_XINLINE static typename enable_if<is_unsigned<DType>::value, DType>::type
Maprmod970   Map(DType a, DType b) {
971     if (a == DType(0)) {
972       return DType(0);
973     } else {
974       return DType(::fmod(static_cast<double>(b), static_cast<double>(a)));
975     }
976   }
977 };
978 
979 template<>
980 MSHADOW_XINLINE mshadow::half::half2_t rmod::Map<mshadow::half::half2_t>
981                                                 (mshadow::half::half2_t a,
982                                                  mshadow::half::half2_t b) {
983   return b%a;
984 }
985 
986 struct rmod_grad {
987   template<typename DType>
Maprmod_grad988   MSHADOW_XINLINE static DType Map(DType a, DType b) {
989     return DType(0);
990   }
991 };
992 template<>
993 MSHADOW_XINLINE double rmod_grad::Map<double>(double a, double b) {
994   return -::floor(b/a);
995 }
996 template<>
997 MSHADOW_XINLINE float rmod_grad::Map<float>(float a, float b) {
998   return -::floorf(b/a);
999 }
1000 
1001 template<>
1002 MSHADOW_XINLINE mshadow::half::half_t rmod_grad::Map<mshadow::half::half_t>
1003                                                    (mshadow::half::half_t a,
1004                                                     mshadow::half::half_t b) {
1005   return mshadow::half::half_t(-::floorf(static_cast<float>(b/a)));
1006 }
1007 template<>
1008 MSHADOW_XINLINE mshadow::half::half2_t rmod_grad::Map<mshadow::half::half2_t>
1009                                                      (mshadow::half::half2_t a,
1010                                                       mshadow::half::half2_t b) {
1011 #if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2)
1012   return mshadow::half::half2_t(::__hneg2(::h2floor((b/a).half2_)));
1013 #else
1014   return mshadow::half::half2_t(mshadow::half::half_t(-::floorf(
1015                                   static_cast<float>(b.half_t2[0]/a.half_t2[0]))),
1016                                 mshadow::half::half_t(-::floorf(
1017                                   static_cast<float>(b.half_t2[1]/a.half_t2[1]))));
1018 #endif
1019 }
1020 
1021 struct clip : public mxnet_op::tunable {
1022   template<typename DType>
Mapclip1023   MSHADOW_XINLINE static DType Map(DType x, DType bound) {
1024     if (x > bound) {
1025       return bound;
1026     } else if (x < -bound) {
1027       return -bound;
1028     } else {
1029       return x;
1030     }
1031   }
1032   template<typename DType>
Mapclip1033   MSHADOW_XINLINE static DType Map(DType x, DType lower_bound, DType upper_bound) {
1034     if (x > upper_bound) {
1035       return upper_bound;
1036     } else if (x < lower_bound) {
1037       return lower_bound;
1038     }
1039     return x;
1040   }
1041 };
1042 
1043 /***** gamma ******/
1044 
1045 MXNET_UNARY_MATH_OP(gamma, math::tgamma(a));
1046 
1047 struct gamma_grad : public mxnet_op::tunable {
1048   template<typename DType>
Mapgamma_grad1049   MSHADOW_XINLINE static DType Map(DType a) {
1050     // default implementation using floating precision
1051     float af(static_cast<float>(a));
1052     return DType(math::tgamma(af) * special_functions::cephes::psi<float>(af));
1053   }
1054 };
1055 
1056 template<>
1057 MSHADOW_XINLINE double gamma_grad::Map<double>(double a) {
1058   return math::tgamma(a) * special_functions::cephes::psi<double>(a);
1059 }
1060 
1061 /***** gammaln ******/
1062 
1063 MXNET_UNARY_MATH_OP(gammaln, math::lgamma(a));
1064 
1065 struct gammaln_grad : public mxnet_op::tunable {
1066   template<typename DType>
Mapgammaln_grad1067   MSHADOW_XINLINE static DType Map(DType a) {
1068     // default implementation using floating precision
1069     return DType(special_functions::cephes::psi<float>(a));
1070   }
1071 };
1072 
1073 template<>
1074 MSHADOW_XINLINE double gammaln_grad::Map<double>(double a) {
1075   return special_functions::cephes::psi<double>(a);
1076 }
1077 
1078 /* Smooth L1 Loss is a loss specific for R-CNN franchise training
1079  * Smooth L1 Loss function:
1080  * f(x) = 0.5 * (sigma * x) ^ 2,     |x| < 1 / sigma^2
1081  *      = |x| - 0.5 / sigma / sigma, otherwise
1082  * When sigma = 1, it is equivalent to the Huber loss, evaluated at
1083  * delta = 1.
1084  * smooth_l1_loss = w_out * f(w_in * x)
1085  * with w_in, w_out provided by input_data.
1086  */
1087 struct smooth_l1_loss : public mxnet_op::tunable {
1088   // a is x, b is sigma
1089   template<typename DType>
Mapsmooth_l1_loss1090   MSHADOW_XINLINE static DType Map(DType a, DType b) {
1091     auto bsq = math::sqr(b);
1092     auto ibsq = 1.0f / bsq;
1093     auto af = math::id(a);
1094     if (af > ibsq) {
1095       return DType(af - 0.5f * ibsq);
1096     } else if (af < -ibsq) {
1097       return DType(-af - 0.5f * ibsq);
1098     } else {
1099       return DType(0.5f * af * af * bsq);
1100     }
1101   }
1102 };  // struct smooth_l1_loss
1103 
1104 /* The derivative of smooth l1 loss is
1105  * f'(x) = sigma^2 * x, |x| < 1 / sigma^2
1106  *       = sign(x),     otherwise
1107  */
1108 struct smooth_l1_gradient : public mxnet_op::tunable {
1109   // a is x, b is sigma2
1110   template<typename DType>
Mapsmooth_l1_gradient1111   MSHADOW_XINLINE static DType Map(DType a, DType b) {
1112     auto bsq = math::sqr(b);
1113     auto ibsq = 1.0f / bsq;
1114     auto af = math::id(a);
1115     if (af > ibsq) {
1116       return DType(1);
1117     } else if (af < -ibsq) {
1118       return DType(-1);
1119     } else {
1120       return DType(bsq * af);
1121     }
1122   }
1123 };  // struct smooth_l1_derivative
1124 
1125 /*! \brief product reducer */
1126 struct product {
1127   /*! \brief do reduction into dst */
1128   template<typename DType>
Reduceproduct1129   MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
1130     dst *= src;
1131   }
1132   /*! \brief do reduction into dst */
1133   template<typename DType>
Reduceproduct1134   MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& none) { // NOLINT(*)
1135     Reduce(dst, src);
1136   }
1137   /*! \brief combine the results of two reducers */
1138   template<typename DType>
Mergeproduct1139   MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
1140     Reduce(dst_val, src_val);
1141   }
1142   /*! \brief combine the results of two reducers */
1143   template<typename DType>
Mergeproduct1144   MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
1145     Reduce(dst_val, src_val);
1146   }
1147   /*! \brief finalize reduction */
1148   template<typename DType>
Finalizeproduct1149   MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
1150   /*! \brief finalize reduction */
1151   template<typename DType>
Finalizeproduct1152   MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& none) {} // NOLINT(*)
1153   /*!
1154   *\brief calculate gradient of redres with respect to redsrc,
1155   * redres: reduced result, redsrc: one of reduction element
1156   */
1157   template<typename DType>
PartialGradproduct1158   MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
1159     return redres / redsrc;
1160   }
1161   /*!
1162   *\brief set the initial value during reduction
1163   */
1164   template<typename DType>
SetInitValueproduct1165   MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
1166     initv = 1;
1167   }
1168   /*!
1169   *\brief set the initial value during reduction
1170   */
1171   template<typename DType>
SetInitValueproduct1172   MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*)
1173     SetInitValue(initv);
1174   }
1175 };
1176 
1177 MXNET_UNARY_MATH_OP_NC(relu, IsNan(a) || (a > DType(0)) ? a : DType(0));
1178 
1179 /*! \brief used for computing gradient of relu operator */
1180 struct relu_grad : public mxnet_op::tunable {
1181   template<typename DType>
Maprelu_grad1182   MSHADOW_XINLINE static DType Map(DType a) {
1183     if (IsNan(a)) {
1184       return a;
1185     } else {
1186       return a > DType(0) ? DType(1) : DType(0);
1187     }
1188   }
1189 };
1190 
1191 /*! \brief used for computing binary operator maximum */
1192 struct maximum : public mxnet_op::tunable {
1193   template<typename DType>
Mapmaximum1194   MSHADOW_XINLINE static DType Map(DType a, DType b) {
1195     if (IsNan(a)) {
1196       return a;
1197     } else {
1198       return (a > b ? a : b);
1199     }
1200   }
1201 };
1202 
1203 /*! \brief used for computing binary operator minimum */
1204 struct minimum : public mxnet_op::tunable {
1205   template<typename DType>
Mapminimum1206   MSHADOW_XINLINE static DType Map(DType a, DType b) {
1207     if (IsNan(a)) {
1208       return a;
1209     } else {
1210       return DType(a < b ? a : b);
1211     }
1212   }
1213 };
1214 
1215 /*! \brief boolean any/all kernel that determines whether elem is NonZero */
1216 struct NonZero {
1217   template<typename DType>
MapNonZero1218   MSHADOW_XINLINE static bool Map(DType a) {
1219     return (a != DType(0));
1220   }
1221 };
1222 
1223 /*! \brief sum reducer that ignores NaN values in the input */
1224 struct nansum {
1225   /*! \brief do reduction into dst */
1226   template<typename DType>
Reducenansum1227   MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
1228     if (IsNan(src)) return;
1229     dst += src;
1230   }
1231   /*! \brief do reduction into dst */
1232   template<typename DType>
Reducenansum1233   MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*)
1234     if (IsNan(src)) return;
1235     DType y = src - residual;
1236     DType t = dst + y;
1237     residual = (t - dst) - y;
1238     dst = t;
1239   }
1240   /*! \brief combine the results of two reducers */
1241   template<typename DType>
Mergenansum1242   MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
1243     Reduce(dst_val, src_val);
1244   }
1245   /*! \brief combine the results of two reducers */
1246   template<typename DType>
Mergenansum1247   MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
1248     DType t1 = dst_val + src_val;
1249     DType e = t1 - src_val;
1250     DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
1251     dst_val = t1 + t2;
1252     dst_residual = t2 - (dst_val - t1);
1253   }
1254   /*! \brief finalize reduction */
1255   template<typename DType>
Finalizenansum1256   MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
1257   /*! \brief finalize reduction */
1258   template<typename DType>
Finalizenansum1259   MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
1260   /*!
1261   *\brief set the initial value during reduction
1262   */
1263   template<typename DType>
SetInitValuenansum1264   MSHADOW_XINLINE static void SetInitValue(DType & initv) { // NOLINT(*)
1265       initv = 0;
1266   }
1267   /*!
1268    *\brief set the initial value during reduction
1269    */
1270   template<typename DType>
SetInitValuenansum1271   MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*)
1272     SetInitValue(initv);
1273     residual = 0;
1274   }
1275 };
1276 
1277 struct nansum_grad : public mxnet_op::tunable {
1278   template<typename DType>
Mapnansum_grad1279   MSHADOW_XINLINE static DType Map(DType a, DType b) {
1280     return IsNan(a) ? DType(0) : DType(1);
1281   }
1282 };
1283 
1284 /*! \brief product reducer that ignores NaN values in the input */
1285 struct nanprod {
1286   /*! \brief do reduction into dst */
1287   template<typename DType>
Reducenanprod1288   MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
1289     if (IsNan(src)) return;
1290     dst *= src;
1291   }
1292   /*! \brief do reduction into dst */
1293   template<typename DType>
Reducenanprod1294   MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& none) { // NOLINT(*)
1295     Reduce(dst, src);
1296   }
1297   /*! \brief combine the results of two reducers */
1298   template<typename DType>
Mergenanprod1299   MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
1300     Reduce(dst_val, src_val);
1301   }
1302   /*! \brief combine the results of two reducers */
1303   template<typename DType>
Mergenanprod1304   MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
1305     Reduce(dst_val, src_val);
1306   }
1307   /*! \brief finalize reduction */
1308   template<typename DType>
Finalizenanprod1309   MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
1310   /*! \brief finalize reduction */
1311   template<typename DType>
Finalizenanprod1312   MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& none) {} // NOLINT(*)
1313   /*!
1314   *\brief set the initial value during reduction
1315   */
1316   template<typename DType>
SetInitValuenanprod1317   MSHADOW_XINLINE static void SetInitValue(DType & initv) { // NOLINT(*)
1318     initv = 1;
1319   }
1320 
1321   /*!
1322   *\brief set the initial value during reduction
1323   */
1324   template<typename DType>
SetInitValuenanprod1325   MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*)
1326     SetInitValue(initv);
1327   }
1328 };
1329 
1330 /*! \brief compute l2 norm */
1331 struct nrm2 {
1332   /*! \brief do reduction into dst */
1333   template<typename AType, typename DType>
Reducenrm21334   MSHADOW_XINLINE static void Reduce(volatile AType& sum_of_squares, volatile DType src) { // NOLINT(*)
1335     sum_of_squares += src * src;
1336   }
1337   /*! \brief do stable reduction into dst */
1338   template<typename AType, typename DType>
Reducenrm21339   MSHADOW_XINLINE static void Reduce(volatile AType& sum_of_squares,  volatile DType src, volatile DType& scale) { // NOLINT(*)
1340     if (src != 0) {
1341       DType abs = mshadow_op::abs::Map(src);
1342       if (scale < abs) {
1343         sum_of_squares = 1 + sum_of_squares * (scale / abs) * (scale / abs);
1344         scale = abs;
1345       } else {
1346         sum_of_squares = sum_of_squares + (abs / scale) * (abs / scale);
1347       }
1348     }
1349   }
1350   /*! \brief combine the results of two reducers */
1351   template<typename DType>
Mergenrm21352   MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
1353     dst_val += src_val;
1354   }
1355   /*! \brief combine the results of two reducers */
1356   template<typename DType>
Mergenrm21357   MSHADOW_XINLINE static void Merge(volatile DType& dst_ssq, volatile DType& dst_scale, volatile DType& src_ssq, volatile DType& src_scale) { // NOLINT(*)
1358     if (dst_scale != 0 && dst_scale >= src_scale) {
1359       dst_ssq = dst_ssq + src_ssq * (src_scale / dst_scale) * (src_scale / dst_scale);
1360     } else if (src_scale != 0 && dst_scale < src_scale) {
1361       dst_ssq = src_ssq + dst_ssq * (dst_scale / src_scale) * (dst_scale / src_scale);
1362       dst_scale = src_scale;
1363     }
1364   }
1365   /*! \brief finalize reduction result */
1366   template<typename DType>
Finalizenrm21367   MSHADOW_XINLINE static void Finalize(volatile DType& sum_of_squares) { // NOLINT(*)
1368     sum_of_squares = math::sqrt(sum_of_squares);
1369   }
1370   /*! \brief finalize reduction result */
1371   template<typename DType>
Finalizenrm21372   MSHADOW_XINLINE static void Finalize(volatile DType& sum_of_squares, volatile DType& scale) { // NOLINT(*)
1373     sum_of_squares = scale * math::sqrt(sum_of_squares);
1374   }
1375   /*!
1376    *\brief calculate gradient of redres with respect to redsrc,
1377    * redres: reduced result, redsrc: one of reduction element
1378    */
1379   template<typename DType>
PartialGradnrm21380   MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
1381     return redsrc / redres;
1382   }
1383   /*!
1384    *\brief set the initial value during reduction
1385    */
1386   template<typename DType>
SetInitValuenrm21387   MSHADOW_XINLINE static void SetInitValue(DType &sum_of_squares) { // NOLINT(*)
1388     sum_of_squares = 0;
1389   }
1390   /*!
1391    *\brief set the initial value during reduction
1392    */
1393   template<typename DType>
SetInitValuenrm21394   MSHADOW_XINLINE static void SetInitValue(DType &sum_of_squares, DType &scale) { // NOLINT(*)
1395     SetInitValue(sum_of_squares);
1396     scale = 0;
1397   }
1398 };
1399 
1400 /*! \brief sum reducer */
1401 struct sum {
1402   /*! \brief do reduction into dst */
1403   template<typename AType, typename DType>
Reducesum1404   MSHADOW_XINLINE static void Reduce(volatile AType& dst,  volatile DType src) { // NOLINT(*)
1405     dst += src;
1406   }
1407   /*! \brief do stable reduction into dst */
1408   template<typename AType, typename DType>
Reducesum1409   MSHADOW_XINLINE static void Reduce(volatile AType& dst,  volatile DType src, volatile DType& residual) { // NOLINT(*)
1410     DType y = src - residual;
1411     DType t = dst + y;
1412     residual = (t - dst) - y;
1413     dst = t;
1414   }
1415   /*! \brief combine the results of two reducers */
1416   template<typename DType>
Mergesum1417   MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
1418     Reduce(dst_val, src_val);
1419   }
1420   /*! \brief combine the results of two reducers */
1421   template<typename DType>
Mergesum1422   MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
1423     DType t1 = dst_val + src_val;
1424     DType e = t1 - dst_val;
1425     DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
1426     dst_val = t1 + t2;
1427     dst_residual = t2 - (dst_val - t1);
1428   }
1429   /*! \brief finalize reduction */
1430   template<typename DType>
Finalizesum1431   MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
1432   /*! \brief finalize reduction */
1433   template<typename DType>
Finalizesum1434   MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
1435   /*!
1436    *\brief calculate gradient of redres with respect to redsrc,
1437    * redres: reduced result, redsrc: one of reduction element
1438    */
1439   template<typename DType>
PartialGradsum1440   MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
1441     return 1;
1442   }
1443   /*!
1444    *\brief set the initial value during reduction
1445    */
1446   template<typename DType>
SetInitValuesum1447   MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
1448     initv = 0;
1449   }
1450   /*!
1451    *\brief set the initial value during reduction
1452    */
1453   template<typename DType>
SetInitValuesum1454   MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*)
1455     SetInitValue(initv);
1456     residual = 0;
1457   }
1458 };
1459 
1460 struct nanprod_grad : public mxnet_op::tunable {
1461   template<typename DType>
Mapnanprod_grad1462   MSHADOW_XINLINE static DType Map(DType a, DType b) {
1463     return IsNan(a) ? DType(0) : b / a;
1464   }
1465 };
1466 
1467 /*! \brief used for computing binary lowest common multiple */
1468 struct lcm : public mxnet_op::tunable {
1469   template<typename DType>
1470   MSHADOW_XINLINE static typename enable_if<is_integral<DType>::value, DType>::type
Maplcm1471   Map(DType a, DType b) {
1472     // minus cases.
1473     if (a < 0) {
1474       a = -a;
1475     }
1476     if (b < 0) {
1477       b = -b;
1478     }
1479     // handle zero-valued cases.
1480     DType c;
1481     if (a == 0 || b == 0) {
1482       c = 0;
1483     } else {
1484       DType tmp;
1485       DType tmp_a = a;
1486       DType tmp_b = b;
1487       if (a < b) {
1488         tmp = a;
1489         a = b;
1490         b = tmp;
1491       }
1492       while (a % b != 0) {
1493         a = a % b;
1494         tmp = a;
1495         a = b;
1496         b = tmp;
1497       }
1498       c = tmp_a / b * tmp_b;
1499     }
1500     return c;
1501   }
1502   template<typename DType>
1503   MSHADOW_XINLINE static typename enable_if<!is_integral<DType>::value, DType>::type
Maplcm1504   Map(DType a, DType b) {
1505     return DType(0.0f);
1506   }
1507 };
1508 
1509 }  // namespace mshadow_op
1510 }  // namespace op
1511 }  // namespace mxnet
1512 #endif  // MXNET_OPERATOR_MSHADOW_OP_H_
1513