1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file mshadow_op.h 22 * \brief 23 * \author Bing Xu 24 */ 25 #ifndef MXNET_OPERATOR_MSHADOW_OP_H_ 26 #define MXNET_OPERATOR_MSHADOW_OP_H_ 27 28 #include <mxnet/base.h> 29 #include <mshadow/base.h> 30 #include "math.h" 31 #include "math_functions-inl.h" 32 #include "special_functions-inl.h" 33 #include "./operator_tune.h" 34 #include "./contrib/erfinv-inl.h" 35 36 #ifdef __CUDACC__ 37 #include <cuda_fp16.h> 38 #endif 39 40 namespace mxnet { 41 namespace op { 42 namespace mshadow_op { 43 44 using mshadow::isnan_typed::IsNan; 45 using mshadow::isinf_typed::IsInf; 46 47 #ifdef __CUDA_ARCH__ 48 __constant__ const float PI = 3.14159265358979323846; 49 __constant__ const float SELU_ALPHA = 1.6732632423543772848170429916717; 50 __constant__ const float SELU_LAMBDA = 1.0507009873554804934193349852946; 51 __constant__ const float SQRT_2 = 1.4142135623730950488016887242096; 52 #else 53 const float PI = 3.14159265358979323846; 54 const float SELU_ALPHA = 1.6732632423543772848170429916717; 55 const float SELU_LAMBDA = 1.0507009873554804934193349852946; 56 const float SQRT_2 = 1.4142135623730950488016887242096; 57 #endif 58 using std::enable_if; 59 using std::is_unsigned; 60 using std::is_integral; 61 62 #define MXNET_UNARY_MATH_OP(name, expr) \ 63 struct name : public mxnet_op::tunable { \ 64 template<typename DType> \ 65 MSHADOW_XINLINE static DType Map(DType a) { \ 66 return DType(expr); \ 67 } \ 68 } 69 70 #define MXNET_UNARY_MATH_OP_NC(name, expr) \ 71 struct name : public mxnet_op::tunable { \ 72 template<typename DType> \ 73 MSHADOW_XINLINE static DType Map(DType a) { \ 74 return (expr); \ 75 } \ 76 } 77 78 #define MXNET_UNARY_LOGIC_OP_NC(name, expr) \ 79 struct name : public mxnet_op::tunable { \ 80 template<typename DType> \ 81 MSHADOW_XINLINE static bool Map(DType a) { \ 82 return (expr); \ 83 } \ 84 } 85 86 #define MXNET_BINARY_MATH_OP(name, expr) \ 87 struct name : public mxnet_op::tunable { \ 88 template<typename DType> \ 89 MSHADOW_XINLINE static DType Map(DType a, DType b) { \ 90 return DType(expr); \ 91 } \ 92 } 93 94 #define MXNET_BINARY_MATH_OP_NC(name, expr) \ 95 struct name : public mxnet_op::tunable { \ 96 template<typename DType> \ 97 MSHADOW_XINLINE static DType Map(DType a, DType b) { \ 98 return (expr); \ 99 } \ 100 } 101 102 #define MXNET_BINARY_MATH_OP_NC_WITH_BOOL(name, expr) \ 103 struct name : public mxnet_op::tunable { \ 104 template<typename DType, \ 105 typename std::enable_if<!std::is_same<DType, bool>::value, int>::type = 0> \ 106 MSHADOW_XINLINE static DType Map(DType a, DType b) { \ 107 return (expr); \ 108 } \ 109 MSHADOW_XINLINE static bool Map(bool a, bool b) { \ 110 return (expr); \ 111 } \ 112 } 113 114 #define MXNET_BINARY_LOGIC_OP_NC(name, expr) \ 115 struct name : public mxnet_op::tunable { \ 116 template<typename DType> \ 117 MSHADOW_XINLINE static bool Map(DType a, DType b) { \ 118 return (expr); \ 119 } \ 120 } 121 122 #define MXNET_SIMPLE_UNARY_MATH_OP(name) MXNET_UNARY_MATH_OP(name, math::name(a)) 123 124 #define MXNET_SIMPLE_BINARY_MATH_OP(name) MXNET_BINARY_MATH_OP(name, math::name(a, b)) 125 126 MXNET_UNARY_MATH_OP_NC(identity, a); 127 128 MXNET_UNARY_MATH_OP(identity_grad, 1); 129 130 struct identity_with_cast { 131 template<typename DTypeIn, typename DTypeOut> Mapidentity_with_cast132 MSHADOW_XINLINE static void Map(index_t i, DTypeOut *out, DTypeIn *in) { 133 out[i] = DTypeOut(in[i]); 134 } 135 }; 136 137 struct true_divide : public mxnet_op::tunable { 138 template<typename DType, 139 typename std::enable_if<!std::is_integral<DType>::value, int>::type = 0> Maptrue_divide140 MSHADOW_XINLINE static DType Map(DType a, DType b) { 141 return a / b; 142 } 143 144 template<typename DType, 145 typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> Maptrue_divide146 MSHADOW_XINLINE static float Map(DType a, DType b) { 147 return static_cast<float>(a) / static_cast<float>(b); 148 } 149 150 template<typename DType, 151 typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> Maptrue_divide152 MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { 153 return static_cast<mshadow::half::half_t>(a) / b; 154 } 155 156 template<typename DType, 157 typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> Maptrue_divide158 MSHADOW_XINLINE static float Map(DType a, float b) { 159 return static_cast<float>(a) / b; 160 } 161 162 template<typename DType, 163 typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> Maptrue_divide164 MSHADOW_XINLINE static double Map(DType a, double b) { 165 return static_cast<double>(a) / b; 166 } 167 }; 168 169 struct rtrue_divide : public mxnet_op::tunable { 170 template<typename DType, 171 typename std::enable_if<!std::is_integral<DType>::value, int>::type = 0> Maprtrue_divide172 MSHADOW_XINLINE static DType Map(DType a, DType b) { 173 return b / a; 174 } 175 176 template<typename DType, 177 typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> Maprtrue_divide178 MSHADOW_XINLINE static float Map(DType a, DType b) { 179 return static_cast<float>(b) / static_cast<float>(a); 180 } 181 182 template<typename DType, 183 typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> Maprtrue_divide184 MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { 185 return b / static_cast<mshadow::half::half_t>(a); 186 } 187 188 template<typename DType, 189 typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> Maprtrue_divide190 MSHADOW_XINLINE static float Map(DType a, float b) { 191 return b / static_cast<float>(a); 192 } 193 194 template<typename DType, 195 typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> Maprtrue_divide196 MSHADOW_XINLINE static double Map(DType a, double b) { 197 return b / static_cast<double>(a); 198 } 199 }; 200 201 MXNET_BINARY_MATH_OP_NC(left, a); 202 203 MXNET_BINARY_MATH_OP_NC(right, b); 204 205 struct mixed_plus { 206 template<typename DType, 207 typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> Mapmixed_plus208 MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { 209 return static_cast<mshadow::half::half_t>(a) + b; 210 } 211 212 template<typename DType, 213 typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || 214 std::is_integral<DType>::value, int>::type = 0> Mapmixed_plus215 MSHADOW_XINLINE static float Map(DType a, float b) { 216 return static_cast<float>(a) + b; 217 } 218 219 template<typename DType, 220 typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || 221 std::is_same<DType, float>::value || 222 std::is_integral<DType>::value, int>::type = 0> Mapmixed_plus223 MSHADOW_XINLINE static double Map(DType a, double b) { 224 return static_cast<double>(a) + b; 225 } 226 }; 227 228 struct mixed_minus { 229 template<typename DType, 230 typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> Mapmixed_minus231 MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { 232 return static_cast<mshadow::half::half_t>(a) - b; 233 } 234 235 template<typename DType, 236 typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || 237 std::is_integral<DType>::value, int>::type = 0> Mapmixed_minus238 MSHADOW_XINLINE static float Map(DType a, float b) { 239 return static_cast<float>(a) - b; 240 } 241 242 template<typename DType, 243 typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || 244 std::is_same<DType, float>::value || 245 std::is_integral<DType>::value, int>::type = 0> Mapmixed_minus246 MSHADOW_XINLINE static double Map(DType a, double b) { 247 return static_cast<double>(a) - b; 248 } 249 }; 250 251 struct mixed_rminus { 252 template<typename DType, 253 typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> Mapmixed_rminus254 MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { 255 return b - static_cast<mshadow::half::half_t>(a); 256 } 257 258 template<typename DType, 259 typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || 260 std::is_integral<DType>::value, int>::type = 0> Mapmixed_rminus261 MSHADOW_XINLINE static float Map(DType a, float b) { 262 return b - static_cast<float>(a); 263 } 264 265 template<typename DType, 266 typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || 267 std::is_same<DType, float>::value || 268 std::is_integral<DType>::value, int>::type = 0> Mapmixed_rminus269 MSHADOW_XINLINE static double Map(DType a, double b) { 270 return b - static_cast<double>(a); 271 } 272 }; 273 274 struct mixed_mul { 275 template<typename DType, 276 typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> Mapmixed_mul277 MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { 278 return static_cast<mshadow::half::half_t>(a) * b; 279 } 280 281 template<typename DType, 282 typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || 283 std::is_integral<DType>::value, int>::type = 0> Mapmixed_mul284 MSHADOW_XINLINE static float Map(DType a, float b) { 285 return static_cast<float>(a) * b; 286 } 287 288 template<typename DType, 289 typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || 290 std::is_same<DType, float>::value || 291 std::is_integral<DType>::value, int>::type = 0> Mapmixed_mul292 MSHADOW_XINLINE static double Map(DType a, double b) { 293 return static_cast<double>(a) * b; 294 } 295 }; 296 297 struct mixed_power { 298 template<typename DType, 299 typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> Mapmixed_power300 MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { 301 return static_cast<mshadow::half::half_t>(math::pow(a, b)); 302 } 303 304 template<typename DType, 305 typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || 306 std::is_integral<DType>::value, int>::type = 0> Mapmixed_power307 MSHADOW_XINLINE static float Map(DType a, float b) { 308 return static_cast<float>(math::pow(a, b)); 309 } 310 311 template<typename DType, 312 typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || 313 std::is_same<DType, float>::value || 314 std::is_integral<DType>::value, int>::type = 0> Mapmixed_power315 MSHADOW_XINLINE static double Map(DType a, double b) { 316 return static_cast<double>(math::pow(a, b)); 317 } 318 }; 319 320 struct mixed_rpower { 321 template<typename DType, 322 typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> Mapmixed_rpower323 MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { 324 return static_cast<mshadow::half::half_t>(math::pow(b, a)); 325 } 326 327 template<typename DType, 328 typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || 329 std::is_integral<DType>::value, int>::type = 0> Mapmixed_rpower330 MSHADOW_XINLINE static float Map(DType a, float b) { 331 return static_cast<float>(math::pow(b, a)); 332 } 333 334 template<typename DType, 335 typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || 336 std::is_same<DType, float>::value || 337 std::is_integral<DType>::value, int>::type = 0> Mapmixed_rpower338 MSHADOW_XINLINE static double Map(DType a, double b) { 339 return static_cast<double>(math::pow(b, a)); 340 } 341 }; 342 343 #pragma GCC diagnostic push 344 #if __GNUC__ >= 7 345 #pragma GCC diagnostic ignored "-Wint-in-bool-context" 346 #pragma GCC diagnostic ignored "-Wbool-compare" 347 #endif 348 MXNET_BINARY_MATH_OP_NC_WITH_BOOL(mul, a * b); 349 350 MXNET_BINARY_MATH_OP_NC_WITH_BOOL(div, a / b); 351 352 MXNET_BINARY_MATH_OP_NC_WITH_BOOL(plus, a + b); 353 354 MXNET_BINARY_MATH_OP_NC_WITH_BOOL(minus, a - b); 355 356 MXNET_UNARY_MATH_OP(negation, -a); 357 358 MXNET_UNARY_MATH_OP(reciprocal, 1.0f / math::id(a)); 359 360 struct bitwise_not : public mxnet_op::tunable { 361 template<typename DType, 362 typename std::enable_if<!std::is_same<DType, bool>::value, int>::type = 0> Mapbitwise_not363 MSHADOW_XINLINE static DType Map(DType a) { 364 return ~static_cast<int64_t>(a); 365 } 366 Mapbitwise_not367 MSHADOW_XINLINE static bool Map(bool a) { 368 return !a; 369 } 370 }; 371 372 MXNET_UNARY_MATH_OP(reciprocal_grad, -1.0f / math::sqr(a)); 373 374 MXNET_UNARY_MATH_OP(sigmoid, 1.0f / (1.0f + math::exp(-a))); 375 376 MXNET_UNARY_MATH_OP(sigmoid_grad, math::id(a) * (1.0f - math::id(a))); 377 378 MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a))); 379 380 MXNET_UNARY_MATH_OP(softsign_grad, 1.0f / math::sqr(1.0f + math::fabs(a))); 381 382 MXNET_UNARY_MATH_OP_NC(selu, DType(SELU_LAMBDA) * 383 (a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a)))); 384 385 MXNET_UNARY_MATH_OP_NC(selu_grad, 386 DType(SELU_LAMBDA) * (a > DType(0) ? DType(1) : DType(SELU_ALPHA + a))); 387 388 MXNET_BINARY_MATH_OP_NC(prelu_grad, a > DType(0) ? DType(0) : a); 389 390 MXNET_BINARY_MATH_OP_NC(xelu, a > DType(0) ? a : 391 DType(static_cast<float>(a) * static_cast<float>(b))); 392 393 MXNET_BINARY_MATH_OP_NC(xelu_grad, a > DType(0) ? DType(1) : b); 394 395 MXNET_BINARY_MATH_OP_NC(elu, a > DType(0) ? a : 396 DType(math::id(b) * math::expm1(a))); 397 398 MXNET_BINARY_MATH_OP_NC(elu_grad, a > DType(0) ? DType(1) : DType(b + a)); 399 400 MXNET_SIMPLE_UNARY_MATH_OP(tanh); 401 402 MXNET_UNARY_MATH_OP(tanh_grad, 1.0f - math::sqr(a)); 403 404 /*! \brief SoftReLU, also known as softplus activation */ 405 struct softrelu : public mxnet_op::tunable { 406 template<typename DType> Mapsoftrelu407 MSHADOW_XINLINE static DType Map(DType a) { 408 // Avoid overflow of exp for large inputs. 409 // Thresholds 20.0 is chosen such that softrelu(a) = a 410 // for a > 20 using floating precision 411 if (a > DType(20.0f)) { 412 return a; 413 } else { 414 return DType(math::log1p(math::exp(a))); 415 } 416 } 417 }; 418 419 MXNET_UNARY_MATH_OP(softrelu_grad, -math::expm1(-a)); 420 421 MXNET_UNARY_MATH_OP(erfinv_grad, 0.5 * math::sqrt(PI) * math::exp(math::sqr(a))); 422 423 MXNET_UNARY_MATH_OP(erf_grad, 2.0 / math::sqrt(PI) * math::exp(-(a * a))); 424 425 MXNET_SIMPLE_UNARY_MATH_OP(erf); 426 427 MXNET_UNARY_MATH_OP(gelu, 428 DType(0.5f * static_cast<float>(a) * (1.0f + math::erf(static_cast<float>(a) / SQRT_2)))); 429 430 MXNET_BINARY_MATH_OP_NC(gelu_grad, 431 DType(0.5f * (1.0f + math::erf(static_cast<float>(a) / SQRT_2) + 432 static_cast<float>(a) * erf_grad::Map(static_cast<float>(a) / SQRT_2) / SQRT_2))); 433 434 MXNET_SIMPLE_UNARY_MATH_OP(exp); 435 436 MXNET_SIMPLE_UNARY_MATH_OP(expm1); 437 438 MXNET_SIMPLE_UNARY_MATH_OP(log); 439 440 MXNET_UNARY_MATH_OP(log_grad, 1.0f / math::id(a)); 441 442 MXNET_SIMPLE_UNARY_MATH_OP(log10); 443 444 // Constant is 1 / log(10) 445 struct log10_grad : public mxnet_op::tunable { 446 template<typename DType> Maplog10_grad447 MSHADOW_XINLINE static DType Map(DType a) { 448 return DType(0.4342944819f / static_cast<float>(a)); 449 } 450 }; 451 452 template<> 453 MSHADOW_XINLINE double log10_grad::Map<double>(double a) { 454 return 0.43429448190325182765 / a; 455 } 456 457 MXNET_SIMPLE_UNARY_MATH_OP(log2); 458 459 // Constant is 1 / log(2) 460 struct log2_grad : public mxnet_op::tunable { 461 template<typename DType> Maplog2_grad462 MSHADOW_XINLINE static DType Map(DType a) { 463 return DType(1.442695041f / static_cast<float>(a)); 464 } 465 }; 466 467 template<> 468 MSHADOW_XINLINE double log2_grad::Map<double>(double a) { 469 return 1.44269504088896340737 / a; 470 } 471 472 MXNET_SIMPLE_UNARY_MATH_OP(sin); 473 474 MXNET_UNARY_MATH_OP(sin_grad, math::cos(a)); 475 476 MXNET_SIMPLE_UNARY_MATH_OP(log1p); 477 478 MXNET_UNARY_MATH_OP(log1p_grad, 1.0f / (1.0f + math::id(a))); 479 480 MXNET_SIMPLE_UNARY_MATH_OP(cos); 481 482 MXNET_UNARY_MATH_OP(cos_grad, -math::sin(a)); 483 484 MXNET_SIMPLE_UNARY_MATH_OP(tan); 485 486 MXNET_UNARY_MATH_OP(tan_grad, math::sqr(a) + 1.0f); 487 488 MXNET_UNARY_MATH_OP(arcsin, math::asin(a)); 489 490 MXNET_UNARY_MATH_OP(arcsin_grad, 1.0f / math::sqrt(1.0f - math::sqr(a))); 491 492 MXNET_UNARY_MATH_OP(arccos, math::acos(a)); 493 494 MXNET_UNARY_MATH_OP(arccos_grad, -1.0f / math::sqrt(1.0f - math::sqr(a))); 495 496 MXNET_UNARY_MATH_OP(arctan, math::atan(a)); 497 498 MXNET_UNARY_MATH_OP(arctan_grad, 1.0f / (math::sqr(a) + 1.0f)); 499 500 MXNET_SIMPLE_BINARY_MATH_OP(hypot); 501 502 MXNET_BINARY_MATH_OP(hypot_grad_left, math::id(a) / math::hypot(a, b)); 503 504 MXNET_BINARY_MATH_OP(hypot_grad_right, math::id(b) / math::hypot(a, b)); 505 506 MXNET_UNARY_MATH_OP(degrees, 180.0f / PI * math::id(a)); 507 508 MXNET_UNARY_MATH_OP(degrees_grad, 180.0f / PI); 509 510 MXNET_UNARY_MATH_OP(radians, PI / 180.0f * math::id(a)); 511 512 MXNET_UNARY_MATH_OP(radians_grad, PI / 180.0f); 513 514 MXNET_SIMPLE_UNARY_MATH_OP(sinh); 515 516 MXNET_UNARY_MATH_OP(sinh_grad, math::cosh(a)); 517 518 MXNET_SIMPLE_UNARY_MATH_OP(cosh); 519 520 MXNET_UNARY_MATH_OP(cosh_grad, math::sinh(a)); 521 522 MXNET_UNARY_MATH_OP(arcsinh, math::asinh(a)); 523 524 MXNET_UNARY_MATH_OP(arcsinh_grad, 1.0f / math::hypot(a, DType(1))); 525 526 MXNET_UNARY_MATH_OP(arccosh, math::acosh(a)); 527 528 MXNET_UNARY_MATH_OP(arccosh_grad, 1.0f / math::sqrt(math::sqr(a) - 1.0f)); 529 530 MXNET_UNARY_MATH_OP(arctanh, math::atanh(a)); 531 532 MXNET_UNARY_MATH_OP(arctanh_grad, 1.0f / (1.0f - math::sqr(a))); 533 534 MXNET_UNARY_MATH_OP(square, math::sqr(a)); 535 536 MXNET_UNARY_MATH_OP(square_grad, 2.0f * math::id(a)); 537 538 /*! \brief used for generate Bernoulli mask */ 539 MXNET_BINARY_MATH_OP_NC(threshold, a < b ? DType(1) : DType(0)); 540 MXNET_BINARY_MATH_OP_NC(threshold_eq, a <= b ? DType(1) : DType(0)); 541 542 /*! \brief used for generate element of abs */ 543 MXNET_UNARY_MATH_OP(abs, math::fabs(a)); // NOLINT(*) 544 545 /*! \brief used for generate element of sign */ 546 struct sign : public mxnet_op::tunable { 547 template<typename DType> 548 MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type Mapsign549 Map(DType a) { 550 if (a < DType(0)) return DType(-DType(1)); 551 if (a > DType(0)) return DType(1); 552 return DType(0); 553 } 554 template<typename DType> 555 MSHADOW_XINLINE static typename enable_if<is_unsigned<DType>::value, DType>::type Mapsign556 Map(DType a) { 557 if (a > DType(0)) return DType(1); 558 return DType(0); 559 } 560 }; 561 562 MXNET_UNARY_MATH_OP_NC(sign_grad, DType(0)); 563 564 /*! \brief used for generate element of power */ 565 MXNET_BINARY_MATH_OP(power, math::pow(a, b)); 566 567 MXNET_BINARY_MATH_OP(power_grad, math::pow(a, b - DType(1)) * math::id(b)); 568 569 MXNET_BINARY_MATH_OP(power_rgrad, math::pow(a, b) * math::log(a)); 570 571 MXNET_BINARY_MATH_OP(rpower, math::pow(b, a)); 572 573 MXNET_BINARY_MATH_OP(rpower_grad, math::id(a) * math::log(b)); 574 575 MXNET_BINARY_MATH_OP(arctan2, math::atan2(a, b)); 576 MXNET_BINARY_MATH_OP(arctan2_grad, math::id(b) / (math::id(a * a + b * b))); 577 578 MXNET_BINARY_MATH_OP(arctan2_rgrad, -math::id(a) / (math::id(a * a + b * b))); 579 580 MXNET_BINARY_MATH_OP(rarctan2, math::atan2(b, a)); 581 582 MXNET_BINARY_MATH_OP(rarctan2_grad, math::id(a) / (math::id(a * a + b * b))); 583 584 MXNET_UNARY_MATH_OP_NC(nt, a != DType(0) ? DType(0) : DType(1)); 585 586 MXNET_UNARY_LOGIC_OP_NC(np_logical_not, !static_cast<bool>(a)); 587 588 MXNET_BINARY_MATH_OP_NC(ge, a >= b ? DType(1) : DType(0)); 589 590 MXNET_BINARY_MATH_OP_NC(gt, a > b ? DType(1) : DType(0)); 591 592 MXNET_BINARY_MATH_OP_NC(lt, a < b ? DType(1) : DType(0)); 593 594 MXNET_BINARY_MATH_OP_NC(le, a <= b ? DType(1) : DType(0)); 595 596 MXNET_BINARY_MATH_OP_NC(eq, a == b ? DType(1) : DType(0)); 597 598 MXNET_BINARY_MATH_OP_NC(ne, a != b ? DType(1) : DType(0)); 599 600 MXNET_BINARY_LOGIC_OP_NC(np_greater_equal, a >= b ? true : false); 601 602 MXNET_BINARY_LOGIC_OP_NC(np_greater, a > b ? true : false); 603 604 MXNET_BINARY_LOGIC_OP_NC(np_less, a < b ? true : false); 605 606 MXNET_BINARY_LOGIC_OP_NC(np_less_equal, a <= b ? true : false); 607 608 MXNET_BINARY_LOGIC_OP_NC(np_equal, a == b ? true : false); 609 610 MXNET_BINARY_LOGIC_OP_NC(np_not_equal, a != b ? true : false); 611 612 MXNET_BINARY_MATH_OP(logical_and, a && b ? DType(1) : DType(0)); 613 614 MXNET_BINARY_MATH_OP(logical_or, a || b ? DType(1) : DType(0)); 615 616 MXNET_BINARY_MATH_OP(logical_xor, (a || b) && !(a && b) ? DType(1) : DType(0)); 617 618 MXNET_BINARY_MATH_OP(bitwise_and, static_cast<int64_t>(a) & static_cast<int64_t>(b)); 619 620 MXNET_BINARY_MATH_OP(bitwise_xor, static_cast<int64_t>(a) ^ static_cast<int64_t>(b)); 621 622 MXNET_BINARY_MATH_OP(bitwise_or, static_cast<int64_t>(a) | static_cast<int64_t>(b)); 623 624 MXNET_UNARY_MATH_OP(square_root, math::sqrt(a)); 625 626 MXNET_UNARY_MATH_OP(square_root_grad, 0.5f / math::id(a)); 627 MXNET_UNARY_MATH_OP(reciprocal_square_root, 1.0f / math::sqrt(a)); 628 629 MXNET_UNARY_MATH_OP(reciprocal_square_root_grad, -0.5f / (math::sqrt(a) * math::id(a))); 630 631 MXNET_UNARY_MATH_OP(cube_root, math::cbrt(a)); 632 633 MXNET_UNARY_MATH_OP(cube_root_grad, 1.0f / (3.0f * math::sqr(a))); 634 635 MXNET_UNARY_MATH_OP(reciprocal_cube_root, 1.0f / math::cbrt(a)); 636 637 MXNET_UNARY_MATH_OP(reciprocal_cube_root_grad, -1.0f / (3.0f * math::cbrt(a) * math::id(a))); 638 639 /*! \brief used for generate element of ldexp */ 640 MXNET_BINARY_MATH_OP(ldexp, math::id(a) * math::pow(2.0f, b)); 641 642 MXNET_BINARY_MATH_OP(ldexp_grad, math::pow(2.0f, b)); 643 644 MXNET_BINARY_MATH_OP(ldexp_rgrad, math::id(a) * math::pow(2.0f, b) * math::log(2.0f)); 645 646 MXNET_BINARY_MATH_OP(rldexp, math::id(b) * math::pow(2.0f, a)); // swap a and b if a is scalar. 647 648 MXNET_BINARY_MATH_OP(rldexp_grad, math::id(b) * math::pow(2.0f, a) * math::log(2.0f)); 649 650 /*! \brief used for generate element of round */ 651 MXNET_SIMPLE_UNARY_MATH_OP(round); 652 653 /*! \brief used for generate element of ceil */ 654 MXNET_SIMPLE_UNARY_MATH_OP(ceil); 655 656 /*! \brief used for generate element of floor */ 657 MXNET_SIMPLE_UNARY_MATH_OP(floor); 658 659 /*! \brief used to round towards zero */ 660 MXNET_SIMPLE_UNARY_MATH_OP(trunc); 661 662 /*! \brief used to round number to nearest integer */ 663 struct rint : public mxnet_op::tunable { 664 template<typename DType> Maprint665 MSHADOW_XINLINE static DType Map(DType a) { 666 auto floor = math::floor(a); 667 auto ceil = math::ceil(a); 668 auto af = math::id(a); 669 return DType((af - floor) <= (ceil - af) ? floor : ceil); 670 } 671 }; 672 673 /*! \brief used to round number to integer nearest to 0 */ 674 struct fix : public mxnet_op::tunable { 675 template<typename DType> Mapfix676 MSHADOW_XINLINE static DType Map(DType a) { 677 auto floor = math::floor(a); 678 auto ceil = math::ceil(a); 679 return DType((floor > 0 ? floor : -floor) < (ceil > 0 ? ceil : -ceil) ? floor : ceil); 680 } 681 }; 682 683 /*! \brief used to determine whether a number is Not A Number*/ 684 struct isnan : public mxnet_op::tunable { 685 template<typename DType> Mapisnan686 MSHADOW_XINLINE static bool Map(DType a) { 687 return IsNan(a); 688 } 689 }; 690 691 /*! \brief used to determine whether a number is infinite*/ 692 struct isinf : public mxnet_op::tunable { 693 template<typename DType> Mapisinf694 MSHADOW_XINLINE static bool Map(DType a) { 695 return IsInf(a); 696 } 697 }; 698 699 /*! \brief used to determine whether a number is finite*/ 700 struct isfinite : public mxnet_op::tunable { 701 template<typename DType> Mapisfinite702 MSHADOW_XINLINE static bool Map(DType a) { 703 return !IsNan(a) && !IsInf(a); 704 } 705 }; 706 707 /*! \brief used to determine whether a number is positive infinity*/ 708 struct isposinf : public mxnet_op::tunable { 709 template<typename DType> Mapisposinf710 MSHADOW_XINLINE static bool Map(DType a) { 711 return IsInf(a) && a > 0; 712 } 713 }; 714 715 /*! \brief used to determine whether a number is negative infinity*/ 716 struct isneginf : public mxnet_op::tunable { 717 template<typename DType> Mapisneginf718 MSHADOW_XINLINE static bool Map(DType a) { 719 return IsInf(a) && a < 0; 720 } 721 }; 722 723 /*! \brief used for generate gradient of MAE loss*/ 724 MXNET_BINARY_MATH_OP_NC(minus_sign, a - b > DType(0) ? DType(1) : -DType(1)); 725 726 MXNET_BINARY_MATH_OP(rminus, b - a); 727 728 MXNET_BINARY_MATH_OP_NC(posone, 1); 729 730 MXNET_BINARY_MATH_OP_NC(negone, -1); 731 732 MXNET_BINARY_MATH_OP(div_grad, 1.0f / math::id(b)); 733 734 template<> 735 MSHADOW_XINLINE mshadow::half::half2_t div_grad::Map<mshadow::half::half2_t> 736 (mshadow::half::half2_t a, 737 mshadow::half::half2_t b) { 738 return mshadow::half::half2_t(1) / b; 739 } 740 741 MXNET_BINARY_MATH_OP(div_rgrad, -math::id(a) / math::sqr(b)); 742 743 template<> 744 MSHADOW_XINLINE mshadow::half::half2_t div_rgrad::Map<mshadow::half::half2_t> 745 (mshadow::half::half2_t a, 746 mshadow::half::half2_t b) { 747 return -a / (b * b); 748 } 749 750 MXNET_BINARY_MATH_OP(rdiv, math::id(b) / math::id(a)); 751 752 MXNET_BINARY_MATH_OP(rdiv_grad, -math::id(b) / math::sqr(a)); 753 754 MXNET_BINARY_MATH_OP(copysign, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? a : -a); 755 756 MXNET_BINARY_MATH_OP(copysign_grad, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? 1: -1); 757 758 MXNET_BINARY_MATH_OP(copysign_rgrad, 0); 759 760 MXNET_BINARY_MATH_OP(rcopysign, (b >= 0 && a >= 0) || (b < 0 && a < 0) ? b : -b); 761 762 MXNET_BINARY_MATH_OP(rcopysign_grad, 0); 763 764 struct mod : public mxnet_op::tunable { 765 template<typename DType> 766 MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type Mapmod767 Map(DType a, DType b) { 768 if (b == DType(0)) { 769 return DType(0); 770 } else if (b < DType(0)) { 771 if (a < DType(0)) { 772 return DType(-::fmod(-static_cast<double>(a), -static_cast<double>(b))); 773 } else { 774 return DType(::fmod(static_cast<double>(a), -static_cast<double>(b)) + 775 (::fmod(static_cast<double>(a), -static_cast<double>(b)) != DType(0) 776 ? b : DType(0))); 777 } 778 } else { 779 if (a < DType(0)) { 780 return DType(-::fmod(-static_cast<double>(a), static_cast<double>(b)) + 781 (::fmod(-static_cast<double>(a), static_cast<double>(b)) != DType(0) 782 ? b : DType(0))); 783 } else { 784 return DType(::fmod(static_cast<double>(a), static_cast<double>(b))); 785 } 786 } 787 } 788 template<typename DType> 789 MSHADOW_XINLINE static typename enable_if<is_unsigned<DType>::value, DType>::type Mapmod790 Map(DType a, DType b) { 791 if (b == DType(0)) { 792 return DType(0); 793 } else { 794 return DType(::fmod(static_cast<double>(a), static_cast<double>(b))); 795 } 796 } 797 }; 798 799 struct mixed_mod { 800 template<typename DType, 801 typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> Mapmixed_mod802 MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { 803 return mod::Map(static_cast<mshadow::half::half_t>(a), b); 804 } 805 806 template<typename DType, 807 typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || 808 std::is_integral<DType>::value, int>::type = 0> Mapmixed_mod809 MSHADOW_XINLINE static float Map(DType a, float b) { 810 return mod::Map(static_cast<float>(a), b); 811 } 812 813 template<typename DType, 814 typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || 815 std::is_same<DType, float>::value || 816 std::is_integral<DType>::value, int>::type = 0> Mapmixed_mod817 MSHADOW_XINLINE static double Map(DType a, double b) { 818 return mod::Map(static_cast<double>(a), b); 819 } 820 }; 821 822 struct mixed_rmod { 823 template<typename DType, 824 typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> Mapmixed_rmod825 MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { 826 return mod::Map(b, static_cast<mshadow::half::half_t>(a)); 827 } 828 829 template<typename DType, 830 typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || 831 std::is_integral<DType>::value, int>::type = 0> Mapmixed_rmod832 MSHADOW_XINLINE static float Map(DType a, float b) { 833 return mod::Map(b, static_cast<float>(a)); 834 } 835 836 template<typename DType, 837 typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || 838 std::is_same<DType, float>::value || 839 std::is_integral<DType>::value, int>::type = 0> Mapmixed_rmod840 MSHADOW_XINLINE static double Map(DType a, double b) { 841 return mod::Map(b, static_cast<double>(a)); 842 } 843 }; 844 845 struct fmod : public mxnet_op::tunable { 846 template<typename DType> Mapfmod847 MSHADOW_XINLINE static DType Map(DType a, DType b) { 848 if (b == DType(0)) { 849 return DType(0); 850 } else { 851 return DType(::fmod(static_cast<double>(a), static_cast<double>(b))); 852 } 853 } 854 }; 855 856 struct rfmod : public mxnet_op::tunable { 857 template<typename DType> Maprfmod858 MSHADOW_XINLINE static DType Map(DType a, DType b) { 859 if (a == DType(0)) { 860 return DType(0); 861 } else { 862 return DType(::fmod(static_cast<double>(b), static_cast<double>(a))); 863 } 864 } 865 }; 866 867 template<> 868 MSHADOW_XINLINE mshadow::half::half2_t mod::Map<mshadow::half::half2_t> 869 (mshadow::half::half2_t a, 870 mshadow::half::half2_t b) { 871 return a%b; 872 } 873 874 struct mod_grad : public mxnet_op::tunable { 875 template<typename DType> Mapmod_grad876 MSHADOW_XINLINE static DType Map(DType a, DType b) { 877 return DType(0); 878 } 879 }; 880 template<> 881 MSHADOW_XINLINE double mod_grad::Map<double>(double a, double b) { 882 return 1.0; 883 } 884 template<> 885 MSHADOW_XINLINE float mod_grad::Map<float>(float a, float b) { 886 return 1.0f; 887 } 888 889 template<> 890 MSHADOW_XINLINE mshadow::half::half_t mod_grad::Map<mshadow::half::half_t> 891 (mshadow::half::half_t a, 892 mshadow::half::half_t b) { 893 return mshadow::half::half_t(1.0f); 894 } 895 template<> 896 MSHADOW_XINLINE mshadow::half::half2_t mod_grad::Map<mshadow::half::half2_t> 897 (mshadow::half::half2_t a, 898 mshadow::half::half2_t b) { 899 mshadow::half::half2_t result = mshadow::half::half2_t(); 900 #if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2) 901 result.half2_ = ::__float2half2_rn(1.0f); 902 #else 903 result.half_t2[0] = mshadow::half::half_t(0.0f); 904 result.half_t2[1] = mshadow::half::half_t(1.0f); 905 #endif 906 return result; 907 } 908 909 struct mod_rgrad : public mxnet_op::tunable { 910 template<typename DType> Mapmod_rgrad911 MSHADOW_XINLINE static DType Map(DType a, DType b) { 912 return DType(0); 913 } 914 }; 915 template<> 916 MSHADOW_XINLINE double mod_rgrad::Map<double>(double a, double b) { 917 return -::floor(a/b); 918 } 919 template<> 920 MSHADOW_XINLINE float mod_rgrad::Map<float>(float a, float b) { 921 return -::floorf(a/b); 922 } 923 924 template<> 925 MSHADOW_XINLINE mshadow::half::half_t mod_rgrad::Map<mshadow::half::half_t> 926 (mshadow::half::half_t a, 927 mshadow::half::half_t b) { 928 return mshadow::half::half_t(-::floorf(static_cast<float>(a/b))); 929 } 930 template<> 931 MSHADOW_XINLINE mshadow::half::half2_t mod_rgrad::Map<mshadow::half::half2_t> 932 (mshadow::half::half2_t a, 933 mshadow::half::half2_t b) { 934 #if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2) 935 return mshadow::half::half2_t(__hneg2(::h2floor((a/b).half2_))); 936 #else 937 return mshadow::half::half2_t(mshadow::half::half_t(-::floorf( 938 static_cast<float>(a.half_t2[0]/b.half_t2[0]))), 939 mshadow::half::half_t(-::floorf( 940 static_cast<float>(a.half_t2[1]/b.half_t2[1])))); 941 #endif 942 } 943 944 struct rmod : public mxnet_op::tunable { 945 template<typename DType> 946 MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type Maprmod947 Map(DType a, DType b) { 948 if (a == DType(0)) { 949 return DType(0); 950 } else if (a < DType(0)) { 951 if (b < DType(0)) { 952 return DType(-::fmod(-static_cast<double>(b), -static_cast<double>(a))); 953 } else { 954 return DType(::fmod(static_cast<double>(b), -static_cast<double>(a)) + 955 (::fmod(static_cast<double>(b), -static_cast<double>(a)) != DType(0) 956 ? a : DType(0))); 957 } 958 } else { 959 if (b < DType(0)) { 960 return DType(-::fmod(-static_cast<double>(b), static_cast<double>(a)) + 961 (::fmod(-static_cast<double>(b), static_cast<double>(a)) != DType(0) 962 ? a : DType(0))); 963 } else { 964 return DType(::fmod(static_cast<double>(b), static_cast<double>(a))); 965 } 966 } 967 } 968 template<typename DType> 969 MSHADOW_XINLINE static typename enable_if<is_unsigned<DType>::value, DType>::type Maprmod970 Map(DType a, DType b) { 971 if (a == DType(0)) { 972 return DType(0); 973 } else { 974 return DType(::fmod(static_cast<double>(b), static_cast<double>(a))); 975 } 976 } 977 }; 978 979 template<> 980 MSHADOW_XINLINE mshadow::half::half2_t rmod::Map<mshadow::half::half2_t> 981 (mshadow::half::half2_t a, 982 mshadow::half::half2_t b) { 983 return b%a; 984 } 985 986 struct rmod_grad { 987 template<typename DType> Maprmod_grad988 MSHADOW_XINLINE static DType Map(DType a, DType b) { 989 return DType(0); 990 } 991 }; 992 template<> 993 MSHADOW_XINLINE double rmod_grad::Map<double>(double a, double b) { 994 return -::floor(b/a); 995 } 996 template<> 997 MSHADOW_XINLINE float rmod_grad::Map<float>(float a, float b) { 998 return -::floorf(b/a); 999 } 1000 1001 template<> 1002 MSHADOW_XINLINE mshadow::half::half_t rmod_grad::Map<mshadow::half::half_t> 1003 (mshadow::half::half_t a, 1004 mshadow::half::half_t b) { 1005 return mshadow::half::half_t(-::floorf(static_cast<float>(b/a))); 1006 } 1007 template<> 1008 MSHADOW_XINLINE mshadow::half::half2_t rmod_grad::Map<mshadow::half::half2_t> 1009 (mshadow::half::half2_t a, 1010 mshadow::half::half2_t b) { 1011 #if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2) 1012 return mshadow::half::half2_t(::__hneg2(::h2floor((b/a).half2_))); 1013 #else 1014 return mshadow::half::half2_t(mshadow::half::half_t(-::floorf( 1015 static_cast<float>(b.half_t2[0]/a.half_t2[0]))), 1016 mshadow::half::half_t(-::floorf( 1017 static_cast<float>(b.half_t2[1]/a.half_t2[1])))); 1018 #endif 1019 } 1020 1021 struct clip : public mxnet_op::tunable { 1022 template<typename DType> Mapclip1023 MSHADOW_XINLINE static DType Map(DType x, DType bound) { 1024 if (x > bound) { 1025 return bound; 1026 } else if (x < -bound) { 1027 return -bound; 1028 } else { 1029 return x; 1030 } 1031 } 1032 template<typename DType> Mapclip1033 MSHADOW_XINLINE static DType Map(DType x, DType lower_bound, DType upper_bound) { 1034 if (x > upper_bound) { 1035 return upper_bound; 1036 } else if (x < lower_bound) { 1037 return lower_bound; 1038 } 1039 return x; 1040 } 1041 }; 1042 1043 /***** gamma ******/ 1044 1045 MXNET_UNARY_MATH_OP(gamma, math::tgamma(a)); 1046 1047 struct gamma_grad : public mxnet_op::tunable { 1048 template<typename DType> Mapgamma_grad1049 MSHADOW_XINLINE static DType Map(DType a) { 1050 // default implementation using floating precision 1051 float af(static_cast<float>(a)); 1052 return DType(math::tgamma(af) * special_functions::cephes::psi<float>(af)); 1053 } 1054 }; 1055 1056 template<> 1057 MSHADOW_XINLINE double gamma_grad::Map<double>(double a) { 1058 return math::tgamma(a) * special_functions::cephes::psi<double>(a); 1059 } 1060 1061 /***** gammaln ******/ 1062 1063 MXNET_UNARY_MATH_OP(gammaln, math::lgamma(a)); 1064 1065 struct gammaln_grad : public mxnet_op::tunable { 1066 template<typename DType> Mapgammaln_grad1067 MSHADOW_XINLINE static DType Map(DType a) { 1068 // default implementation using floating precision 1069 return DType(special_functions::cephes::psi<float>(a)); 1070 } 1071 }; 1072 1073 template<> 1074 MSHADOW_XINLINE double gammaln_grad::Map<double>(double a) { 1075 return special_functions::cephes::psi<double>(a); 1076 } 1077 1078 /* Smooth L1 Loss is a loss specific for R-CNN franchise training 1079 * Smooth L1 Loss function: 1080 * f(x) = 0.5 * (sigma * x) ^ 2, |x| < 1 / sigma^2 1081 * = |x| - 0.5 / sigma / sigma, otherwise 1082 * When sigma = 1, it is equivalent to the Huber loss, evaluated at 1083 * delta = 1. 1084 * smooth_l1_loss = w_out * f(w_in * x) 1085 * with w_in, w_out provided by input_data. 1086 */ 1087 struct smooth_l1_loss : public mxnet_op::tunable { 1088 // a is x, b is sigma 1089 template<typename DType> Mapsmooth_l1_loss1090 MSHADOW_XINLINE static DType Map(DType a, DType b) { 1091 auto bsq = math::sqr(b); 1092 auto ibsq = 1.0f / bsq; 1093 auto af = math::id(a); 1094 if (af > ibsq) { 1095 return DType(af - 0.5f * ibsq); 1096 } else if (af < -ibsq) { 1097 return DType(-af - 0.5f * ibsq); 1098 } else { 1099 return DType(0.5f * af * af * bsq); 1100 } 1101 } 1102 }; // struct smooth_l1_loss 1103 1104 /* The derivative of smooth l1 loss is 1105 * f'(x) = sigma^2 * x, |x| < 1 / sigma^2 1106 * = sign(x), otherwise 1107 */ 1108 struct smooth_l1_gradient : public mxnet_op::tunable { 1109 // a is x, b is sigma2 1110 template<typename DType> Mapsmooth_l1_gradient1111 MSHADOW_XINLINE static DType Map(DType a, DType b) { 1112 auto bsq = math::sqr(b); 1113 auto ibsq = 1.0f / bsq; 1114 auto af = math::id(a); 1115 if (af > ibsq) { 1116 return DType(1); 1117 } else if (af < -ibsq) { 1118 return DType(-1); 1119 } else { 1120 return DType(bsq * af); 1121 } 1122 } 1123 }; // struct smooth_l1_derivative 1124 1125 /*! \brief product reducer */ 1126 struct product { 1127 /*! \brief do reduction into dst */ 1128 template<typename DType> Reduceproduct1129 MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) 1130 dst *= src; 1131 } 1132 /*! \brief do reduction into dst */ 1133 template<typename DType> Reduceproduct1134 MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& none) { // NOLINT(*) 1135 Reduce(dst, src); 1136 } 1137 /*! \brief combine the results of two reducers */ 1138 template<typename DType> Mergeproduct1139 MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) 1140 Reduce(dst_val, src_val); 1141 } 1142 /*! \brief combine the results of two reducers */ 1143 template<typename DType> Mergeproduct1144 MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) 1145 Reduce(dst_val, src_val); 1146 } 1147 /*! \brief finalize reduction */ 1148 template<typename DType> Finalizeproduct1149 MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) 1150 /*! \brief finalize reduction */ 1151 template<typename DType> Finalizeproduct1152 MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& none) {} // NOLINT(*) 1153 /*! 1154 *\brief calculate gradient of redres with respect to redsrc, 1155 * redres: reduced result, redsrc: one of reduction element 1156 */ 1157 template<typename DType> PartialGradproduct1158 MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { 1159 return redres / redsrc; 1160 } 1161 /*! 1162 *\brief set the initial value during reduction 1163 */ 1164 template<typename DType> SetInitValueproduct1165 MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) 1166 initv = 1; 1167 } 1168 /*! 1169 *\brief set the initial value during reduction 1170 */ 1171 template<typename DType> SetInitValueproduct1172 MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*) 1173 SetInitValue(initv); 1174 } 1175 }; 1176 1177 MXNET_UNARY_MATH_OP_NC(relu, IsNan(a) || (a > DType(0)) ? a : DType(0)); 1178 1179 /*! \brief used for computing gradient of relu operator */ 1180 struct relu_grad : public mxnet_op::tunable { 1181 template<typename DType> Maprelu_grad1182 MSHADOW_XINLINE static DType Map(DType a) { 1183 if (IsNan(a)) { 1184 return a; 1185 } else { 1186 return a > DType(0) ? DType(1) : DType(0); 1187 } 1188 } 1189 }; 1190 1191 /*! \brief used for computing binary operator maximum */ 1192 struct maximum : public mxnet_op::tunable { 1193 template<typename DType> Mapmaximum1194 MSHADOW_XINLINE static DType Map(DType a, DType b) { 1195 if (IsNan(a)) { 1196 return a; 1197 } else { 1198 return (a > b ? a : b); 1199 } 1200 } 1201 }; 1202 1203 /*! \brief used for computing binary operator minimum */ 1204 struct minimum : public mxnet_op::tunable { 1205 template<typename DType> Mapminimum1206 MSHADOW_XINLINE static DType Map(DType a, DType b) { 1207 if (IsNan(a)) { 1208 return a; 1209 } else { 1210 return DType(a < b ? a : b); 1211 } 1212 } 1213 }; 1214 1215 /*! \brief boolean any/all kernel that determines whether elem is NonZero */ 1216 struct NonZero { 1217 template<typename DType> MapNonZero1218 MSHADOW_XINLINE static bool Map(DType a) { 1219 return (a != DType(0)); 1220 } 1221 }; 1222 1223 /*! \brief sum reducer that ignores NaN values in the input */ 1224 struct nansum { 1225 /*! \brief do reduction into dst */ 1226 template<typename DType> Reducenansum1227 MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) 1228 if (IsNan(src)) return; 1229 dst += src; 1230 } 1231 /*! \brief do reduction into dst */ 1232 template<typename DType> Reducenansum1233 MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*) 1234 if (IsNan(src)) return; 1235 DType y = src - residual; 1236 DType t = dst + y; 1237 residual = (t - dst) - y; 1238 dst = t; 1239 } 1240 /*! \brief combine the results of two reducers */ 1241 template<typename DType> Mergenansum1242 MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) 1243 Reduce(dst_val, src_val); 1244 } 1245 /*! \brief combine the results of two reducers */ 1246 template<typename DType> Mergenansum1247 MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) 1248 DType t1 = dst_val + src_val; 1249 DType e = t1 - src_val; 1250 DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual; 1251 dst_val = t1 + t2; 1252 dst_residual = t2 - (dst_val - t1); 1253 } 1254 /*! \brief finalize reduction */ 1255 template<typename DType> Finalizenansum1256 MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) 1257 /*! \brief finalize reduction */ 1258 template<typename DType> Finalizenansum1259 MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*) 1260 /*! 1261 *\brief set the initial value during reduction 1262 */ 1263 template<typename DType> SetInitValuenansum1264 MSHADOW_XINLINE static void SetInitValue(DType & initv) { // NOLINT(*) 1265 initv = 0; 1266 } 1267 /*! 1268 *\brief set the initial value during reduction 1269 */ 1270 template<typename DType> SetInitValuenansum1271 MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*) 1272 SetInitValue(initv); 1273 residual = 0; 1274 } 1275 }; 1276 1277 struct nansum_grad : public mxnet_op::tunable { 1278 template<typename DType> Mapnansum_grad1279 MSHADOW_XINLINE static DType Map(DType a, DType b) { 1280 return IsNan(a) ? DType(0) : DType(1); 1281 } 1282 }; 1283 1284 /*! \brief product reducer that ignores NaN values in the input */ 1285 struct nanprod { 1286 /*! \brief do reduction into dst */ 1287 template<typename DType> Reducenanprod1288 MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) 1289 if (IsNan(src)) return; 1290 dst *= src; 1291 } 1292 /*! \brief do reduction into dst */ 1293 template<typename DType> Reducenanprod1294 MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& none) { // NOLINT(*) 1295 Reduce(dst, src); 1296 } 1297 /*! \brief combine the results of two reducers */ 1298 template<typename DType> Mergenanprod1299 MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) 1300 Reduce(dst_val, src_val); 1301 } 1302 /*! \brief combine the results of two reducers */ 1303 template<typename DType> Mergenanprod1304 MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) 1305 Reduce(dst_val, src_val); 1306 } 1307 /*! \brief finalize reduction */ 1308 template<typename DType> Finalizenanprod1309 MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) 1310 /*! \brief finalize reduction */ 1311 template<typename DType> Finalizenanprod1312 MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& none) {} // NOLINT(*) 1313 /*! 1314 *\brief set the initial value during reduction 1315 */ 1316 template<typename DType> SetInitValuenanprod1317 MSHADOW_XINLINE static void SetInitValue(DType & initv) { // NOLINT(*) 1318 initv = 1; 1319 } 1320 1321 /*! 1322 *\brief set the initial value during reduction 1323 */ 1324 template<typename DType> SetInitValuenanprod1325 MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*) 1326 SetInitValue(initv); 1327 } 1328 }; 1329 1330 /*! \brief compute l2 norm */ 1331 struct nrm2 { 1332 /*! \brief do reduction into dst */ 1333 template<typename AType, typename DType> Reducenrm21334 MSHADOW_XINLINE static void Reduce(volatile AType& sum_of_squares, volatile DType src) { // NOLINT(*) 1335 sum_of_squares += src * src; 1336 } 1337 /*! \brief do stable reduction into dst */ 1338 template<typename AType, typename DType> Reducenrm21339 MSHADOW_XINLINE static void Reduce(volatile AType& sum_of_squares, volatile DType src, volatile DType& scale) { // NOLINT(*) 1340 if (src != 0) { 1341 DType abs = mshadow_op::abs::Map(src); 1342 if (scale < abs) { 1343 sum_of_squares = 1 + sum_of_squares * (scale / abs) * (scale / abs); 1344 scale = abs; 1345 } else { 1346 sum_of_squares = sum_of_squares + (abs / scale) * (abs / scale); 1347 } 1348 } 1349 } 1350 /*! \brief combine the results of two reducers */ 1351 template<typename DType> Mergenrm21352 MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) 1353 dst_val += src_val; 1354 } 1355 /*! \brief combine the results of two reducers */ 1356 template<typename DType> Mergenrm21357 MSHADOW_XINLINE static void Merge(volatile DType& dst_ssq, volatile DType& dst_scale, volatile DType& src_ssq, volatile DType& src_scale) { // NOLINT(*) 1358 if (dst_scale != 0 && dst_scale >= src_scale) { 1359 dst_ssq = dst_ssq + src_ssq * (src_scale / dst_scale) * (src_scale / dst_scale); 1360 } else if (src_scale != 0 && dst_scale < src_scale) { 1361 dst_ssq = src_ssq + dst_ssq * (dst_scale / src_scale) * (dst_scale / src_scale); 1362 dst_scale = src_scale; 1363 } 1364 } 1365 /*! \brief finalize reduction result */ 1366 template<typename DType> Finalizenrm21367 MSHADOW_XINLINE static void Finalize(volatile DType& sum_of_squares) { // NOLINT(*) 1368 sum_of_squares = math::sqrt(sum_of_squares); 1369 } 1370 /*! \brief finalize reduction result */ 1371 template<typename DType> Finalizenrm21372 MSHADOW_XINLINE static void Finalize(volatile DType& sum_of_squares, volatile DType& scale) { // NOLINT(*) 1373 sum_of_squares = scale * math::sqrt(sum_of_squares); 1374 } 1375 /*! 1376 *\brief calculate gradient of redres with respect to redsrc, 1377 * redres: reduced result, redsrc: one of reduction element 1378 */ 1379 template<typename DType> PartialGradnrm21380 MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { 1381 return redsrc / redres; 1382 } 1383 /*! 1384 *\brief set the initial value during reduction 1385 */ 1386 template<typename DType> SetInitValuenrm21387 MSHADOW_XINLINE static void SetInitValue(DType &sum_of_squares) { // NOLINT(*) 1388 sum_of_squares = 0; 1389 } 1390 /*! 1391 *\brief set the initial value during reduction 1392 */ 1393 template<typename DType> SetInitValuenrm21394 MSHADOW_XINLINE static void SetInitValue(DType &sum_of_squares, DType &scale) { // NOLINT(*) 1395 SetInitValue(sum_of_squares); 1396 scale = 0; 1397 } 1398 }; 1399 1400 /*! \brief sum reducer */ 1401 struct sum { 1402 /*! \brief do reduction into dst */ 1403 template<typename AType, typename DType> Reducesum1404 MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src) { // NOLINT(*) 1405 dst += src; 1406 } 1407 /*! \brief do stable reduction into dst */ 1408 template<typename AType, typename DType> Reducesum1409 MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*) 1410 DType y = src - residual; 1411 DType t = dst + y; 1412 residual = (t - dst) - y; 1413 dst = t; 1414 } 1415 /*! \brief combine the results of two reducers */ 1416 template<typename DType> Mergesum1417 MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) 1418 Reduce(dst_val, src_val); 1419 } 1420 /*! \brief combine the results of two reducers */ 1421 template<typename DType> Mergesum1422 MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) 1423 DType t1 = dst_val + src_val; 1424 DType e = t1 - dst_val; 1425 DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual; 1426 dst_val = t1 + t2; 1427 dst_residual = t2 - (dst_val - t1); 1428 } 1429 /*! \brief finalize reduction */ 1430 template<typename DType> Finalizesum1431 MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) 1432 /*! \brief finalize reduction */ 1433 template<typename DType> Finalizesum1434 MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*) 1435 /*! 1436 *\brief calculate gradient of redres with respect to redsrc, 1437 * redres: reduced result, redsrc: one of reduction element 1438 */ 1439 template<typename DType> PartialGradsum1440 MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { 1441 return 1; 1442 } 1443 /*! 1444 *\brief set the initial value during reduction 1445 */ 1446 template<typename DType> SetInitValuesum1447 MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) 1448 initv = 0; 1449 } 1450 /*! 1451 *\brief set the initial value during reduction 1452 */ 1453 template<typename DType> SetInitValuesum1454 MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*) 1455 SetInitValue(initv); 1456 residual = 0; 1457 } 1458 }; 1459 1460 struct nanprod_grad : public mxnet_op::tunable { 1461 template<typename DType> Mapnanprod_grad1462 MSHADOW_XINLINE static DType Map(DType a, DType b) { 1463 return IsNan(a) ? DType(0) : b / a; 1464 } 1465 }; 1466 1467 /*! \brief used for computing binary lowest common multiple */ 1468 struct lcm : public mxnet_op::tunable { 1469 template<typename DType> 1470 MSHADOW_XINLINE static typename enable_if<is_integral<DType>::value, DType>::type Maplcm1471 Map(DType a, DType b) { 1472 // minus cases. 1473 if (a < 0) { 1474 a = -a; 1475 } 1476 if (b < 0) { 1477 b = -b; 1478 } 1479 // handle zero-valued cases. 1480 DType c; 1481 if (a == 0 || b == 0) { 1482 c = 0; 1483 } else { 1484 DType tmp; 1485 DType tmp_a = a; 1486 DType tmp_b = b; 1487 if (a < b) { 1488 tmp = a; 1489 a = b; 1490 b = tmp; 1491 } 1492 while (a % b != 0) { 1493 a = a % b; 1494 tmp = a; 1495 a = b; 1496 b = tmp; 1497 } 1498 c = tmp_a / b * tmp_b; 1499 } 1500 return c; 1501 } 1502 template<typename DType> 1503 MSHADOW_XINLINE static typename enable_if<!is_integral<DType>::value, DType>::type Maplcm1504 Map(DType a, DType b) { 1505 return DType(0.0f); 1506 } 1507 }; 1508 1509 } // namespace mshadow_op 1510 } // namespace op 1511 } // namespace mxnet 1512 #endif // MXNET_OPERATOR_MSHADOW_OP_H_ 1513