1 /*
2  * Copyright © 2015 Red Hat
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  * Authors:
24  *    Rob Clark <robclark@freedesktop.org>
25  */
26 
27 #include "nir.h"
28 #include "nir_builder.h"
29 
30 /* Has two paths
31  * One (nir_lower_idiv_fast) lowers idiv/udiv/umod and is based on
32  * NV50LegalizeSSA::handleDIV()
33  *
34  * Note that this path probably does not have not enough precision for
35  * compute shaders. Perhaps we want a second higher precision (looping)
36  * version of this? Or perhaps we assume if you can do compute shaders you
37  * can also branch out to a pre-optimized shader library routine..
38  *
39  * The other path (nir_lower_idiv_precise) is based off of code used by LLVM's
40  * AMDGPU target. It should handle 32-bit idiv/irem/imod/udiv/umod exactly.
41  */
42 
43 static nir_ssa_def *
convert_instr(nir_builder * bld,nir_op op,nir_ssa_def * numer,nir_ssa_def * denom)44 convert_instr(nir_builder *bld, nir_op op,
45       nir_ssa_def *numer, nir_ssa_def *denom)
46 {
47    nir_ssa_def *af, *bf, *a, *b, *q, *r, *rt;
48    bool is_signed;
49 
50    is_signed = (op == nir_op_idiv ||
51                 op == nir_op_imod ||
52                 op == nir_op_irem);
53 
54    if (is_signed) {
55       af = nir_i2f32(bld, numer);
56       bf = nir_i2f32(bld, denom);
57       af = nir_fabs(bld, af);
58       bf = nir_fabs(bld, bf);
59       a  = nir_iabs(bld, numer);
60       b  = nir_iabs(bld, denom);
61    } else {
62       af = nir_u2f32(bld, numer);
63       bf = nir_u2f32(bld, denom);
64       a  = numer;
65       b  = denom;
66    }
67 
68    /* get first result: */
69    bf = nir_frcp(bld, bf);
70    bf = nir_isub(bld, bf, nir_imm_int(bld, 2));  /* yes, really */
71    q  = nir_fmul(bld, af, bf);
72 
73    if (is_signed) {
74       q = nir_f2i32(bld, q);
75    } else {
76       q = nir_f2u32(bld, q);
77    }
78 
79    /* get error of first result: */
80    r = nir_imul(bld, q, b);
81    r = nir_isub(bld, a, r);
82    r = nir_u2f32(bld, r);
83    r = nir_fmul(bld, r, bf);
84    r = nir_f2u32(bld, r);
85 
86    /* add quotients: */
87    q = nir_iadd(bld, q, r);
88 
89    /* correction: if modulus >= divisor, add 1 */
90    r = nir_imul(bld, q, b);
91    r = nir_isub(bld, a, r);
92    rt = nir_uge(bld, r, b);
93 
94    if (op == nir_op_umod) {
95       q = nir_bcsel(bld, rt, nir_isub(bld, r, b), r);
96    } else {
97       r = nir_b2i32(bld, rt);
98 
99       q = nir_iadd(bld, q, r);
100       if (is_signed)  {
101          /* fix the sign: */
102          r = nir_ixor(bld, numer, denom);
103          r = nir_ilt(bld, r, nir_imm_int(bld, 0));
104          b = nir_ineg(bld, q);
105          q = nir_bcsel(bld, r, b, q);
106 
107          if (op == nir_op_imod || op == nir_op_irem) {
108             q = nir_imul(bld, q, denom);
109             q = nir_isub(bld, numer, q);
110             if (op == nir_op_imod) {
111                q = nir_bcsel(bld, nir_ieq_imm(bld, q, 0),
112                              nir_imm_int(bld, 0),
113                              nir_bcsel(bld, r, nir_iadd(bld, q, denom), q));
114             }
115          }
116       }
117    }
118 
119    return q;
120 }
121 
122 /* ported from LLVM's AMDGPUTargetLowering::LowerUDIVREM */
123 static nir_ssa_def *
emit_udiv(nir_builder * bld,nir_ssa_def * numer,nir_ssa_def * denom,bool modulo)124 emit_udiv(nir_builder *bld, nir_ssa_def *numer, nir_ssa_def *denom, bool modulo)
125 {
126    nir_ssa_def *rcp = nir_frcp(bld, nir_u2f32(bld, denom));
127    rcp = nir_f2u32(bld, nir_fmul_imm(bld, rcp, 4294966784.0));
128 
129    nir_ssa_def *neg_rcp_times_denom =
130       nir_imul(bld, rcp, nir_ineg(bld, denom));
131    rcp = nir_iadd(bld, rcp, nir_umul_high(bld, rcp, neg_rcp_times_denom));
132 
133    /* Get initial estimate for quotient/remainder, then refine the estimate
134     * in two iterations after */
135    nir_ssa_def *quotient = nir_umul_high(bld, numer, rcp);
136    nir_ssa_def *num_s_remainder = nir_imul(bld, quotient, denom);
137    nir_ssa_def *remainder = nir_isub(bld, numer, num_s_remainder);
138 
139    /* First refinement step */
140    nir_ssa_def *remainder_ge_den = nir_uge(bld, remainder, denom);
141    if (!modulo) {
142       quotient = nir_bcsel(bld, remainder_ge_den,
143                            nir_iadd_imm(bld, quotient, 1), quotient);
144    }
145    remainder = nir_bcsel(bld, remainder_ge_den,
146                          nir_isub(bld, remainder, denom), remainder);
147 
148    /* Second refinement step */
149    remainder_ge_den = nir_uge(bld, remainder, denom);
150    if (modulo) {
151       return nir_bcsel(bld, remainder_ge_den, nir_isub(bld, remainder, denom),
152                        remainder);
153    } else {
154       return nir_bcsel(bld, remainder_ge_den, nir_iadd_imm(bld, quotient, 1),
155                        quotient);
156    }
157 }
158 
159 /* ported from LLVM's AMDGPUTargetLowering::LowerSDIVREM */
160 static nir_ssa_def *
emit_idiv(nir_builder * bld,nir_ssa_def * numer,nir_ssa_def * denom,nir_op op)161 emit_idiv(nir_builder *bld, nir_ssa_def *numer, nir_ssa_def *denom, nir_op op)
162 {
163    nir_ssa_def *lh_sign = nir_ilt(bld, numer, nir_imm_int(bld, 0));
164    nir_ssa_def *rh_sign = nir_ilt(bld, denom, nir_imm_int(bld, 0));
165    lh_sign = nir_bcsel(bld, lh_sign, nir_imm_int(bld, -1), nir_imm_int(bld, 0));
166    rh_sign = nir_bcsel(bld, rh_sign, nir_imm_int(bld, -1), nir_imm_int(bld, 0));
167 
168    nir_ssa_def *lhs = nir_iadd(bld, numer, lh_sign);
169    nir_ssa_def *rhs = nir_iadd(bld, denom, rh_sign);
170    lhs = nir_ixor(bld, lhs, lh_sign);
171    rhs = nir_ixor(bld, rhs, rh_sign);
172 
173    if (op == nir_op_idiv) {
174       nir_ssa_def *d_sign = nir_ixor(bld, lh_sign, rh_sign);
175       nir_ssa_def *res = emit_udiv(bld, lhs, rhs, false);
176       res = nir_ixor(bld, res, d_sign);
177       return nir_isub(bld, res, d_sign);
178    } else {
179       nir_ssa_def *res = emit_udiv(bld, lhs, rhs, true);
180       res = nir_ixor(bld, res, lh_sign);
181       res = nir_isub(bld, res, lh_sign);
182       if (op == nir_op_imod) {
183          nir_ssa_def *cond = nir_ieq_imm(bld, res, 0);
184          cond = nir_ior(bld, nir_ieq(bld, lh_sign, rh_sign), cond);
185          res = nir_bcsel(bld, cond, res, nir_iadd(bld, res, denom));
186       }
187       return res;
188    }
189 }
190 
191 static nir_ssa_def *
convert_instr_precise(nir_builder * bld,nir_op op,nir_ssa_def * numer,nir_ssa_def * denom)192 convert_instr_precise(nir_builder *bld, nir_op op,
193       nir_ssa_def *numer, nir_ssa_def *denom)
194 {
195    if (op == nir_op_udiv || op == nir_op_umod)
196       return emit_udiv(bld, numer, denom, op == nir_op_umod);
197    else
198       return emit_idiv(bld, numer, denom, op);
199 }
200 
201 static nir_ssa_def *
convert_instr_small(nir_builder * b,nir_op op,nir_ssa_def * numer,nir_ssa_def * denom,const nir_lower_idiv_options * options)202 convert_instr_small(nir_builder *b, nir_op op,
203       nir_ssa_def *numer, nir_ssa_def *denom,
204       const nir_lower_idiv_options *options)
205 {
206    unsigned sz = numer->bit_size;
207    nir_alu_type int_type = nir_op_infos[op].output_type | sz;
208    nir_alu_type float_type = nir_type_float | (options->allow_fp16 ? sz * 2 : 32);
209 
210    nir_ssa_def *p = nir_type_convert(b, numer, int_type, float_type);
211    nir_ssa_def *q = nir_type_convert(b, denom, int_type, float_type);
212 
213    /* Take 1/q but offset mantissa by 1 to correct for rounding. This is
214     * needed for correct results and has been checked exhaustively for
215     * all pairs of 16-bit integers */
216    nir_ssa_def *rcp = nir_iadd_imm(b, nir_frcp(b, q), 1);
217 
218    /* Divide by multiplying by adjusted reciprocal */
219    nir_ssa_def *res = nir_fmul(b, p, rcp);
220 
221    /* Convert back to integer space with rounding inferred by type */
222    res = nir_type_convert(b, res, float_type, int_type);
223 
224    /* Get remainder given the quotient */
225    if (op == nir_op_umod || op == nir_op_imod || op == nir_op_irem)
226       res = nir_isub(b, numer, nir_imul(b, denom, res));
227 
228    /* Adjust for sign, see constant folding definition */
229    if (op == nir_op_imod) {
230       nir_ssa_def *zero = nir_imm_zero(b, 1, sz);
231       nir_ssa_def *diff_sign =
232                nir_ine(b, nir_ige(b, numer, zero), nir_ige(b, denom, zero));
233 
234       nir_ssa_def *adjust = nir_iand(b, diff_sign, nir_ine(b, res, zero));
235       res = nir_iadd(b, res, nir_bcsel(b, adjust, denom, zero));
236    }
237 
238    return res;
239 }
240 
241 static nir_ssa_def *
lower_idiv(nir_builder * b,nir_instr * instr,void * _data)242 lower_idiv(nir_builder *b, nir_instr *instr, void *_data)
243 {
244    const nir_lower_idiv_options *options = _data;
245    nir_alu_instr *alu = nir_instr_as_alu(instr);
246 
247    nir_ssa_def *numer = nir_ssa_for_alu_src(b, alu, 0);
248    nir_ssa_def *denom = nir_ssa_for_alu_src(b, alu, 1);
249 
250    b->exact = true;
251 
252    if (numer->bit_size < 32)
253       return convert_instr_small(b, alu->op, numer, denom, options);
254    else if (options->imprecise_32bit_lowering)
255       return convert_instr(b, alu->op, numer, denom);
256    else
257       return convert_instr_precise(b, alu->op, numer, denom);
258 }
259 
260 static bool
inst_is_idiv(const nir_instr * instr,UNUSED const void * _state)261 inst_is_idiv(const nir_instr *instr, UNUSED const void *_state)
262 {
263    if (instr->type != nir_instr_type_alu)
264       return false;
265 
266    nir_alu_instr *alu = nir_instr_as_alu(instr);
267 
268    if (alu->dest.dest.ssa.bit_size > 32)
269       return false;
270 
271    switch (alu->op) {
272    case nir_op_idiv:
273    case nir_op_udiv:
274    case nir_op_imod:
275    case nir_op_umod:
276    case nir_op_irem:
277       return true;
278    default:
279       return false;
280    }
281 }
282 
283 bool
nir_lower_idiv(nir_shader * shader,const nir_lower_idiv_options * options)284 nir_lower_idiv(nir_shader *shader, const nir_lower_idiv_options *options)
285 {
286    return nir_shader_lower_instructions(shader,
287          inst_is_idiv,
288          lower_idiv,
289          (void *)options);
290 }
291