1/*
2 * Copyright (c) 2014 Advanced Micro Devices, Inc.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice shall be included in
12 * all copies or substantial portions of the Software.
13 *
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 * THE SOFTWARE.
21 */
22
23#include <clc/clc.h>
24
25#include "config.h"
26#include "math.h"
27#include "tables.h"
28#include "../clcmacro.h"
29
30// compute pow using log and exp
31// x^y = exp(y * log(x))
32//
33// we take care not to lose precision in the intermediate steps
34//
35// When computing log, calculate it in splits,
36//
37// r = f * (p_invead + p_inv_tail)
38// r = rh + rt
39//
40// calculate log polynomial using r, in end addition, do
41// poly = poly + ((rh-r) + rt)
42//
43// lth = -r
44// ltt = ((xexp * log2_t) - poly) + logT
45// lt = lth + ltt
46//
47// lh = (xexp * log2_h) + logH
48// l = lh + lt
49//
50// Calculate final log answer as gh and gt,
51// gh = l & higher-half bits
52// gt = (((ltt - (lt - lth)) + ((lh - l) + lt)) + (l - gh))
53//
54// yh = y & higher-half bits
55// yt = y - yh
56//
57// Before entering computation of exp,
58// vs = ((yt*gt + yt*gh) + yh*gt)
59// v = vs + yh*gh
60// vt = ((yh*gh - v) + vs)
61//
62// In calculation of exp, add vt to r that is used for poly
63// At the end of exp, do
64// ((((expT * poly) + expT) + expH*poly) + expH)
65
66_CLC_DEF _CLC_OVERLOAD float __clc_rootn(float x, int ny)
67{
68    float y = MATH_RECIP((float)ny);
69
70    int ix = as_int(x);
71    int ax = ix & EXSIGNBIT_SP32;
72    int xpos = ix == ax;
73
74    int iy = as_int(y);
75    int ay = iy & EXSIGNBIT_SP32;
76    int ypos = iy == ay;
77
78    // Extra precise log calculation
79    // First handle case that x is close to 1
80    float r = 1.0f - as_float(ax);
81    int near1 = fabs(r) < 0x1.0p-4f;
82    float r2 = r*r;
83
84    // Coefficients are just 1/3, 1/4, 1/5 and 1/6
85    float poly = mad(r,
86                     mad(r,
87                         mad(r,
88                             mad(r, 0x1.24924ap-3f, 0x1.555556p-3f),
89                             0x1.99999ap-3f),
90                         0x1.000000p-2f),
91                     0x1.555556p-2f);
92
93    poly *= r2*r;
94
95    float lth_near1 = -r2 * 0.5f;
96    float ltt_near1 = -poly;
97    float lt_near1 = lth_near1 + ltt_near1;
98    float lh_near1 = -r;
99    float l_near1 = lh_near1 + lt_near1;
100
101    // Computations for x not near 1
102    int m = (int)(ax >> EXPSHIFTBITS_SP32) - EXPBIAS_SP32;
103    float mf = (float)m;
104    int ixs = as_int(as_float(ax | 0x3f800000) - 1.0f);
105    float mfs = (float)((ixs >> EXPSHIFTBITS_SP32) - 253);
106    int c = m == -127;
107    int ixn = c ? ixs : ax;
108    float mfn = c ? mfs : mf;
109
110    int indx = (ixn & 0x007f0000) + ((ixn & 0x00008000) << 1);
111
112    // F - Y
113    float f = as_float(0x3f000000 | indx) - as_float(0x3f000000 | (ixn & MANTBITS_SP32));
114
115    indx = indx >> 16;
116    float2 tv = USE_TABLE(log_inv_tbl_ep, indx);
117    float rh = f * tv.s0;
118    float rt = f * tv.s1;
119    r = rh + rt;
120
121    poly = mad(r, mad(r, 0x1.0p-2f, 0x1.555556p-2f), 0x1.0p-1f) * (r*r);
122    poly += (rh - r) + rt;
123
124    const float LOG2_HEAD = 0x1.62e000p-1f;  // 0.693115234
125    const float LOG2_TAIL = 0x1.0bfbe8p-15f; // 0.0000319461833
126    tv = USE_TABLE(loge_tbl, indx);
127    float lth = -r;
128    float ltt = mad(mfn, LOG2_TAIL, -poly) + tv.s1;
129    float lt = lth + ltt;
130    float lh = mad(mfn, LOG2_HEAD, tv.s0);
131    float l = lh + lt;
132
133    // Select near 1 or not
134    lth = near1 ? lth_near1 : lth;
135    ltt = near1 ? ltt_near1 : ltt;
136    lt = near1 ? lt_near1 : lt;
137    lh = near1 ? lh_near1 : lh;
138    l = near1 ? l_near1 : l;
139
140    float gh = as_float(as_int(l) & 0xfffff000);
141    float gt = ((ltt - (lt - lth)) + ((lh - l) + lt)) + (l - gh);
142
143    float yh = as_float(iy & 0xfffff000);
144
145    float fny = (float)ny;
146    float fnyh = as_float(as_int(fny) & 0xfffff000);
147    float fnyt = (float)(ny - (int)fnyh);
148    float yt = MATH_DIVIDE(mad(-fnyt, yh, mad(-fnyh, yh, 1.0f)), fny);
149
150    float ylogx_s = mad(gt, yh, mad(gh, yt, yt*gt));
151    float ylogx = mad(yh, gh, ylogx_s);
152    float ylogx_t = mad(yh, gh, -ylogx) + ylogx_s;
153
154    // Extra precise exp of ylogx
155    const float R_64_BY_LOG2 = 0x1.715476p+6f; // 64/log2 : 92.332482616893657
156    int n = convert_int(ylogx * R_64_BY_LOG2);
157    float nf = (float) n;
158
159    int j = n & 0x3f;
160    m = n >> 6;
161    int m2 = m << EXPSHIFTBITS_SP32;
162
163    const float R_LOG2_BY_64_LD = 0x1.620000p-7f;  // log2/64 lead: 0.0108032227
164    const float R_LOG2_BY_64_TL = 0x1.c85fdep-16f; // log2/64 tail: 0.0000272020388
165    r = mad(nf, -R_LOG2_BY_64_TL, mad(nf, -R_LOG2_BY_64_LD, ylogx)) + ylogx_t;
166
167    // Truncated Taylor series for e^r
168    poly = mad(mad(mad(r, 0x1.555556p-5f, 0x1.555556p-3f), r, 0x1.000000p-1f), r*r, r);
169
170    tv = USE_TABLE(exp_tbl_ep, j);
171
172    float expylogx = mad(tv.s0, poly, mad(tv.s1, poly, tv.s1)) + tv.s0;
173    float sexpylogx = __clc_fp32_subnormals_supported() ? expylogx * as_float(0x1 << (m + 149)) : 0.0f;
174
175    float texpylogx = as_float(as_int(expylogx) + m2);
176    expylogx = m < -125 ? sexpylogx : texpylogx;
177
178    // Result is +-Inf if (ylogx + ylogx_t) > 128*log2
179    expylogx = ((ylogx > 0x1.62e430p+6f) | (ylogx == 0x1.62e430p+6f & ylogx_t > -0x1.05c610p-22f)) ? as_float(PINFBITPATT_SP32) : expylogx;
180
181    // Result is 0 if ylogx < -149*log2
182    expylogx = ylogx <  -0x1.9d1da0p+6f ? 0.0f : expylogx;
183
184    // Classify y:
185    //   inty = 0 means not an integer.
186    //   inty = 1 means odd integer.
187    //   inty = 2 means even integer.
188
189    int inty = 2 - (ny & 1);
190
191    float signval = as_float((as_uint(expylogx) ^ SIGNBIT_SP32));
192    expylogx = ((inty == 1) & !xpos) ? signval : expylogx;
193    int ret = as_int(expylogx);
194
195    // Corner case handling
196    ret = (!xpos & (inty == 2)) ? QNANBITPATT_SP32 : ret;
197    int xinf = xpos ? PINFBITPATT_SP32 : NINFBITPATT_SP32;
198    ret = ((ax == 0) & !ypos & (inty == 1)) ? xinf : ret;
199    ret = ((ax == 0) & !ypos & (inty == 2)) ? PINFBITPATT_SP32 : ret;
200    ret = ((ax == 0) & ypos & (inty == 2)) ? 0 : ret;
201    int xzero = xpos ? 0 : 0x80000000;
202    ret = ((ax == 0) & ypos & (inty == 1)) ? xzero : ret;
203    ret = ((ix == NINFBITPATT_SP32) & ypos & (inty == 1)) ? NINFBITPATT_SP32 : ret;
204    ret = ((ix == NINFBITPATT_SP32) & !ypos & (inty == 1)) ? 0x80000000 : ret;
205    ret = ((ix == PINFBITPATT_SP32) & !ypos) ? 0 : ret;
206    ret = ((ix == PINFBITPATT_SP32) & ypos) ? PINFBITPATT_SP32 : ret;
207    ret = ax > PINFBITPATT_SP32 ? ix : ret;
208    ret = ny == 0 ? QNANBITPATT_SP32 : ret;
209
210    return as_float(ret);
211}
212_CLC_BINARY_VECTORIZE(_CLC_DEF _CLC_OVERLOAD, float, __clc_rootn, float, int)
213
214#ifdef cl_khr_fp64
215_CLC_DEF _CLC_OVERLOAD double __clc_rootn(double x, int ny)
216{
217    const double real_log2_tail = 5.76999904754328540596e-08;
218    const double real_log2_lead = 6.93147122859954833984e-01;
219
220    double dny = (double)ny;
221    double y = 1.0 / dny;
222
223    long ux = as_long(x);
224    long ax = ux & (~SIGNBIT_DP64);
225    int xpos = ax == ux;
226
227    long uy = as_long(y);
228    long ay = uy & (~SIGNBIT_DP64);
229    int ypos = ay == uy;
230
231    // Extended precision log
232    double v, vt;
233    {
234        int exp = (int)(ax >> 52) - 1023;
235        int mask_exp_1023 = exp == -1023;
236        double xexp = (double) exp;
237        long mantissa = ax & 0x000FFFFFFFFFFFFFL;
238
239        long temp_ux = as_long(as_double(0x3ff0000000000000L | mantissa) - 1.0);
240        exp = ((temp_ux & 0x7FF0000000000000L) >> 52) - 2045;
241        double xexp1 = (double) exp;
242        long mantissa1 = temp_ux & 0x000FFFFFFFFFFFFFL;
243
244        xexp = mask_exp_1023 ? xexp1 : xexp;
245        mantissa = mask_exp_1023 ? mantissa1 : mantissa;
246
247        long rax = (mantissa & 0x000ff00000000000) + ((mantissa & 0x0000080000000000) << 1);
248        int index = rax >> 44;
249
250        double F = as_double(rax | 0x3FE0000000000000L);
251        double Y = as_double(mantissa | 0x3FE0000000000000L);
252        double f = F - Y;
253        double2 tv = USE_TABLE(log_f_inv_tbl, index);
254        double log_h = tv.s0;
255        double log_t = tv.s1;
256        double f_inv = (log_h + log_t) * f;
257        double r1 = as_double(as_long(f_inv) & 0xfffffffff8000000L);
258        double r2 = fma(-F, r1, f) * (log_h + log_t);
259        double r = r1 + r2;
260
261        double poly = fma(r,
262                          fma(r,
263                              fma(r,
264                                  fma(r, 1.0/7.0, 1.0/6.0),
265                                  1.0/5.0),
266                              1.0/4.0),
267                          1.0/3.0);
268        poly = poly * r * r * r;
269
270        double hr1r1 = 0.5*r1*r1;
271        double poly0h = r1 + hr1r1;
272        double poly0t = r1 - poly0h + hr1r1;
273        poly = fma(r1, r2, fma(0.5*r2, r2, poly)) + r2 + poly0t;
274
275        tv = USE_TABLE(powlog_tbl, index);
276        log_h = tv.s0;
277        log_t = tv.s1;
278
279        double resT_t = fma(xexp, real_log2_tail, + log_t) - poly;
280        double resT = resT_t - poly0h;
281        double resH = fma(xexp, real_log2_lead, log_h);
282        double resT_h = poly0h;
283
284        double H = resT + resH;
285        double H_h = as_double(as_long(H) & 0xfffffffff8000000L);
286        double T = (resH - H + resT) + (resT_t - (resT + resT_h)) + (H - H_h);
287        H = H_h;
288
289        double y_head = as_double(uy & 0xfffffffff8000000L);
290        double y_tail = y - y_head;
291
292        double fnyh = as_double(as_long(dny) & 0xfffffffffff00000);
293        double fnyt = (double)(ny - (int)fnyh);
294        y_tail = fma(-fnyt, y_head, fma(-fnyh, y_head, 1.0))/ dny;
295
296        double temp = fma(y_tail, H, fma(y_head, T, y_tail*T));
297        v = fma(y_head, H, temp);
298        vt = fma(y_head, H, -v) + temp;
299    }
300
301    // Now calculate exp of (v,vt)
302
303    double expv;
304    {
305        const double max_exp_arg = 709.782712893384;
306        const double min_exp_arg = -745.1332191019411;
307        const double sixtyfour_by_lnof2 = 92.33248261689366;
308        const double lnof2_by_64_head = 0.010830424260348081;
309        const double lnof2_by_64_tail = -4.359010638708991e-10;
310
311        double temp = v * sixtyfour_by_lnof2;
312        int n = (int)temp;
313        double dn = (double)n;
314        int j = n & 0x0000003f;
315        int m = n >> 6;
316
317        double2 tv = USE_TABLE(two_to_jby64_ep_tbl, j);
318        double f1 = tv.s0;
319        double f2 = tv.s1;
320        double f = f1 + f2;
321
322        double r1 = fma(dn, -lnof2_by_64_head, v);
323        double r2 = dn * lnof2_by_64_tail;
324        double r = (r1 + r2) + vt;
325
326        double q = fma(r,
327                       fma(r,
328                           fma(r,
329                               fma(r, 1.38889490863777199667e-03, 8.33336798434219616221e-03),
330                               4.16666666662260795726e-02),
331                           1.66666666665260878863e-01),
332                       5.00000000000000008883e-01);
333        q = fma(r*r, q, r);
334
335        expv = fma(f, q, f2) + f1;
336	      expv = ldexp(expv, m);
337
338        expv = v > max_exp_arg ? as_double(0x7FF0000000000000L) : expv;
339        expv = v < min_exp_arg ? 0.0 : expv;
340    }
341
342    // See whether y is an integer.
343    // inty = 0 means not an integer.
344    // inty = 1 means odd integer.
345    // inty = 2 means even integer.
346
347    int inty = 2 - (ny & 1);
348
349    expv *= ((inty == 1) & !xpos) ? -1.0 : 1.0;
350
351    long ret = as_long(expv);
352
353    // Now all the edge cases
354    ret = (!xpos & (inty == 2)) ? QNANBITPATT_DP64 : ret;
355    long xinf = xpos ? PINFBITPATT_DP64 : NINFBITPATT_DP64;
356    ret = ((ax == 0L) & !ypos & (inty == 1)) ? xinf : ret;
357    ret = ((ax == 0L) & !ypos & (inty == 2)) ? PINFBITPATT_DP64 : ret;
358    ret = ((ax == 0L) & ypos & (inty == 2)) ? 0L : ret;
359    long xzero = xpos ? 0L : 0x8000000000000000L;
360    ret = ((ax == 0L) & ypos & (inty == 1)) ? xzero : ret;
361    ret = ((ux == NINFBITPATT_DP64) & ypos & (inty == 1)) ? NINFBITPATT_DP64 : ret;
362    ret = ((ux == NINFBITPATT_DP64) & !ypos & (inty == 1)) ? 0x8000000000000000L : ret;
363    ret = ((ux == PINFBITPATT_DP64) & !ypos) ? 0L : ret;
364    ret = ((ux == PINFBITPATT_DP64) & ypos) ? PINFBITPATT_DP64 : ret;
365    ret = ax > PINFBITPATT_DP64 ? ux : ret;
366    ret = ny == 0 ? QNANBITPATT_DP64 : ret;
367    return as_double(ret);
368}
369_CLC_BINARY_VECTORIZE(_CLC_DEF _CLC_OVERLOAD, double, __clc_rootn, double, int)
370#endif
371