1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16/* 17 Convert to metal by MNN. 18 Copyright © 2018, Alibaba Group Holding Limited 19 */ 20 21#include <metal_stdlib> 22#include "MetalDefine.metal" 23 24using namespace metal; 25 26namespace MNN { 27 // Part 1: Low-level integer-arithmetic primitives. 28 template <typename tIntegerType> 29 struct FixedPointRawTypeTraits {}; 30 31 template <> 32 struct FixedPointRawTypeTraits<int32_t> { 33 typedef int32_t ScalarRawType; 34 static constant int kLanes = 1; 35 }; 36 37 template <> 38 struct FixedPointRawTypeTraits<int16_t> { 39 typedef int16_t ScalarRawType; 40 static constant int kLanes = 1; 41 }; 42 43 // Returns a SIMD value duplicating a scalar value across all lanes. 44 template <typename tRawType> 45 tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) { 46 return x; 47 } 48 49 // Plain bit-wise AND 50 template <typename tIntegerType> 51 tIntegerType BitAnd(tIntegerType a, tIntegerType b) { 52 return a & b; 53 } 54 55 // Plain bit-wise OR 56 template <typename tIntegerType> 57 tIntegerType BitOr(tIntegerType a, tIntegerType b) { 58 return a | b; 59 } 60 61 // Plain bit-wise XOR 62 template <typename tIntegerType> 63 tIntegerType BitXor(tIntegerType a, tIntegerType b) { 64 return a ^ b; 65 } 66 67 // Plain bit-wise NOT 68 template <typename tIntegerType> 69 tIntegerType BitNot(tIntegerType a) { 70 return ~a; 71 } 72 73 // Integer addition. Not saturating. Overflow is undefined behavior. 74 template <typename tIntegerType> 75 tIntegerType Add(tIntegerType a, tIntegerType b) { 76 return a + b; 77 } 78 79 // Integer subtraction. Not saturating. Overflow is undefined behavior. 80 template <typename tIntegerType> 81 tIntegerType Mul(tIntegerType a, tIntegerType b) { 82 return a * b; 83 } 84 85 template <typename tIntegerType> 86 tIntegerType Sub(tIntegerType a, tIntegerType b) { 87 return a - b; 88 } 89 90 // Integer unary negative. Not saturating. Overflow is undefined behavior. 91 template <typename tIntegerType> 92 tIntegerType Neg(tIntegerType a) { 93 return -a; 94 } 95 96 // Integer arithmetic left-shift, equivalent to multiplying with a power of two. 97 // Not saturating. Negative inputs do not necessarily invoke undefined 98 // behaviour. Overflow is undefined behavior. 99 template <typename tIntegerType> 100 tIntegerType ShiftLeft(tIntegerType a, int offset) { 101 return a * (static_cast<tIntegerType>(1) << offset); 102 } 103 104 // Integer arithmetic right-shift. Not rounding. 105 // Relying on implementation-defined, but in-practice-consistent, 106 // C++ compiler behavior. 107 template <typename tIntegerType> 108 tIntegerType ShiftRight(tIntegerType a, int offset) { 109 return a >> offset; 110 } 111 112 // Each bit of the result is set to the corresponding bit of either then_val or 113 // else_val depending on whether the corresponding bit of if_mask is set. 114 // Equivalent to the VBSL instruction in ARM NEON. 115 template <typename tIntegerType> 116 tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, 117 tIntegerType else_val) { 118 return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val)); 119 } 120 121 // For each input scalar, the corresponding bits of the result are set if the 122 // input scalar is non-zero. 123 template <typename tIntegerType> 124 tIntegerType MaskIfNonZero(tIntegerType a) { 125 constexpr tIntegerType zero = 0; 126 return a ? BitNot(zero) : zero; 127 } 128 129 // For each input scalar, the corresponding bits of the result are set if the 130 // input scalar is zero. 131 template <typename tIntegerType> 132 tIntegerType MaskIfZero(tIntegerType a) { 133 return MaskIfNonZero<tIntegerType>(!a); 134 } 135 136 // For each pair of input scalars, the corresponding bits of the result are 137 // set if the input scalars are equal. 138 template <typename tIntegerType> 139 tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) { 140 return MaskIfNonZero<tIntegerType>(a == b); 141 } 142 143 // For each pair of input scalars, the corresponding bits of the result are 144 // set if the input scalars are not equal. 145 template <typename tIntegerType> 146 tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) { 147 return MaskIfNonZero<tIntegerType>(a != b); 148 } 149 150 // For each pair of input scalars, the corresponding bits of the result are 151 // set if the input scalars a, b satisfy a > b. 152 template <typename tIntegerType> 153 tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) { 154 return MaskIfNonZero<tIntegerType>(a > b); 155 } 156 157 // For each pair of input scalars, the corresponding bits of the result are 158 // set if the input scalars a, b satisfy a >= b. 159 template <typename tIntegerType> 160 tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) { 161 return MaskIfNonZero<tIntegerType>(a >= b); 162 } 163 164 // For each pair of input scalars, the corresponding bits of the result are 165 // set if the input scalars a, b satisfy a < b. 166 template <typename tIntegerType> 167 tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) { 168 return MaskIfNonZero<tIntegerType>(a < b); 169 } 170 171 // For each pair of input scalars, the corresponding bits of the result are 172 // set if the input scalars a, b satisfy a <= b. 173 template <typename tIntegerType> 174 tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) { 175 return MaskIfNonZero<tIntegerType>(a <= b); 176 } 177 178 // Returns true if all of the input scalars are nonzero. 179 // This function may currently assume that each of the input scalars has either 180 // all or none of its bits set. Otherwise, its behavior is currently undefined. 181 template <typename tIntegerType> 182 bool All(tIntegerType a) { 183 return a; 184 } 185 186 // Returns true if any of the input scalars are nonzero. 187 // This function may currently assume that each of the input scalars has either 188 // all or none of its bits set. Otherwise, its behavior is currently undefined. 189 template <typename tIntegerType> 190 bool Any(tIntegerType a) { 191 return a; 192 } 193 194 // Returns (a+b)/2, rounded to the nearest integer. 195 // Equivalent to VRHADD in the ARM NEON instruction set. 196 template <typename IntegerType> 197 IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { 198// static_assert(is_same<IntegerType, void>::value, "unimplemented"); 199 return a; 200 } 201 202 template <> 203 inline int32_t RoundingHalfSum(int32_t a, int32_t b) { 204 return hadd(a, b); } 205 206 template <> 207 inline int16_t RoundingHalfSum(int16_t a, int16_t b) { 208 return hadd(a, b); 209 } 210 211 template <typename IntegerType> 212 IntegerType SaturatingAdd(IntegerType a, IntegerType b) { 213// static_assert(is_same<IntegerType, void>::value, "unimplemented"); 214 return a; 215 } 216 217 // So far this is only needed for int16. 218 template <> 219 inline int16_t SaturatingAdd(int16_t a, int16_t b) { 220 int32_t a32 = a; 221 int32_t b32 = b; 222 int32_t sum = a32 + b32; 223 return static_cast<int16_t>(min(32767, max(-32768, sum))); 224 } 225 226 // Returns a+b, saturating if the integers are 16bit or narrower, 227 // otherwise just a plain addition. 228 template <typename IntegerType, bool Is16Bit> 229 struct AddSaturatingIf16BitImpl { 230 static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); } 231 }; 232 template <typename IntegerType> 233 struct AddSaturatingIf16BitImpl<IntegerType, true> { 234 static IntegerType Run(IntegerType a, IntegerType b) { 235 return SaturatingAdd(a, b); 236 } 237 }; 238 template <typename IntegerType> 239 IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) { 240 using ScalarType = 241 typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; 242 return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a, 243 b); 244 } 245 246 // Returns the product of a run-time integer value by a compile-time power 247 // of two, with either a positive exponent (equivalent to an arithmetic 248 // left shift, saturating) or a negative exponent (equivalent to an arithmetic 249 // right shift, rounding to nearest). 250 template <int Exponent, typename IntegerType, 251 int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)> 252 struct ImplSaturatingRoundingMultiplyByPOT {}; 253 254 template <int Exponent, typename IntegerType> 255 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> { 256 static IntegerType eval(IntegerType x) { return x; } 257 }; 258 259 template <int Exponent, typename IntegerType> 260 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> { 261 static IntegerType eval(IntegerType x) { 262 using ScalarIntegerType = 263 typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; 264 const IntegerType min = Dup<IntegerType>(num_limits<ScalarIntegerType>::min()); 265 const IntegerType max = Dup<IntegerType>(num_limits<ScalarIntegerType>::max()); 266 const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); 267 268 const int32_t threshold = ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1); 269 const IntegerType positive_mask = MaskIfGreaterThan(x, Dup<IntegerType>(threshold)); 270 const IntegerType negative_mask = MaskIfLessThan(x, Dup<IntegerType>(-threshold)); 271 272 IntegerType result = ShiftLeft(x, Exponent); 273 result = SelectUsingMask(positive_mask, max, result); 274 result = SelectUsingMask(negative_mask, min, result); 275 return result; 276 } 277 }; 278 279 template <int Exponent, typename IntegerType> 280 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> { 281 static IntegerType eval(IntegerType x) { 282 return round_divide_by_pot<IntegerType>(x, -Exponent); 283 } 284 }; 285 286 template <int Exponent, typename IntegerType> 287 IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) { 288 return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x); 289 } 290 291 // Part 2: the FixedPoint class. 292 template <typename tRawType, int tIntegerBits> 293 class FixedPoint { 294 public: 295 typedef tRawType RawType; 296 297 typedef FixedPointRawTypeTraits<RawType> RawTypeTraits; 298 typedef typename RawTypeTraits::ScalarRawType ScalarRawType; 299 300 static constant int kTotalBits = 8 * sizeof(ScalarRawType); 301 static constant int kIntegerBits = tIntegerBits; 302 static constant int kFractionalBits = kTotalBits - 1 - kIntegerBits; 303// static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, "bad IntegerBits"); 304 305 typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType; 306 307 static const ScalarRawType ScalarRawMin() { 308 return num_limits<ScalarRawType>::min(); 309 } 310 311 static const ScalarRawType ScalarRawMax() { 312 return num_limits<ScalarRawType>::max(); 313 } 314 315 static const ScalarRawType RawMin() { 316 return VectorFromScalar(ScalarRawMin()); 317 } 318 319 static const ScalarRawType RawMax() { 320 return VectorFromScalar(ScalarRawMax()); 321 } 322 323 static FixedPoint FromRaw(RawType x) { 324 FixedPoint retval; 325 retval.raw() = x; 326 return retval; 327 } 328 329 static FixedPoint FromScalarRaw(ScalarRawType x) { 330 FixedPoint retval; 331 retval.raw() = Dup<RawType>(x); 332 return retval; 333 } 334 335 static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { 336 return FromScalarRaw(x.raw()); 337 } 338 339 template <int Exponent> 340 static FixedPoint ConstantPOT() { 341 constexpr int kOffset = kFractionalBits + Exponent; 342// static_assert(kOffset < 31, "Constant not exactly representable in this fixed-point format"); 343 return FromScalarRaw(ScalarRawType(1) << kOffset); 344 } 345 346 static FixedPoint Zero() { return FromScalarRaw(0); } 347 348 static FixedPoint One() { 349 return FromScalarRaw(kIntegerBits == 0 350 ? ScalarRawMax() 351 : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits))); 352 } 353 354 355 RawType raw() const { return i_; } 356 thread RawType& raw() { return i_; } 357 358 private: 359 RawType i_; 360 }; 361 362 // Part 3: implementation of arithmetic operators for the 363 // FixedPoint class, and a few related functions. 364 365 // A FixedPoint multiplication is just a 366 // saturate_round_x2_high_mul operation on the underlying 367 // raw integer values. The IntegerBits simply add up, as is obvious 368 // from the fact that the range is [-2^IntegerBits, 2^IntegerBits). 369 template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b> 370 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*( 371 FixedPoint<tRawType, tIntegerBits_a> a, 372 FixedPoint<tRawType, tIntegerBits_b> b) { 373 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c; 374 c.raw() = saturate_round_x2_high_mul(a.raw(), b.raw()); 375 return c; 376 } 377 378 // Tweaking IntegerBits gives exact multiplication by a power of two. 379 template <int tExponent, typename tRawType, int tIntegerBits> 380 FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot( 381 FixedPoint<tRawType, tIntegerBits> a) { 382 FixedPoint<tRawType, tExponent + tIntegerBits> c; 383 c.raw() = a.raw(); 384 return c; 385 } 386 387 // If we want to leave IntegerBits fixed, then multiplication 388 // by a power of two has to be saturating/rounding, not exact anymore. 389 template <int tExponent, typename tRawType, int tIntegerBits> 390 FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT( 391 FixedPoint<tRawType, tIntegerBits> a) { 392 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 393 SaturatingRoundingMultiplyByPOT<tExponent>(a.raw())); 394 } 395 396 // Generic arithmetic operators. 397 398#define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \ 399template <typename tRawType, int tIntegerBits> \ 400FixedPoint<tRawType, tIntegerBits> FuncName( \ 401FixedPoint<tRawType, tIntegerBits> a) { \ 402return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \ 403} 404 405#define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \ 406template <typename tRawType, int tIntegerBits> \ 407FixedPoint<tRawType, tIntegerBits> FuncName( \ 408FixedPoint<tRawType, tIntegerBits> a, \ 409FixedPoint<tRawType, tIntegerBits> b) { \ 410return FixedPoint<tRawType, tIntegerBits>::FromRaw( \ 411ImplFuncName(a.raw(), b.raw())); \ 412} 413 414 MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg) 415 MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot) 416 MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add) 417 MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub) 418 MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd) 419 MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor) 420 MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr) 421 MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum) 422 423#undef MAKE_FIXEDPOINT_UNARY_FUNC 424#undef MAKE_FIXEDPOINT_BINARY_FUNC 425 426#define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \ 427template <typename tRawType, int tIntegerBits> \ 428tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \ 429return FuncName(a.raw()); \ 430} 431 432#define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \ 433template <typename tRawType, int tIntegerBits> \ 434tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, \ 435FixedPoint<tRawType, tIntegerBits> b) { \ 436return FuncName(a.raw(), b.raw()); \ 437} 438 439 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero) 440 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero) 441 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual) 442 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual) 443 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan) 444 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual) 445 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan) 446 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual) 447 448#undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW 449#undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW 450 451 template <typename tRawType, int tIntegerBits> 452 FixedPoint<tRawType, tIntegerBits> SelectUsingMask( 453 tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val, 454 FixedPoint<tRawType, tIntegerBits> else_val) { 455 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 456 SelectUsingMask(if_mask, then_val.raw(), else_val.raw())); 457 } 458 459 template <typename tRawType, int tIntegerBits> 460 bool operator==(FixedPoint<tRawType, tIntegerBits> a, 461 FixedPoint<tRawType, tIntegerBits> b) { 462 return All(MaskIfEqual(a.raw(), b.raw())); 463 } 464 465 template <typename tRawType, int tIntegerBits> 466 bool operator!=(FixedPoint<tRawType, tIntegerBits> a, 467 FixedPoint<tRawType, tIntegerBits> b) { 468 return !(a == b); 469 } 470 471 template <typename tRawType, int tIntegerBits> 472 FixedPoint<tRawType, tIntegerBits> SaturatingAdd(FixedPoint<tRawType, tIntegerBits> a, 473 FixedPoint<tRawType, tIntegerBits> b) { 474 return FixedPoint<tRawType, tIntegerBits>::FromRaw(SaturatingAdd(a.raw(), b.raw())); 475 } 476 477 template <typename tRawType, int tIntegerBits> 478 FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(FixedPoint<tRawType, tIntegerBits> a, 479 FixedPoint<tRawType, tIntegerBits> b) { 480 return FixedPoint<tRawType, tIntegerBits>::FromRaw(AddSaturatingIf16Bit(a.raw(), b.raw())); 481 } 482 483 // Rescale changes the number of IntegerBits and updates the underlying 484 // raw integer value accordingly. 485 template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc> 486 FixedPoint<tRawType, tIntegerBitsDst> Rescale( 487 FixedPoint<tRawType, tIntegerBitsSrc> x) { 488 constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst; 489 FixedPoint<tRawType, tIntegerBitsDst> result; 490 result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw()); 491 return result; 492 } 493 494 // CheckedFixedPointConstant allows to specify fixed-point constants 495 // initialized as real numbers, in a way that does not compile floating-point 496 // arithmetic in production code, yet still checks agreement with the 497 // floating-point expressions when asserts are enabled. 498 // 499 // The raw integer value provided is always a int32, encoding a 32-bit 500 // fixed-point value, regardless of the actual Scalar type. This allows 501 // writing generic code that applies just as well to the 32-bit and 16-bit 502 // cases. In the 16-bit case, the raw integer value is internally 503 // rounding-shifted by 16 bits to the right. 504 template <typename FixedPointType> 505 inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(int32_t int32_value) { 506 typedef typename FixedPointType::ScalarRawType ScalarRawType; 507 constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType); 508 return static_cast<ScalarRawType>(round_divide_by_pot<int32_t>(int32_value, 32 - ScalarTypeBits)); 509 } 510 511#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawInt32Value, DoubleValue) \ 512(FixedPointType::FromScalarRaw(RescaleConstantInitializer<FixedPointType>(ScalarRawInt32Value))) 513 514 // Implementation of exponential function. 515 516 // Returns exp(x) for x in [-1/4, 0). 517 template <typename tRawType> 518 FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(FixedPoint<tRawType, 0> a) { 519 typedef FixedPoint<tRawType, 0> F; 520 const F constant_term = 521 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, exp(-1.0 / 8.0)); 522 const F constant_1_over_3 = 523 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0); 524 // We're evaluating a Taylor expansion around -1/8, so we do the change of 525 // variable: x = a + 1/8. 526 // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28. 527 F x = a + F::template ConstantPOT<-3>(); 528 F x2 = x * x; 529 F x3 = x2 * x; 530 F x4 = x2 * x2; 531 F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4); 532 F x4_over_24_plus_x3_over_6_plus_x2_over_2 = 533 SaturatingRoundingMultiplyByPOT<-1>(((x4_over_4 + x3) * constant_1_over_3) + x2); 534 return AddSaturatingIf16Bit(constant_term, 535 constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2)); 536 } 537 538 // Returns exp(x) for x < 0. 539 template <typename tRawType, int tIntegerBits> 540 FixedPoint<tRawType, 0> exp_on_negative_values(FixedPoint<tRawType, tIntegerBits> a) { 541 typedef FixedPoint<tRawType, tIntegerBits> InputF; 542 typedef FixedPoint<tRawType, 0> ResultF; 543 constexpr int kFractionalBits = InputF::kFractionalBits; 544 constexpr int kIntegerBits = InputF::kIntegerBits; 545 const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); 546 InputF mask = kOneQuarter - InputF::FromScalarRaw(1); 547 InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter; 548 ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(Rescale<0>(a_mod_quarter_minus_one_quarter)); 549 tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw(); 550 551#define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ 552if (kIntegerBits > Exponent) { \ 553 const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, FixedPointMultiplier, exp(-pow(2.0, Exponent))); \ 554 constexpr int kShiftAmount = kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \ 555 result = SelectUsingMask( \ 556 MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \ 557 result * kMultiplier, result); \ 558} 559 560 GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); 561 GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); 562 GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); 563 GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); 564 GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); 565 GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); 566 GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); 567 568#undef GEMMLOWP_EXP_BARREL_SHIFTER 569 570 constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0; 571 if (kIntegerBits > 5) { 572 const InputF clamp = 573 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0); 574 result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result); 575 } 576 577 result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result); 578 return result; 579 } 580 581 // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)). 582 583 // Returns (1 - x) / (1 + x) for x in (0, 1). 584 template <typename tRawType> 585 FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1(FixedPoint<tRawType, 0> a) { 586 typedef FixedPoint<tRawType, 0> F0; 587 typedef FixedPoint<tRawType, 2> F2; 588 F0 half_denominator = RoundingHalfSum(a, F0::One()); 589 // Newton-Raphson division 590 // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division 591 // Refer to that page for the logic behind the 48/17 and 32/17 constants. 592 const F2 constant_48_over_17 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); 593 const F2 constant_neg_32_over_17 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); 594 F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; 595 for (int i = 0; i < 3; i++) { 596 F2 half_denominator_times_x = half_denominator * x; 597 F2 one_minus_half_denominator_times_x = 598 F2::One() - half_denominator_times_x; 599 x = x + Rescale<2>(x * one_minus_half_denominator_times_x); 600 } 601 return Rescale<0>(x - F2::One()); 602 } 603 604 // Returns -tanh(x) for x < 0. 605 template <typename tRawType, int tIntegerBits> 606 FixedPoint<tRawType, 0> neg_tanh_on_negative_values(FixedPoint<tRawType, tIntegerBits> a) { 607 return one_minus_x_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(ExactMulByPot<1>(a))); 608 } 609 610 // Returns tanh(x) for any x. 611 template <typename tRawType, int tIntegerBits> 612 FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) { 613 typedef FixedPoint<tRawType, tIntegerBits> InputF; 614 typedef FixedPoint<tRawType, 0> ResultF; 615 tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero()); 616 tRawType mask_if_zero = MaskIfZero(a); 617 InputF n = SelectUsingMask(mask_if_negative, a, -a); 618 ResultF t = neg_tanh_on_negative_values(n); 619 return SelectUsingMask(mask_if_zero, ResultF::Zero(), 620 SelectUsingMask(mask_if_negative, -t, t)); 621 } 622 623 // Implementation of logistic function. 624 625 // Returns 1 / (1 + x) for x in (0, 1). 626 template <typename tRawType> 627 FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1(FixedPoint<tRawType, 0> a) { 628 typedef FixedPoint<tRawType, 0> F0; 629 typedef FixedPoint<tRawType, 2> F2; 630 F0 half_denominator = RoundingHalfSum(a, F0::One()); 631 // Newton-Raphson division 632 // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division 633 // Refer to that page for the logic behind the 48/17 and 32/17 constants. 634 const F2 constant_48_over_17 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); 635 const F2 constant_neg_32_over_17 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); 636 F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; 637 for (int i = 0; i < 3; i++) { 638 F2 half_denominator_times_x = half_denominator * x; 639 F2 one_minus_half_denominator_times_x = 640 F2::One() - half_denominator_times_x; 641 x = x + Rescale<2>(x * one_minus_half_denominator_times_x); 642 } 643 return Rescale<0>(ExactMulByPot<-1>(x)); 644 } 645 646 // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0. 647 template <typename tRawType, int tIntegerBits> 648 FixedPoint<tRawType, 0> logistic_on_positive_values(FixedPoint<tRawType, tIntegerBits> a) { 649 return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a)); 650 } 651 652 // Returns logistic(x) = 1 / (1 + exp(-x)) for any x. 653 template <typename tRawType, int tIntegerBits> 654 FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) { 655 typedef FixedPoint<tRawType, tIntegerBits> InputF; 656 typedef FixedPoint<tRawType, 0> ResultF; 657 tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero()); 658 tRawType mask_if_zero = MaskIfZero(a); 659 InputF abs_input = SelectUsingMask(mask_if_positive, a, -a); 660 ResultF result_if_positive = logistic_on_positive_values(abs_input); 661 ResultF result_if_negative = ResultF::One() - result_if_positive; 662 const ResultF one_half = 663 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5); 664 return SelectUsingMask(mask_if_zero, one_half, SelectUsingMask(mask_if_positive, result_if_positive, result_if_negative)); 665 } 666 667 inline int MultiplyByQuantizedMultiplierSmallerThanOneExp(int x, int quantized_multiplier, int left_shift) { 668 return round_divide_by_pot(saturate_round_x2_high_mul(x, quantized_multiplier), -left_shift); 669 } 670 671 inline int MultiplyByQuantizedMultiplier(int x, int quantized_multiplier, int shift) { 672 int left_shift = shift > 0 ? shift : 0; 673 int right_shift = shift > 0 ? 0 : -shift; 674 return round_divide_by_pot(saturate_round_x2_high_mul(x * (1 << left_shift), quantized_multiplier), right_shift); 675 } 676 677 inline int MultiplyByQuantizedMultiplierGreaterThanOne(int x, int quantized_multiplier, int left_shift) { 678 return saturate_round_x2_high_mul(x * (1 << left_shift), quantized_multiplier); 679 } 680} 681