1
2 /*
3 * Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 #include "math_common.h"
20 #include "sleef_common.h"
21
__ldexpf_scalar_kernel(float a,int scale)22 static float INLINE __ldexpf_scalar_kernel(float a, int scale)
23 {
24 #if (defined __AVX512F__)
25 float res = _mm_cvtss_f32(_mm_scalef_ss(_mm_set_ss(a), _mm_set_ss((float)scale)));
26 return res;
27 #else
28 PRINT(a); PRINT(scale);
29 // Input is allowed to be such that signed |scale| < 256,
30 // |a| may be in {+-0} or +-[2^-149, 2^0] as it comes from sin/cos,
31 // but we took precaution outside this routine and normalized a,
32 // so that it is within +-[2^-149 + 64, 2^64] or zero.
33
34 // Zeros and Inf/NaNs are handled separately.
35 // Input denormals end up here too and yield incorrect result.
36 // FIXME: assert(this function assumes no denormals on input !!!);
37 unsigned exp_bits = F2I(a) & FL_EXP_MASK;
38 unsigned zeroinfnan_mask = ((exp_bits == FL_EXP_MASK) || (exp_bits == 0))
39 ? 0xffffffff : 0; PRINT(zeroinfnan_mask);
40 // Preserve sign of input, quiet NaN
41 float zeroinfnan_res = a + a; PRINT(zeroinfnan_res);
42
43 // biased exponent bits, shifted to least significant position
44 unsigned getexp_a = exp_bits >> (FL_PREC_BITS-1); PRINT(getexp_a);
45
46 // For a * 2^scale to fit in floats we need getexp(a) + scale
47 // to fit in exponents range of floats: bias + (FL_EXP_MIN-1, FL_EXP_MAX).
48 // FL_EXP_MIN-1 is less than the smallest denormal, but it may round up.
49 int sumexp = getexp_a + scale; PRINT(sumexp);
50
51 // Return Inf of correct sign if overflow
52 unsigned ovf_mask = ((sumexp > (signed int)(FL_EXP_MAX + FL_EXP_BIAS))
53 ? 0xffffffff : 0);
54 unsigned sign_a = F2I(a) & FL_SIGN_BIT; PRINT(sign_a);
55 unsigned ovf_res = (sign_a | FL_EXP_MASK); PRINT(ovf_res);
56
57 // If underflow, return zero of correct sign
58 unsigned udf_mask = (sumexp < (signed int)(FL_EXP_MIN-1 + FL_EXP_BIAS) )
59 ? 0xffffffff : 0;
60 unsigned udf_res = sign_a; PRINT(udf_res);
61
62 // Check if result is within denormalized numbers range
63 // and doesn't completely underflow
64 unsigned den_mask = ~udf_mask &
65 (((signed int)(sumexp) <= 0) ? 0xffffffff : 0);
66
67 // If scaling leads to denormals: we shall do it via FP multiplication
68 // 2^scale * a. But 2^scale alone may not be representable in FP, while
69 // the product is OK. Thus we would like the sum of exponents sumexp in
70 // range for FP. Since sumexp already contains the value of biased exponent
71 // of a, we will first compensate a by reducing its exponent to biased zero:
72 // a = a * 2^(-(getexp_a - bias)), or set exponent bits of a to FL_EXP_BIAS.
73 // Now we would like sumexp become positive, for that we may add as little
74 // as -(FL_EXP_MIN-2 + FL_EXP_BIAS). We'd have to compensate exponent of a
75 // by this same quantity, so in the end we'll be setting exponent of a to
76 // FL_EXP_BIAS + (FL_EXP_MIN-2 + FL_EXP_BIAS) = 2*FL_EXP_BIAS + FL_EXP_MIN-2
77 int new_scale = ((unsigned int)(sumexp -(FL_EXP_MIN-2 + FL_EXP_BIAS)))
78 << (FL_PREC_BITS-1); PRINT(new_scale);
79 float new_a = I2F((F2I(a) & (~FL_EXP_MASK)) |
80 ((2*FL_EXP_BIAS + FL_EXP_MIN-2) << (FL_PREC_BITS-1))); PRINT(new_a);
81 float den_res = new_a * I2F(new_scale); PRINT(den_res);
82
83 // normal case, just add scale to exponent bits
84 unsigned gen_res = F2I(a) + (((unsigned int)scale) << (FL_PREC_BITS-1)); PRINT(gen_res);
85 unsigned gen_mask = ~(ovf_mask | udf_mask | den_mask);
86
87 float result = I2F((F2I(zeroinfnan_res) & zeroinfnan_mask) |
88 ((~zeroinfnan_mask) & ((ovf_res & ovf_mask) |
89 (udf_res & udf_mask) |
90 (F2I(den_res) & den_mask) |
91 (gen_res & gen_mask)))); PRINT(result);
92
93 return result;
94 #endif //#if (defined __AVX512F__)
95 }
96
97 static vfloat INLINE
98 //static vfloat __attribute__((noinline))
__vldexpf_manual(vfloat va,vfloat vscale)99 __vldexpf_manual(vfloat va, vfloat vscale)
100 {
101 PRINT(va); PRINT(vscale);
102 // Input is allowed to be such that signed |scale| < 256,
103 // |a| may be in {+-0} or +-[2^-149, 2^0] as it comes from sin/cos,
104 // but we took precaution outside this routine and normalized a,
105 // so that it is within +-[2^-149 + 64, 2^64] or zero.
106
107 // Zeros and Inf/NaNs are handled separately.
108 // Input denormals end up here too and yield incorrect result.
109 // FIXME: assert(this function assumes no denormals on input !!!);
110 vint2 exp_bits = vand_vi2_vi2_vi2(vF2I(va), vSETi(FL_EXP_MASK));
111 vopmask zero_mask = veq_vo_vi2_vi2(exp_bits, vSETi(0));
112 vopmask infnan_mask = veq_vo_vi2_vi2(exp_bits, vSETi(FL_EXP_MASK));
113 vopmask zeroinfnan_mask = vor_vo_vo_vo(zero_mask, infnan_mask); PRINT(zeroinfnan_mask);
114
115 // Preserve sign of input, quiet NaN
116 vfloat zeroinfnan_res = vadd_vf_vf_vf(va, va); PRINT(zeroinfnan_res);
117
118 // biased exponent bits, shifted to least significant position
119 vint2 getexp_a = vsrl_vi2_vi2_i(exp_bits, FL_PREC_BITS-1); PRINT(getexp_a);
120
121 // For a * 2^scale to fit in floats we need getexp(a) + scale
122 // to fit in exponents range of floats: bias + (FL_EXP_MIN-1, FL_EXP_MAX).
123 // FL_EXP_MIN-1 is less than the smallest denormal, but it may round up.
124 vint2 sumexp = vadd_vi2_vi2_vi2(getexp_a, vF2I(vscale)); PRINT(sumexp);
125
126 // Return Inf of correct sign if overflow
127 vopmask ovf_mask = vgt_vo_vi2_vi2(sumexp, vSETi(FL_EXP_MAX + FL_EXP_BIAS));
128 vint2 sign_a = vand_vi2_vi2_vi2(vF2I(va), vSETi(FL_SIGN_BIT)); PRINT(sign_a);
129 vint2 ovf_res = vor_vi2_vi2_vi2(sign_a, vSETi(FL_EXP_MASK)); PRINT(ovf_res);
130
131 // If underflow, return zero of correct sign
132 vopmask udf_mask = vgt_vo_vi2_vi2(vSETi(FL_EXP_MIN-1 + FL_EXP_BIAS), sumexp);
133 vint2 udf_res = sign_a; PRINT(udf_res);
134
135 // Check if result is within denormalized numbers range
136 // and doesn't completely underflow
137 vopmask den_mask = vandnot_vo_vo_vo(udf_mask, vgt_vo_vi2_vi2(vSETi(1), sumexp));
138
139 // If scaling leads to denormals: we shall do it via FP multiplication
140 // 2^scale * a. But 2^scale alone may not be representable in FP, while
141 // the product is OK. Thus we would like the sum of exponents sumexp in
142 // range for FP. Since sumexp already contains the value of biased exponent
143 // of a, we will first compensate a by reducing its exponent to biased zero:
144 // a = a * 2^(-(getexp_a - bias)), or set exponent bits of a to FL_EXP_BIAS.
145 // Now we would like sumexp become positive, for that we may add as little
146 // as -(FL_EXP_MIN-2 + FL_EXP_BIAS). We'd have to compensate exponent of a
147 // by this same quantity, so in the end we'll be setting exponent of a to
148 // FL_EXP_BIAS + (FL_EXP_MIN-2 + FL_EXP_BIAS) = 2*FL_EXP_BIAS + FL_EXP_MIN-2
149 vint2 new_scale =
150 vsll_vi2_vi2_i(
151 vadd_vi2_vi2_vi2(sumexp, vSETi(-(FL_EXP_MIN-2 + FL_EXP_BIAS))),
152 FL_PREC_BITS-1); PRINT(new_scale);
153 vfloat new_a = vI2F(vor_vi2_vi2_vi2(
154 vand_vi2_vi2_vi2(vF2I(va), vSETi(~FL_EXP_MASK)),
155 vSETi((2*FL_EXP_BIAS + FL_EXP_MIN-2) << (FL_PREC_BITS-1)))); PRINT(new_a);
156 vfloat den_res = vmul_vf_vf_vf(new_a, vI2F(new_scale)); PRINT(den_res);
157
158 // normal case, just add scale to exponent bits
159 vint2 gen_res = vadd_vi2_vi2_vi2(vF2I(va),
160 vsll_vi2_vi2_i( vF2I(vscale), FL_PREC_BITS-1)); PRINT(gen_res);
161 vopmask ngen_mask =
162 vor_vo_vo_vo(vor_vo_vo_vo(ovf_mask, udf_mask), den_mask);
163
164 vfloat result = vI2F(
165 vor_vi2_vi2_vi2(
166 vand_vi2_vo_vi2(zeroinfnan_mask, vF2I(zeroinfnan_res)),
167 vandnot_vi2_vo_vi2(zeroinfnan_mask,
168 vor_vi2_vi2_vi2(
169 vand_vi2_vo_vi2(ovf_mask, ovf_res),
170 vor_vi2_vi2_vi2(
171 vand_vi2_vo_vi2(udf_mask, udf_res),
172 vor_vi2_vi2_vi2(
173 vand_vi2_vo_vi2(den_mask, vF2I(den_res)),
174 vandnot_vi2_vo_vi2(ngen_mask, gen_res))))))); PRINT(result);
175
176 return result;
177 }
178
179 static vfloat INLINE
180 //static vfloat __attribute__((noinline))
__vldexpf_kernel(vfloat va,vfloat vscale)181 __vldexpf_kernel(vfloat va, vfloat vscale)
182 {
183 PRINT(va); PRINT(vscale);
184 #if (defined __AVX512F__) && ((defined __AVX512VL__) || (_VL == 8))
185 // use AVX512VL instruction for _VL < 8
186 // use AVX512F instruction in case of a full width
187 vfloat vfres = JOIN(__SIMD_TYPE,_scalef_ps)(va, vcast_vf_vi2(vF2I(vscale))); PRINT(vfres);
188 return vfres;
189 #elif (defined __AVX512F__)
190 // AVX512VL not supported and _VL < 8
191 vfloat vscale_converted = vcast_vf_vi2(vF2I(vscale)); PRINT(vscale_converted);
192 __mmask16 mask = (__mmask16)((1 << (2*_VL)) - 1); PRINT(mask);
193 __m512 fullwidth_va = JOIN3(_mm512_castps,__SIMD_BITS,_ps512)(va); PRINT(fullwidth_va);
194 __m512 fullwidth_vscale = JOIN3(_mm512_castps,__SIMD_BITS,_ps512)(vscale_converted); PRINT(fullwidth_vscale);
195 __m512 fullwidth_vfres = _mm512_maskz_scalef_ps(mask, fullwidth_va, fullwidth_vscale); PRINT(fullwidth_vfres);
196 vfloat vfres = JOIN(_mm512_castps512_ps,__SIMD_BITS)(fullwidth_vfres); PRINT(vfres);
197 return vfres;
198 #else
199 return __vldexpf_manual(va, vscale);
200 #endif
201 }
202