1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16/*
17 Convert to metal by MNN.
18 Copyright © 2018, Alibaba Group Holding Limited
19 */
20
21#include <metal_stdlib>
22#include "MetalDefine.metal"
23
24using namespace metal;
25
26namespace MNN {
27    // Part 1: Low-level integer-arithmetic primitives.
28    template <typename tIntegerType>
29    struct FixedPointRawTypeTraits {};
30
31    template <>
32    struct FixedPointRawTypeTraits<int32_t> {
33        typedef int32_t ScalarRawType;
34        static constant int kLanes = 1;
35    };
36
37    template <>
38    struct FixedPointRawTypeTraits<int16_t> {
39        typedef int16_t ScalarRawType;
40        static constant int kLanes = 1;
41    };
42
43    // Returns a SIMD value duplicating a scalar value across all lanes.
44    template <typename tRawType>
45    tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
46        return x;
47    }
48
49    // Plain bit-wise AND
50    template <typename tIntegerType>
51    tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
52        return a & b;
53    }
54
55    // Plain bit-wise OR
56    template <typename tIntegerType>
57    tIntegerType BitOr(tIntegerType a, tIntegerType b) {
58        return a | b;
59    }
60
61    // Plain bit-wise XOR
62    template <typename tIntegerType>
63    tIntegerType BitXor(tIntegerType a, tIntegerType b) {
64        return a ^ b;
65    }
66
67    // Plain bit-wise NOT
68    template <typename tIntegerType>
69    tIntegerType BitNot(tIntegerType a) {
70        return ~a;
71    }
72
73    // Integer addition. Not saturating. Overflow is undefined behavior.
74    template <typename tIntegerType>
75    tIntegerType Add(tIntegerType a, tIntegerType b) {
76        return a + b;
77    }
78
79    // Integer subtraction. Not saturating. Overflow is undefined behavior.
80    template <typename tIntegerType>
81    tIntegerType Mul(tIntegerType a, tIntegerType b) {
82        return a * b;
83    }
84
85    template <typename tIntegerType>
86    tIntegerType Sub(tIntegerType a, tIntegerType b) {
87        return a - b;
88    }
89
90    // Integer unary negative. Not saturating. Overflow is undefined behavior.
91    template <typename tIntegerType>
92    tIntegerType Neg(tIntegerType a) {
93        return -a;
94    }
95
96    // Integer arithmetic left-shift, equivalent to multiplying with a power of two.
97    // Not saturating. Negative inputs do not necessarily invoke undefined
98    // behaviour. Overflow is undefined behavior.
99    template <typename tIntegerType>
100    tIntegerType ShiftLeft(tIntegerType a, int offset) {
101        return a * (static_cast<tIntegerType>(1) << offset);
102    }
103
104    // Integer arithmetic right-shift. Not rounding.
105    // Relying on implementation-defined, but in-practice-consistent,
106    // C++ compiler behavior.
107    template <typename tIntegerType>
108    tIntegerType ShiftRight(tIntegerType a, int offset) {
109        return a >> offset;
110    }
111
112    // Each bit of the result is set to the corresponding bit of either then_val or
113    // else_val depending on whether the corresponding bit of if_mask is set.
114    // Equivalent to the VBSL instruction in ARM NEON.
115    template <typename tIntegerType>
116    tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
117                                 tIntegerType else_val) {
118        return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
119    }
120
121    // For each input scalar, the corresponding bits of the result are set if the
122    // input scalar is non-zero.
123    template <typename tIntegerType>
124    tIntegerType MaskIfNonZero(tIntegerType a) {
125        constexpr tIntegerType zero = 0;
126        return a ? BitNot(zero) : zero;
127    }
128
129    // For each input scalar, the corresponding bits of the result are set if the
130    // input scalar is zero.
131    template <typename tIntegerType>
132    tIntegerType MaskIfZero(tIntegerType a) {
133        return MaskIfNonZero<tIntegerType>(!a);
134    }
135
136    // For each pair of input scalars, the corresponding bits of the result are
137    // set if the input scalars are equal.
138    template <typename tIntegerType>
139    tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
140        return MaskIfNonZero<tIntegerType>(a == b);
141    }
142
143    // For each pair of input scalars, the corresponding bits of the result are
144    // set if the input scalars are not equal.
145    template <typename tIntegerType>
146    tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
147        return MaskIfNonZero<tIntegerType>(a != b);
148    }
149
150    // For each pair of input scalars, the corresponding bits of the result are
151    // set if the input scalars a, b satisfy a > b.
152    template <typename tIntegerType>
153    tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
154        return MaskIfNonZero<tIntegerType>(a > b);
155    }
156
157    // For each pair of input scalars, the corresponding bits of the result are
158    // set if the input scalars a, b satisfy a >= b.
159    template <typename tIntegerType>
160    tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
161        return MaskIfNonZero<tIntegerType>(a >= b);
162    }
163
164    // For each pair of input scalars, the corresponding bits of the result are
165    // set if the input scalars a, b satisfy a < b.
166    template <typename tIntegerType>
167    tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
168        return MaskIfNonZero<tIntegerType>(a < b);
169    }
170
171    // For each pair of input scalars, the corresponding bits of the result are
172    // set if the input scalars a, b satisfy a <= b.
173    template <typename tIntegerType>
174    tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
175        return MaskIfNonZero<tIntegerType>(a <= b);
176    }
177
178    // Returns true if all of the input scalars are nonzero.
179    // This function may currently assume that each of the input scalars has either
180    // all or none of its bits set. Otherwise, its behavior is currently undefined.
181    template <typename tIntegerType>
182    bool All(tIntegerType a) {
183        return a;
184    }
185
186    // Returns true if any of the input scalars are nonzero.
187    // This function may currently assume that each of the input scalars has either
188    // all or none of its bits set. Otherwise, its behavior is currently undefined.
189    template <typename tIntegerType>
190    bool Any(tIntegerType a) {
191        return a;
192    }
193
194    // Returns (a+b)/2, rounded to the nearest integer.
195    // Equivalent to VRHADD in the ARM NEON instruction set.
196    template <typename IntegerType>
197    IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
198//        static_assert(is_same<IntegerType, void>::value, "unimplemented");
199        return a;
200    }
201
202    template <>
203    inline int32_t RoundingHalfSum(int32_t a, int32_t b) {
204        return hadd(a, b);    }
205
206    template <>
207    inline int16_t RoundingHalfSum(int16_t a, int16_t b) {
208        return hadd(a, b);
209    }
210
211    template <typename IntegerType>
212    IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
213//        static_assert(is_same<IntegerType, void>::value, "unimplemented");
214        return a;
215    }
216
217    // So far this is only needed for int16.
218    template <>
219    inline int16_t SaturatingAdd(int16_t a, int16_t b) {
220        int32_t a32 = a;
221        int32_t b32 = b;
222        int32_t sum = a32 + b32;
223        return static_cast<int16_t>(min(32767, max(-32768, sum)));
224    }
225
226    // Returns a+b, saturating if the integers are 16bit or narrower,
227    // otherwise just a plain addition.
228    template <typename IntegerType, bool Is16Bit>
229    struct AddSaturatingIf16BitImpl {
230        static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); }
231    };
232    template <typename IntegerType>
233    struct AddSaturatingIf16BitImpl<IntegerType, true> {
234        static IntegerType Run(IntegerType a, IntegerType b) {
235            return SaturatingAdd(a, b);
236        }
237    };
238    template <typename IntegerType>
239    IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
240        using ScalarType =
241        typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
242        return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a,
243                                                                                   b);
244    }
245
246    // Returns the product of a run-time integer value by a compile-time power
247    // of two, with either a positive exponent (equivalent to an arithmetic
248    // left shift, saturating) or a negative exponent (equivalent to an arithmetic
249    // right shift, rounding to nearest).
250    template <int Exponent, typename IntegerType,
251    int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
252    struct ImplSaturatingRoundingMultiplyByPOT {};
253
254    template <int Exponent, typename IntegerType>
255    struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
256        static IntegerType eval(IntegerType x) { return x; }
257    };
258
259    template <int Exponent, typename IntegerType>
260    struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
261        static IntegerType eval(IntegerType x) {
262            using ScalarIntegerType =
263            typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
264            const IntegerType min = Dup<IntegerType>(num_limits<ScalarIntegerType>::min());
265            const IntegerType max = Dup<IntegerType>(num_limits<ScalarIntegerType>::max());
266            const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
267
268            const int32_t threshold = ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1);
269            const IntegerType positive_mask = MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
270            const IntegerType negative_mask = MaskIfLessThan(x, Dup<IntegerType>(-threshold));
271
272            IntegerType result = ShiftLeft(x, Exponent);
273            result = SelectUsingMask(positive_mask, max, result);
274            result = SelectUsingMask(negative_mask, min, result);
275            return result;
276        }
277    };
278
279    template <int Exponent, typename IntegerType>
280    struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> {
281        static IntegerType eval(IntegerType x) {
282            return round_divide_by_pot<IntegerType>(x, -Exponent);
283        }
284    };
285
286    template <int Exponent, typename IntegerType>
287    IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
288        return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
289    }
290
291    // Part 2: the FixedPoint class.
292    template <typename tRawType, int tIntegerBits>
293    class FixedPoint {
294    public:
295        typedef tRawType RawType;
296
297        typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
298        typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
299
300        static constant int kTotalBits = 8 * sizeof(ScalarRawType);
301        static constant int kIntegerBits = tIntegerBits;
302        static constant int kFractionalBits = kTotalBits - 1 - kIntegerBits;
303//        static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, "bad IntegerBits");
304
305        typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;
306
307        static const ScalarRawType ScalarRawMin() {
308            return num_limits<ScalarRawType>::min();
309        }
310
311        static const ScalarRawType ScalarRawMax() {
312            return num_limits<ScalarRawType>::max();
313        }
314
315        static const ScalarRawType RawMin() {
316            return VectorFromScalar(ScalarRawMin());
317        }
318
319        static const ScalarRawType RawMax() {
320            return VectorFromScalar(ScalarRawMax());
321        }
322
323        static FixedPoint FromRaw(RawType x) {
324            FixedPoint retval;
325            retval.raw() = x;
326            return retval;
327        }
328
329        static FixedPoint FromScalarRaw(ScalarRawType x) {
330            FixedPoint retval;
331            retval.raw() = Dup<RawType>(x);
332            return retval;
333        }
334
335        static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) {
336            return FromScalarRaw(x.raw());
337        }
338
339        template <int Exponent>
340        static FixedPoint ConstantPOT() {
341            constexpr int kOffset = kFractionalBits + Exponent;
342//            static_assert(kOffset < 31, "Constant not exactly representable in this fixed-point format");
343            return FromScalarRaw(ScalarRawType(1) << kOffset);
344        }
345
346        static FixedPoint Zero() { return FromScalarRaw(0); }
347
348        static FixedPoint One() {
349            return FromScalarRaw(kIntegerBits == 0
350                                 ? ScalarRawMax()
351                                 : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits)));
352        }
353
354
355        RawType raw() const { return i_; }
356        thread RawType& raw() { return i_; }
357
358    private:
359        RawType i_;
360    };
361
362    // Part 3: implementation of arithmetic operators for the
363    // FixedPoint class, and a few related functions.
364
365    // A FixedPoint multiplication is just a
366    // saturate_round_x2_high_mul operation on the underlying
367    // raw integer values. The IntegerBits simply add up, as is obvious
368    // from the fact that the range is [-2^IntegerBits, 2^IntegerBits).
369    template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
370    FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(
371                                                                    FixedPoint<tRawType, tIntegerBits_a> a,
372                                                                    FixedPoint<tRawType, tIntegerBits_b> b) {
373        FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
374        c.raw() = saturate_round_x2_high_mul(a.raw(), b.raw());
375        return c;
376    }
377
378    // Tweaking IntegerBits gives exact multiplication by a power of two.
379    template <int tExponent, typename tRawType, int tIntegerBits>
380    FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(
381                                                                 FixedPoint<tRawType, tIntegerBits> a) {
382        FixedPoint<tRawType, tExponent + tIntegerBits> c;
383        c.raw() = a.raw();
384        return c;
385    }
386
387    // If we want to leave IntegerBits fixed, then multiplication
388    // by a power of two has to be saturating/rounding, not exact anymore.
389    template <int tExponent, typename tRawType, int tIntegerBits>
390    FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(
391                                                                       FixedPoint<tRawType, tIntegerBits> a) {
392        return FixedPoint<tRawType, tIntegerBits>::FromRaw(
393                                                           SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
394    }
395
396    // Generic arithmetic operators.
397
398#define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName)                     \
399template <typename tRawType, int tIntegerBits>                               \
400FixedPoint<tRawType, tIntegerBits> FuncName(                                 \
401FixedPoint<tRawType, tIntegerBits> a) {                                  \
402return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
403}
404
405#define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
406template <typename tRawType, int tIntegerBits>            \
407FixedPoint<tRawType, tIntegerBits> FuncName(              \
408FixedPoint<tRawType, tIntegerBits> a,                 \
409FixedPoint<tRawType, tIntegerBits> b) {               \
410return FixedPoint<tRawType, tIntegerBits>::FromRaw(     \
411ImplFuncName(a.raw(), b.raw()));                    \
412}
413
414    MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
415    MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
416    MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
417    MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
418    MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
419    MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
420    MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
421    MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)
422
423#undef MAKE_FIXEDPOINT_UNARY_FUNC
424#undef MAKE_FIXEDPOINT_BINARY_FUNC
425
426#define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName)  \
427template <typename tRawType, int tIntegerBits>            \
428tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
429return FuncName(a.raw());                               \
430}
431
432#define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
433template <typename tRawType, int tIntegerBits>            \
434tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a,   \
435FixedPoint<tRawType, tIntegerBits> b) { \
436return FuncName(a.raw(), b.raw());                      \
437}
438
439    MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
440    MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
441    MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
442    MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
443    MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
444    MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
445    MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
446    MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)
447
448#undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
449#undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW
450
451    template <typename tRawType, int tIntegerBits>
452    FixedPoint<tRawType, tIntegerBits> SelectUsingMask(
453                                                       tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
454                                                       FixedPoint<tRawType, tIntegerBits> else_val) {
455        return FixedPoint<tRawType, tIntegerBits>::FromRaw(
456                                                           SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
457    }
458
459    template <typename tRawType, int tIntegerBits>
460    bool operator==(FixedPoint<tRawType, tIntegerBits> a,
461                    FixedPoint<tRawType, tIntegerBits> b) {
462        return All(MaskIfEqual(a.raw(), b.raw()));
463    }
464
465    template <typename tRawType, int tIntegerBits>
466    bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
467                    FixedPoint<tRawType, tIntegerBits> b) {
468        return !(a == b);
469    }
470
471    template <typename tRawType, int tIntegerBits>
472    FixedPoint<tRawType, tIntegerBits> SaturatingAdd(FixedPoint<tRawType, tIntegerBits> a,
473                                                     FixedPoint<tRawType, tIntegerBits> b) {
474        return FixedPoint<tRawType, tIntegerBits>::FromRaw(SaturatingAdd(a.raw(), b.raw()));
475    }
476
477    template <typename tRawType, int tIntegerBits>
478    FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(FixedPoint<tRawType, tIntegerBits> a,
479                                                            FixedPoint<tRawType, tIntegerBits> b) {
480        return FixedPoint<tRawType, tIntegerBits>::FromRaw(AddSaturatingIf16Bit(a.raw(), b.raw()));
481    }
482
483    // Rescale changes the number of IntegerBits and updates the underlying
484    // raw integer value accordingly.
485    template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
486    FixedPoint<tRawType, tIntegerBitsDst> Rescale(
487                                                  FixedPoint<tRawType, tIntegerBitsSrc> x) {
488        constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
489        FixedPoint<tRawType, tIntegerBitsDst> result;
490        result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
491        return result;
492    }
493
494    // CheckedFixedPointConstant allows to specify fixed-point constants
495    // initialized as real numbers, in a way that does not compile floating-point
496    // arithmetic in production code, yet still checks agreement with the
497    // floating-point expressions when asserts are enabled.
498    //
499    // The raw integer value provided is always a int32, encoding a 32-bit
500    // fixed-point value, regardless of the actual Scalar type. This allows
501    // writing generic code that applies just as well to the 32-bit and 16-bit
502    // cases. In the 16-bit case, the raw integer value is internally
503    // rounding-shifted by 16 bits to the right.
504    template <typename FixedPointType>
505    inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(int32_t int32_value) {
506        typedef typename FixedPointType::ScalarRawType ScalarRawType;
507        constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType);
508        return static_cast<ScalarRawType>(round_divide_by_pot<int32_t>(int32_value, 32 - ScalarTypeBits));
509    }
510
511#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawInt32Value, DoubleValue) \
512(FixedPointType::FromScalarRaw(RescaleConstantInitializer<FixedPointType>(ScalarRawInt32Value)))
513
514    // Implementation of exponential function.
515
516    // Returns exp(x) for x in [-1/4, 0).
517    template <typename tRawType>
518    FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(FixedPoint<tRawType, 0> a) {
519        typedef FixedPoint<tRawType, 0> F;
520        const F constant_term =
521        GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, exp(-1.0 / 8.0));
522        const F constant_1_over_3 =
523        GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0);
524        // We're evaluating a Taylor expansion around -1/8, so we do the change of
525        // variable: x = a + 1/8.
526        // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
527        F x = a + F::template ConstantPOT<-3>();
528        F x2 = x * x;
529        F x3 = x2 * x;
530        F x4 = x2 * x2;
531        F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4);
532        F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
533            SaturatingRoundingMultiplyByPOT<-1>(((x4_over_4 + x3) * constant_1_over_3) + x2);
534        return AddSaturatingIf16Bit(constant_term,
535                                    constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
536    }
537
538    // Returns exp(x) for x < 0.
539    template <typename tRawType, int tIntegerBits>
540    FixedPoint<tRawType, 0> exp_on_negative_values(FixedPoint<tRawType, tIntegerBits> a) {
541        typedef FixedPoint<tRawType, tIntegerBits> InputF;
542        typedef FixedPoint<tRawType, 0> ResultF;
543        constexpr int kFractionalBits = InputF::kFractionalBits;
544        constexpr int kIntegerBits = InputF::kIntegerBits;
545        const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
546        InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
547        InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
548        ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(Rescale<0>(a_mod_quarter_minus_one_quarter));
549        tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw();
550
551#define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier)         \
552if (kIntegerBits > Exponent) {                                            \
553    const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, FixedPointMultiplier, exp(-pow(2.0, Exponent))); \
554    constexpr int kShiftAmount = kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \
555    result = SelectUsingMask(                                               \
556        MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \
557        result * kMultiplier, result);                                      \
558}
559
560        GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);
561        GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);
562        GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);
563        GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);
564        GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);
565        GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);
566        GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);
567
568#undef GEMMLOWP_EXP_BARREL_SHIFTER
569
570        constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
571        if (kIntegerBits > 5) {
572            const InputF clamp =
573            GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0);
574            result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
575        }
576
577        result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result);
578        return result;
579    }
580
581    // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)).
582
583    // Returns (1 - x) / (1 + x) for x in (0, 1).
584    template <typename tRawType>
585    FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1(FixedPoint<tRawType, 0> a) {
586        typedef FixedPoint<tRawType, 0> F0;
587        typedef FixedPoint<tRawType, 2> F2;
588        F0 half_denominator = RoundingHalfSum(a, F0::One());
589        // Newton-Raphson division
590        // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
591        // Refer to that page for the logic behind the 48/17 and 32/17 constants.
592        const F2 constant_48_over_17     = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
593        const F2 constant_neg_32_over_17 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
594        F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
595        for (int i = 0; i < 3; i++) {
596            F2 half_denominator_times_x = half_denominator * x;
597            F2 one_minus_half_denominator_times_x =
598            F2::One() - half_denominator_times_x;
599            x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
600        }
601        return Rescale<0>(x - F2::One());
602    }
603
604    // Returns -tanh(x) for x < 0.
605    template <typename tRawType, int tIntegerBits>
606    FixedPoint<tRawType, 0> neg_tanh_on_negative_values(FixedPoint<tRawType, tIntegerBits> a) {
607        return one_minus_x_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(ExactMulByPot<1>(a)));
608    }
609
610    // Returns tanh(x) for any x.
611    template <typename tRawType, int tIntegerBits>
612    FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
613        typedef FixedPoint<tRawType, tIntegerBits> InputF;
614        typedef FixedPoint<tRawType, 0> ResultF;
615        tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
616        tRawType mask_if_zero = MaskIfZero(a);
617        InputF n = SelectUsingMask(mask_if_negative, a, -a);
618        ResultF t = neg_tanh_on_negative_values(n);
619        return SelectUsingMask(mask_if_zero, ResultF::Zero(),
620                               SelectUsingMask(mask_if_negative, -t, t));
621    }
622
623    // Implementation of logistic function.
624
625    // Returns 1 / (1 + x) for x in (0, 1).
626    template <typename tRawType>
627    FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1(FixedPoint<tRawType, 0> a) {
628        typedef FixedPoint<tRawType, 0> F0;
629        typedef FixedPoint<tRawType, 2> F2;
630        F0 half_denominator = RoundingHalfSum(a, F0::One());
631        // Newton-Raphson division
632        // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
633        // Refer to that page for the logic behind the 48/17 and 32/17 constants.
634        const F2 constant_48_over_17     = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
635        const F2 constant_neg_32_over_17 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
636        F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
637        for (int i = 0; i < 3; i++) {
638            F2 half_denominator_times_x = half_denominator * x;
639            F2 one_minus_half_denominator_times_x =
640            F2::One() - half_denominator_times_x;
641            x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
642        }
643        return Rescale<0>(ExactMulByPot<-1>(x));
644    }
645
646    // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0.
647    template <typename tRawType, int tIntegerBits>
648    FixedPoint<tRawType, 0> logistic_on_positive_values(FixedPoint<tRawType, tIntegerBits> a) {
649        return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a));
650    }
651
652    // Returns logistic(x) = 1 / (1 + exp(-x)) for any x.
653    template <typename tRawType, int tIntegerBits>
654    FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
655        typedef FixedPoint<tRawType, tIntegerBits> InputF;
656        typedef FixedPoint<tRawType, 0> ResultF;
657        tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero());
658        tRawType mask_if_zero = MaskIfZero(a);
659        InputF abs_input = SelectUsingMask(mask_if_positive, a, -a);
660        ResultF result_if_positive = logistic_on_positive_values(abs_input);
661        ResultF result_if_negative = ResultF::One() - result_if_positive;
662        const ResultF one_half =
663        GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5);
664        return SelectUsingMask(mask_if_zero, one_half, SelectUsingMask(mask_if_positive, result_if_positive, result_if_negative));
665    }
666
667    inline int MultiplyByQuantizedMultiplierSmallerThanOneExp(int x, int quantized_multiplier, int left_shift) {
668        return round_divide_by_pot(saturate_round_x2_high_mul(x, quantized_multiplier), -left_shift);
669    }
670
671    inline int MultiplyByQuantizedMultiplier(int x, int quantized_multiplier, int shift) {
672        int left_shift = shift > 0 ? shift : 0;
673        int right_shift = shift > 0 ? 0 : -shift;
674        return round_divide_by_pot(saturate_round_x2_high_mul(x * (1 << left_shift), quantized_multiplier), right_shift);
675    }
676
677    inline int MultiplyByQuantizedMultiplierGreaterThanOne(int x, int quantized_multiplier, int left_shift) {
678        return saturate_round_x2_high_mul(x * (1 << left_shift), quantized_multiplier);
679    }
680}
681