1 /*
2  * Copyright © 2020 Collabora Ltd.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #ifndef NIR_CONVERSION_BUILDER_H
25 #define NIR_CONVERSION_BUILDER_H
26 
27 #include "util/u_math.h"
28 #include "nir_builder.h"
29 #include "nir_builtin_builder.h"
30 
31 #ifdef __cplusplus
32 extern "C" {
33 #endif
34 
35 static inline nir_ssa_def *
nir_round_float_to_int(nir_builder * b,nir_ssa_def * src,nir_rounding_mode round)36 nir_round_float_to_int(nir_builder *b, nir_ssa_def *src,
37                        nir_rounding_mode round)
38 {
39    switch (round) {
40    case nir_rounding_mode_ru:
41       return nir_fceil(b, src);
42 
43    case nir_rounding_mode_rd:
44       return nir_ffloor(b, src);
45 
46    case nir_rounding_mode_rtne:
47       return nir_fround_even(b, src);
48 
49    case nir_rounding_mode_undef:
50    case nir_rounding_mode_rtz:
51       break;
52    }
53    unreachable("unexpected rounding mode");
54 }
55 
56 static inline nir_ssa_def *
nir_round_float_to_float(nir_builder * b,nir_ssa_def * src,unsigned dest_bit_size,nir_rounding_mode round)57 nir_round_float_to_float(nir_builder *b, nir_ssa_def *src,
58                          unsigned dest_bit_size,
59                          nir_rounding_mode round)
60 {
61    unsigned src_bit_size = src->bit_size;
62    if (dest_bit_size > src_bit_size)
63       return src; /* No rounding is needed for an up-convert */
64 
65    nir_op low_conv = nir_type_conversion_op(nir_type_float | src_bit_size,
66                                             nir_type_float | dest_bit_size,
67                                             nir_rounding_mode_undef);
68    nir_op high_conv = nir_type_conversion_op(nir_type_float | dest_bit_size,
69                                              nir_type_float | src_bit_size,
70                                              nir_rounding_mode_undef);
71 
72    switch (round) {
73    case nir_rounding_mode_ru: {
74       /* If lower-precision conversion results in a lower value, push it
75       * up one ULP. */
76       nir_ssa_def *lower_prec =
77          nir_build_alu(b, low_conv, src, NULL, NULL, NULL);
78       nir_ssa_def *roundtrip =
79          nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL);
80       nir_ssa_def *cmp = nir_flt(b, roundtrip, src);
81       nir_ssa_def *inf = nir_imm_floatN_t(b, INFINITY, dest_bit_size);
82       return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, inf), lower_prec);
83    }
84    case nir_rounding_mode_rd: {
85       /* If lower-precision conversion results in a higher value, push it
86       * down one ULP. */
87       nir_ssa_def *lower_prec =
88          nir_build_alu(b, low_conv, src, NULL, NULL, NULL);
89       nir_ssa_def *roundtrip =
90          nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL);
91       nir_ssa_def *cmp = nir_flt(b, src, roundtrip);
92       nir_ssa_def *neg_inf = nir_imm_floatN_t(b, -INFINITY, dest_bit_size);
93       return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, neg_inf), lower_prec);
94    }
95    case nir_rounding_mode_rtz:
96       return nir_bcsel(b, nir_flt(b, src, nir_imm_zero(b, 1, src->bit_size)),
97                           nir_round_float_to_float(b, src, dest_bit_size,
98                                                    nir_rounding_mode_ru),
99                           nir_round_float_to_float(b, src, dest_bit_size,
100                                                    nir_rounding_mode_rd));
101    case nir_rounding_mode_rtne:
102    case nir_rounding_mode_undef:
103       break;
104    }
105    unreachable("unexpected rounding mode");
106 }
107 
108 static inline nir_ssa_def *
nir_round_int_to_float(nir_builder * b,nir_ssa_def * src,nir_alu_type src_type,unsigned dest_bit_size,nir_rounding_mode round)109 nir_round_int_to_float(nir_builder *b, nir_ssa_def *src,
110                        nir_alu_type src_type,
111                        unsigned dest_bit_size,
112                        nir_rounding_mode round)
113 {
114    /* We only care whether or not its signed */
115    src_type = nir_alu_type_get_base_type(src_type);
116 
117    unsigned mantissa_bits;
118    switch (dest_bit_size) {
119    case 16:
120       mantissa_bits = 10;
121       break;
122    case 32:
123       mantissa_bits = 23;
124       break;
125    case 64:
126       mantissa_bits = 52;
127       break;
128    default: unreachable("Unsupported bit size");
129    }
130 
131    if (src->bit_size < mantissa_bits)
132       return src;
133 
134    if (src_type == nir_type_int) {
135       nir_ssa_def *sign =
136          nir_i2b1(b, nir_ishr(b, src, nir_imm_int(b, src->bit_size - 1)));
137       nir_ssa_def *abs = nir_iabs(b, src);
138       nir_ssa_def *positive_rounded =
139          nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, round);
140       nir_ssa_def *max_positive =
141          nir_imm_intN_t(b, (1ull << (src->bit_size - 1)) - 1, src->bit_size);
142       switch (round) {
143       case nir_rounding_mode_rtz:
144          return nir_bcsel(b, sign, nir_ineg(b, positive_rounded),
145                                    positive_rounded);
146          break;
147       case nir_rounding_mode_ru:
148          return nir_bcsel(b, sign,
149                           nir_ineg(b, nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_rd)),
150                           nir_umin(b, positive_rounded, max_positive));
151          break;
152       case nir_rounding_mode_rd:
153          return nir_bcsel(b, sign,
154                           nir_ineg(b,
155                                    nir_umin(b, max_positive,
156                                             nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_ru))),
157                           positive_rounded);
158       case nir_rounding_mode_rtne:
159       case nir_rounding_mode_undef:
160          break;
161       }
162       unreachable("unexpected rounding mode");
163    } else {
164       nir_ssa_def *mantissa_bit_size = nir_imm_int(b, mantissa_bits);
165       nir_ssa_def *msb = nir_imax(b, nir_ufind_msb(b, src), mantissa_bit_size);
166       nir_ssa_def *bits_to_lose = nir_isub(b, msb, mantissa_bit_size);
167       nir_ssa_def *one = nir_imm_intN_t(b, 1, src->bit_size);
168       nir_ssa_def *adjust = nir_ishl(b, one, bits_to_lose);
169       nir_ssa_def *mask = nir_inot(b, nir_isub(b, adjust, one));
170       nir_ssa_def *truncated = nir_iand(b, src, mask);
171       switch (round) {
172       case nir_rounding_mode_rtz:
173       case nir_rounding_mode_rd:
174          return truncated;
175          break;
176       case nir_rounding_mode_ru:
177          return nir_bcsel(b, nir_ieq(b, src, truncated),
178                              src, nir_uadd_sat(b, truncated, adjust));
179       case nir_rounding_mode_rtne:
180       case nir_rounding_mode_undef:
181          break;
182       }
183       unreachable("unexpected rounding mode");
184    }
185 }
186 
187 /** Returns true if the representable range of a contains the representable
188  * range of b.
189  */
190 static inline bool
nir_alu_type_range_contains_type_range(nir_alu_type a,nir_alu_type b)191 nir_alu_type_range_contains_type_range(nir_alu_type a, nir_alu_type b)
192 {
193    /* Split types from bit sizes */
194    nir_alu_type a_base_type = nir_alu_type_get_base_type(a);
195    nir_alu_type b_base_type = nir_alu_type_get_base_type(b);
196    unsigned a_bit_size = nir_alu_type_get_type_size(a);
197    unsigned b_bit_size = nir_alu_type_get_type_size(b);
198 
199    /* This requires sized types */
200    assert(a_bit_size > 0 && b_bit_size > 0);
201 
202    if (a_base_type == b_base_type && a_bit_size >= b_bit_size)
203       return true;
204 
205    if (a_base_type == nir_type_int && b_base_type == nir_type_uint &&
206        a_bit_size > b_bit_size)
207       return true;
208 
209    /* 16-bit floats fit in 32-bit integers */
210    if (a_base_type == nir_type_int && a_bit_size >= 32 &&
211        b == nir_type_float16)
212       return true;
213 
214    /* All signed or unsigned ints can fit in float or above. A uint8 can fit
215     * in a float16.
216     */
217    if (a_base_type == nir_type_float && b_base_type != nir_type_float &&
218        (a_bit_size >= 32 || b_bit_size == 8))
219       return true;
220 
221    return false;
222 }
223 
224 /**
225  * Retrieves limits used for clamping a value of the src type into
226  * the widest representable range of the dst type via cmp + bcsel
227  */
228 static inline void
nir_get_clamp_limits(nir_builder * b,nir_alu_type src_type,nir_alu_type dest_type,nir_ssa_def ** low,nir_ssa_def ** high)229 nir_get_clamp_limits(nir_builder *b,
230                      nir_alu_type src_type,
231                      nir_alu_type dest_type,
232                      nir_ssa_def **low, nir_ssa_def **high)
233 {
234    /* Split types from bit sizes */
235    nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
236    nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
237    unsigned src_bit_size = nir_alu_type_get_type_size(src_type);
238    unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
239    assert(dest_bit_size != 0 && src_bit_size != 0);
240 
241    *low = NULL;
242    *high = NULL;
243 
244    /* limits of the destination type, expressed in the source type */
245    switch (dest_base_type) {
246    case nir_type_int: {
247       int64_t ilow, ihigh;
248       if (dest_bit_size == 64) {
249          ilow = INT64_MIN;
250          ihigh = INT64_MAX;
251       } else {
252          ilow = -(1ll << (dest_bit_size - 1));
253          ihigh = (1ll << (dest_bit_size - 1)) - 1;
254       }
255 
256       if (src_base_type == nir_type_int) {
257          *low = nir_imm_intN_t(b, ilow, src_bit_size);
258          *high = nir_imm_intN_t(b, ihigh, src_bit_size);
259       } else if (src_base_type == nir_type_uint) {
260          assert(src_bit_size >= dest_bit_size);
261          *high = nir_imm_intN_t(b, ihigh, src_bit_size);
262       } else {
263          *low = nir_imm_floatN_t(b, ilow, src_bit_size);
264          *high = nir_imm_floatN_t(b, ihigh, src_bit_size);
265       }
266       break;
267    }
268    case nir_type_uint: {
269       uint64_t uhigh = dest_bit_size == 64 ?
270          ~0ull : (1ull << dest_bit_size) - 1;
271       if (src_base_type != nir_type_float) {
272          *low = nir_imm_intN_t(b, 0, src_bit_size);
273          if (src_base_type == nir_type_uint || src_bit_size > dest_bit_size)
274             *high = nir_imm_intN_t(b, uhigh, src_bit_size);
275       } else {
276          *low = nir_imm_floatN_t(b, 0.0f, src_bit_size);
277          *high = nir_imm_floatN_t(b, uhigh, src_bit_size);
278       }
279       break;
280    }
281    case nir_type_float: {
282       double flow, fhigh;
283       switch (dest_bit_size) {
284       case 16:
285          flow = -65504.0f;
286          fhigh = 65504.0f;
287          break;
288       case 32:
289          flow = -FLT_MAX;
290          fhigh = FLT_MAX;
291          break;
292       case 64:
293          flow = -DBL_MAX;
294          fhigh = DBL_MAX;
295          break;
296       default:
297          unreachable("Unhandled bit size");
298       }
299 
300       switch (src_base_type) {
301       case nir_type_int: {
302          int64_t src_ilow, src_ihigh;
303          if (src_bit_size == 64) {
304             src_ilow = INT64_MIN;
305             src_ihigh = INT64_MAX;
306          } else {
307             src_ilow = -(1ll << (src_bit_size - 1));
308             src_ihigh = (1ll << (src_bit_size - 1)) - 1;
309          }
310          if (src_ilow < flow)
311             *low = nir_imm_intN_t(b, flow, src_bit_size);
312          if (src_ihigh > fhigh)
313             *high = nir_imm_intN_t(b, fhigh, src_bit_size);
314          break;
315       }
316       case nir_type_uint: {
317          uint64_t src_uhigh = src_bit_size == 64 ?
318             ~0ull : (1ull << src_bit_size) - 1;
319          if (src_uhigh > fhigh)
320             *high = nir_imm_intN_t(b, fhigh, src_bit_size);
321          break;
322       }
323       case nir_type_float:
324          *low = nir_imm_floatN_t(b, flow, src_bit_size);
325          *high = nir_imm_floatN_t(b, fhigh, src_bit_size);
326          break;
327       default:
328          unreachable("Clamping from unknown type");
329       }
330       break;
331    }
332    default:
333       unreachable("clamping to unknown type");
334       break;
335    }
336 }
337 
338 /**
339  * Clamp the value into the widest representatble range of the
340  * destination type with cmp + bcsel.
341  *
342  * val/val_type: The variables used for bcsel
343  * src/src_type: The variables used for comparison
344  * dest_type: The type which determines the range used for comparison
345  */
346 static inline nir_ssa_def *
nir_clamp_to_type_range(nir_builder * b,nir_ssa_def * val,nir_alu_type val_type,nir_ssa_def * src,nir_alu_type src_type,nir_alu_type dest_type)347 nir_clamp_to_type_range(nir_builder *b,
348                         nir_ssa_def *val, nir_alu_type val_type,
349                         nir_ssa_def *src, nir_alu_type src_type,
350                         nir_alu_type dest_type)
351 {
352    assert(nir_alu_type_get_type_size(src_type) == 0 ||
353           nir_alu_type_get_type_size(src_type) == src->bit_size);
354    src_type |= src->bit_size;
355    if (nir_alu_type_range_contains_type_range(dest_type, src_type))
356       return val;
357 
358    /* limits of the destination type, expressed in the source type */
359    nir_ssa_def *low = NULL, *high = NULL;
360    nir_get_clamp_limits(b, src_type, dest_type, &low, &high);
361 
362    nir_ssa_def *low_cond = NULL, *high_cond = NULL;
363    switch (nir_alu_type_get_base_type(src_type)) {
364    case nir_type_int:
365       low_cond = low ? nir_ilt(b, src, low) : NULL;
366       high_cond = high ? nir_ilt(b, high, src) : NULL;
367       break;
368    case nir_type_uint:
369       low_cond = low ? nir_ult(b, src, low) : NULL;
370       high_cond = high ? nir_ult(b, high, src) : NULL;
371       break;
372    case nir_type_float:
373       low_cond = low ? nir_fge(b, low, src) : NULL;
374       high_cond = high ? nir_fge(b, src, high) : NULL;
375       break;
376    default:
377       unreachable("clamping from unknown type");
378    }
379 
380    nir_ssa_def *val_low = low, *val_high = high;
381    if (val_type != src_type) {
382       nir_get_clamp_limits(b, val_type, dest_type, &val_low, &val_high);
383    }
384 
385    nir_ssa_def *res = val;
386    if (low_cond && val_low)
387       res = nir_bcsel(b, low_cond, val_low, res);
388    if (high_cond && val_high)
389       res = nir_bcsel(b, high_cond, val_high, res);
390 
391    return res;
392 }
393 
394 static inline nir_rounding_mode
nir_simplify_conversion_rounding(nir_alu_type src_type,nir_alu_type dest_type,nir_rounding_mode rounding)395 nir_simplify_conversion_rounding(nir_alu_type src_type,
396                                  nir_alu_type dest_type,
397                                  nir_rounding_mode rounding)
398 {
399    nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
400    nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
401    unsigned src_bit_size = nir_alu_type_get_type_size(src_type);
402    unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
403    assert(src_bit_size > 0 && dest_bit_size > 0);
404 
405    if (rounding == nir_rounding_mode_undef)
406       return rounding;
407 
408    /* Pure integer conversion doesn't have any rounding */
409    if (src_base_type != nir_type_float &&
410        dest_base_type != nir_type_float)
411       return nir_rounding_mode_undef;
412 
413    /* Float down-casts don't round */
414    if (src_base_type == nir_type_float &&
415        dest_base_type == nir_type_float &&
416        dest_bit_size >= src_bit_size)
417       return nir_rounding_mode_undef;
418 
419    /* Regular float to int conversions are RTZ */
420    if (src_base_type == nir_type_float &&
421        dest_base_type != nir_type_float &&
422        rounding == nir_rounding_mode_rtz)
423       return nir_rounding_mode_undef;
424 
425    /* The CL spec requires regular conversions to float to be RTNE */
426    if (dest_base_type == nir_type_float &&
427        rounding == nir_rounding_mode_rtne)
428       return nir_rounding_mode_undef;
429 
430    /* Couldn't simplify */
431    return rounding;
432 }
433 
434 static inline nir_ssa_def *
nir_convert_with_rounding(nir_builder * b,nir_ssa_def * src,nir_alu_type src_type,nir_alu_type dest_type,nir_rounding_mode round,bool clamp)435 nir_convert_with_rounding(nir_builder *b,
436                           nir_ssa_def *src, nir_alu_type src_type,
437                           nir_alu_type dest_type,
438                           nir_rounding_mode round,
439                           bool clamp)
440 {
441    /* Some stuff wants sized types */
442    assert(nir_alu_type_get_type_size(src_type) == 0 ||
443           nir_alu_type_get_type_size(src_type) == src->bit_size);
444    src_type |= src->bit_size;
445 
446    /* Split types from bit sizes */
447    nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
448    nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
449    unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
450 
451    /* Try to simplify the conversion if we can */
452    clamp = clamp &&
453       !nir_alu_type_range_contains_type_range(dest_type, src_type);
454    round = nir_simplify_conversion_rounding(src_type, dest_type, round);
455 
456    /* For float -> int/uint conversions, we might not be able to represent
457     * the destination range in the source float accurately. For these cases,
458     * do the comparison in float range, but the bcsel in the destination range.
459     */
460    bool clamp_after_conversion = clamp &&
461       src_base_type == nir_type_float &&
462       dest_base_type != nir_type_float;
463 
464    /*
465     * If we don't care about rounding and clamping, we can just use NIR's
466     * built-in ops. There is also a special case for SPIR-V in shaders, where
467     * f32/f64 -> f16 conversions can have one of two rounding modes applied,
468     * which NIR has built-in opcodes for.
469     *
470     * For the rest, we have our own implementation of rounding and clamping.
471     */
472    bool trivial_convert;
473    if (!clamp && round == nir_rounding_mode_undef) {
474       trivial_convert = true;
475    } else if (!clamp && src_type == nir_type_float32 &&
476                         dest_type == nir_type_float16 &&
477                         (round == nir_rounding_mode_rtne ||
478                          round == nir_rounding_mode_rtz)) {
479       trivial_convert = true;
480    } else {
481       trivial_convert = false;
482    }
483    if (trivial_convert) {
484       nir_op op = nir_type_conversion_op(src_type, dest_type, round);
485       return nir_build_alu(b, op, src, NULL, NULL, NULL);
486    }
487 
488    nir_ssa_def *dest = src;
489 
490    /* clamp the result into range */
491    if (clamp && !clamp_after_conversion)
492       dest = nir_clamp_to_type_range(b, src, src_type, src, src_type, dest_type);
493 
494    /* round with selected rounding mode */
495    if (!trivial_convert && round != nir_rounding_mode_undef) {
496       if (src_base_type == nir_type_float) {
497          if (dest_base_type == nir_type_float) {
498             dest = nir_round_float_to_float(b, dest, dest_bit_size, round);
499          } else {
500             dest = nir_round_float_to_int(b, dest, round);
501          }
502       } else {
503          dest = nir_round_int_to_float(b, dest, src_type, dest_bit_size, round);
504       }
505 
506       round = nir_rounding_mode_undef;
507    }
508 
509    /* now we can convert the value */
510    nir_op op = nir_type_conversion_op(src_type, dest_type, round);
511    dest = nir_build_alu(b, op, dest, NULL, NULL, NULL);
512 
513    if (clamp_after_conversion)
514       dest = nir_clamp_to_type_range(b, dest, dest_type, src, src_type, dest_type);
515 
516    return dest;
517 }
518 
519 #ifdef __cplusplus
520 }
521 #endif
522 
523 #endif /* NIR_CONVERSION_BUILDER_H */
524