1 /*
2  * Copyright © 2015 Connor Abbott
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  */
24 
25 #include "nir.h"
26 #include "nir_vla.h"
27 #include "nir_builder.h"
28 #include "util/u_dynarray.h"
29 
30 #define HASH(hash, data) XXH32(&data, sizeof(data), hash)
31 
32 static uint32_t
hash_src(uint32_t hash,const nir_src * src)33 hash_src(uint32_t hash, const nir_src *src)
34 {
35    assert(src->is_ssa);
36    void *hash_data = nir_src_is_const(*src) ? NULL : src->ssa;
37 
38    return HASH(hash, hash_data);
39 }
40 
41 static uint32_t
hash_alu_src(uint32_t hash,const nir_alu_src * src,uint32_t num_components,uint32_t max_vec)42 hash_alu_src(uint32_t hash, const nir_alu_src *src,
43              uint32_t num_components, uint32_t max_vec)
44 {
45    assert(!src->abs && !src->negate);
46 
47    /* hash whether a swizzle accesses elements beyond the maximum
48     * vectorization factor:
49     * For example accesses to .x and .y are considered different variables
50     * compared to accesses to .z and .w for 16-bit vec2.
51     */
52    uint32_t swizzle = (src->swizzle[0] & ~(max_vec - 1));
53    hash = HASH(hash, swizzle);
54 
55    return hash_src(hash, &src->src);
56 }
57 
58 static uint32_t
hash_instr(const void * data)59 hash_instr(const void *data)
60 {
61    const nir_instr *instr = (nir_instr *) data;
62    assert(instr->type == nir_instr_type_alu);
63    nir_alu_instr *alu = nir_instr_as_alu(instr);
64 
65    uint32_t hash = HASH(0, alu->op);
66    hash = HASH(hash, alu->dest.dest.ssa.bit_size);
67 
68    for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
69       hash = hash_alu_src(hash, &alu->src[i],
70                           alu->dest.dest.ssa.num_components,
71                           instr->pass_flags);
72 
73    return hash;
74 }
75 
76 static bool
srcs_equal(const nir_src * src1,const nir_src * src2)77 srcs_equal(const nir_src *src1, const nir_src *src2)
78 {
79    assert(src1->is_ssa);
80    assert(src2->is_ssa);
81 
82    return src1->ssa == src2->ssa ||
83           (nir_src_is_const(*src1) && nir_src_is_const(*src2));
84 }
85 
86 static bool
alu_srcs_equal(const nir_alu_src * src1,const nir_alu_src * src2,uint32_t max_vec)87 alu_srcs_equal(const nir_alu_src *src1, const nir_alu_src *src2,
88                uint32_t max_vec)
89 {
90    assert(!src1->abs);
91    assert(!src1->negate);
92    assert(!src2->abs);
93    assert(!src2->negate);
94 
95    uint32_t mask = ~(max_vec - 1);
96    if ((src1->swizzle[0] & mask) != (src2->swizzle[0] & mask))
97       return false;
98 
99    return srcs_equal(&src1->src, &src2->src);
100 }
101 
102 static bool
instrs_equal(const void * data1,const void * data2)103 instrs_equal(const void *data1, const void *data2)
104 {
105    const nir_instr *instr1 = (nir_instr *) data1;
106    const nir_instr *instr2 = (nir_instr *) data2;
107    assert(instr1->type == nir_instr_type_alu);
108    assert(instr2->type == nir_instr_type_alu);
109 
110    nir_alu_instr *alu1 = nir_instr_as_alu(instr1);
111    nir_alu_instr *alu2 = nir_instr_as_alu(instr2);
112 
113    if (alu1->op != alu2->op)
114       return false;
115 
116    if (alu1->dest.dest.ssa.bit_size != alu2->dest.dest.ssa.bit_size)
117       return false;
118 
119    for (unsigned i = 0; i < nir_op_infos[alu1->op].num_inputs; i++) {
120       if (!alu_srcs_equal(&alu1->src[i], &alu2->src[i], instr1->pass_flags))
121          return false;
122    }
123 
124    return true;
125 }
126 
127 static bool
instr_can_rewrite(nir_instr * instr,bool vectorize_16bit)128 instr_can_rewrite(nir_instr *instr, bool vectorize_16bit)
129 {
130    switch (instr->type) {
131    case nir_instr_type_alu: {
132       nir_alu_instr *alu = nir_instr_as_alu(instr);
133 
134       /* Don't try and vectorize mov's. Either they'll be handled by copy
135        * prop, or they're actually necessary and trying to vectorize them
136        * would result in fighting with copy prop.
137        */
138       if (alu->op == nir_op_mov)
139          return false;
140 
141       /* no need to hash instructions which are already vectorized */
142       if (alu->dest.dest.ssa.num_components >= 4)
143          return false;
144 
145       if (vectorize_16bit &&
146           (alu->dest.dest.ssa.num_components >= 2 ||
147            alu->dest.dest.ssa.bit_size != 16))
148          return false;
149 
150       if (nir_op_infos[alu->op].output_size != 0)
151          return false;
152 
153       for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
154          if (nir_op_infos[alu->op].input_sizes[i] != 0)
155             return false;
156 
157          /* don't hash instructions which are already swizzled
158           * outside of max_components: these should better be scalarized */
159          uint32_t mask = vectorize_16bit ? ~1 : ~3;
160          for (unsigned j = 0; j < alu->dest.dest.ssa.num_components; j++) {
161             if ((alu->src[i].swizzle[0] & mask) != (alu->src[i].swizzle[j] & mask))
162                return false;
163          }
164       }
165 
166       return true;
167    }
168 
169    /* TODO support phi nodes */
170    default:
171       break;
172    }
173 
174    return false;
175 }
176 
177 /*
178  * Tries to combine two instructions whose sources are different components of
179  * the same instructions into one vectorized instruction. Note that instr1
180  * should dominate instr2.
181  */
182 
183 static nir_instr *
instr_try_combine(struct nir_shader * nir,struct set * instr_set,nir_instr * instr1,nir_instr * instr2)184 instr_try_combine(struct nir_shader *nir, struct set *instr_set,
185                   nir_instr *instr1, nir_instr *instr2)
186 {
187    assert(instr1->type == nir_instr_type_alu);
188    assert(instr2->type == nir_instr_type_alu);
189    nir_alu_instr *alu1 = nir_instr_as_alu(instr1);
190    nir_alu_instr *alu2 = nir_instr_as_alu(instr2);
191 
192    assert(alu1->dest.dest.ssa.bit_size == alu2->dest.dest.ssa.bit_size);
193    unsigned alu1_components = alu1->dest.dest.ssa.num_components;
194    unsigned alu2_components = alu2->dest.dest.ssa.num_components;
195    unsigned total_components = alu1_components + alu2_components;
196 
197    if (total_components > 4)
198       return NULL;
199 
200    if (nir->options->vectorize_vec2_16bit) {
201       assert(total_components == 2);
202       assert(alu1->dest.dest.ssa.bit_size == 16);
203    }
204 
205    nir_builder b;
206    nir_builder_init(&b, nir_cf_node_get_function(&instr1->block->cf_node));
207    b.cursor = nir_after_instr(instr1);
208 
209    nir_alu_instr *new_alu = nir_alu_instr_create(b.shader, alu1->op);
210    nir_ssa_dest_init(&new_alu->instr, &new_alu->dest.dest,
211                      total_components, alu1->dest.dest.ssa.bit_size, NULL);
212    new_alu->dest.write_mask = (1 << total_components) - 1;
213    new_alu->instr.pass_flags = alu1->instr.pass_flags;
214 
215    /* If either channel is exact, we have to preserve it even if it's
216     * not optimal for other channels.
217     */
218    new_alu->exact = alu1->exact || alu2->exact;
219 
220    /* If all channels don't wrap, we can say that the whole vector doesn't
221     * wrap.
222     */
223    new_alu->no_signed_wrap = alu1->no_signed_wrap && alu2->no_signed_wrap;
224    new_alu->no_unsigned_wrap = alu1->no_unsigned_wrap && alu2->no_unsigned_wrap;
225 
226    for (unsigned i = 0; i < nir_op_infos[alu1->op].num_inputs; i++) {
227       /* handle constant merging case */
228       if (alu1->src[i].src.ssa != alu2->src[i].src.ssa) {
229          nir_const_value *c1 = nir_src_as_const_value(alu1->src[i].src);
230          nir_const_value *c2 = nir_src_as_const_value(alu2->src[i].src);
231          assert(c1 && c2);
232          nir_const_value value[NIR_MAX_VEC_COMPONENTS];
233          unsigned bit_size = alu1->src[i].src.ssa->bit_size;
234 
235          for (unsigned j = 0; j < total_components; j++) {
236             value[j].u64 = j < alu1_components ?
237                               c1[alu1->src[i].swizzle[j]].u64 :
238                               c2[alu2->src[i].swizzle[j - alu1_components]].u64;
239          }
240          nir_ssa_def *def = nir_build_imm(&b, total_components, bit_size, value);
241 
242          new_alu->src[i].src = nir_src_for_ssa(def);
243          for (unsigned j = 0; j < total_components; j++)
244             new_alu->src[i].swizzle[j] = j;
245          continue;
246       }
247 
248       new_alu->src[i].src = alu1->src[i].src;
249 
250       for (unsigned j = 0; j < alu1_components; j++)
251          new_alu->src[i].swizzle[j] = alu1->src[i].swizzle[j];
252 
253       for (unsigned j = 0; j < alu2_components; j++) {
254          new_alu->src[i].swizzle[j + alu1_components] =
255             alu2->src[i].swizzle[j];
256       }
257    }
258 
259    nir_builder_instr_insert(&b, &new_alu->instr);
260 
261    unsigned swiz[NIR_MAX_VEC_COMPONENTS];
262    for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
263       swiz[i] = i;
264    nir_ssa_def *new_alu1 = nir_swizzle(&b, &new_alu->dest.dest.ssa, swiz,
265                                        alu1_components);
266 
267    for (unsigned i = 0; i < alu2_components; i++)
268       swiz[i] += alu1_components;
269    nir_ssa_def *new_alu2 = nir_swizzle(&b, &new_alu->dest.dest.ssa, swiz,
270                                        alu2_components);
271 
272    nir_foreach_use_safe(src, &alu1->dest.dest.ssa) {
273       nir_instr *user_instr = src->parent_instr;
274       if (user_instr->type == nir_instr_type_alu) {
275          /* Check if user is found in the hashset */
276          struct set_entry *entry = _mesa_set_search(instr_set, user_instr);
277 
278          /* For ALU instructions, rewrite the source directly to avoid a
279           * round-trip through copy propagation.
280           */
281          nir_instr_rewrite_src(user_instr, src,
282                                nir_src_for_ssa(&new_alu->dest.dest.ssa));
283 
284          /* Rehash user if it was found in the hashset */
285          if (entry && entry->key == user_instr) {
286             _mesa_set_remove(instr_set, entry);
287             _mesa_set_add(instr_set, src->parent_instr);
288          }
289       } else {
290          nir_instr_rewrite_src(user_instr, src, nir_src_for_ssa(new_alu1));
291       }
292    }
293 
294    nir_foreach_if_use_safe(src, &alu1->dest.dest.ssa) {
295       nir_if_rewrite_condition(src->parent_if, nir_src_for_ssa(new_alu1));
296    }
297 
298    assert(nir_ssa_def_is_unused(&alu1->dest.dest.ssa));
299 
300    nir_foreach_use_safe(src, &alu2->dest.dest.ssa) {
301       if (src->parent_instr->type == nir_instr_type_alu) {
302          /* For ALU instructions, rewrite the source directly to avoid a
303           * round-trip through copy propagation.
304           */
305 
306          nir_alu_instr *use = nir_instr_as_alu(src->parent_instr);
307 
308          unsigned src_index = 5;
309          for (unsigned i = 0; i < nir_op_infos[use->op].num_inputs; i++) {
310             if (&use->src[i].src == src) {
311                src_index = i;
312                break;
313             }
314          }
315          assert(src_index != 5);
316 
317          nir_instr_rewrite_src(src->parent_instr, src,
318                                nir_src_for_ssa(&new_alu->dest.dest.ssa));
319 
320          for (unsigned i = 0;
321               i < nir_ssa_alu_instr_src_components(use, src_index); i++) {
322             use->src[src_index].swizzle[i] += alu1_components;
323          }
324       } else {
325          nir_instr_rewrite_src(src->parent_instr, src,
326                                nir_src_for_ssa(new_alu2));
327       }
328    }
329 
330    nir_foreach_if_use_safe(src, &alu2->dest.dest.ssa) {
331       nir_if_rewrite_condition(src->parent_if, nir_src_for_ssa(new_alu2));
332    }
333 
334    assert(nir_ssa_def_is_unused(&alu2->dest.dest.ssa));
335 
336    nir_instr_remove(instr1);
337    nir_instr_remove(instr2);
338 
339    return &new_alu->instr;
340 }
341 
342 static struct set *
vec_instr_set_create(void)343 vec_instr_set_create(void)
344 {
345    return _mesa_set_create(NULL, hash_instr, instrs_equal);
346 }
347 
348 static void
vec_instr_set_destroy(struct set * instr_set)349 vec_instr_set_destroy(struct set *instr_set)
350 {
351    _mesa_set_destroy(instr_set, NULL);
352 }
353 
354 static bool
vec_instr_set_add_or_rewrite(struct nir_shader * nir,struct set * instr_set,nir_instr * instr,nir_opt_vectorize_cb filter,void * data)355 vec_instr_set_add_or_rewrite(struct nir_shader *nir, struct set *instr_set,
356                              nir_instr *instr,
357                              nir_opt_vectorize_cb filter, void *data)
358 {
359    if (!instr_can_rewrite(instr, nir->options->vectorize_vec2_16bit))
360       return false;
361 
362    if (filter && !filter(instr, data))
363       return false;
364 
365    /* set max vector to instr pass flags: this is used to hash swizzles */
366    instr->pass_flags = nir->options->vectorize_vec2_16bit ? 2 : 4;
367 
368    struct set_entry *entry = _mesa_set_search(instr_set, instr);
369    if (entry) {
370       nir_instr *old_instr = (nir_instr *) entry->key;
371       _mesa_set_remove(instr_set, entry);
372       nir_instr *new_instr = instr_try_combine(nir, instr_set,
373                                                old_instr, instr);
374       if (new_instr) {
375          if (instr_can_rewrite(new_instr, nir->options->vectorize_vec2_16bit) &&
376              (!filter || filter(new_instr, data)))
377             _mesa_set_add(instr_set, new_instr);
378          return true;
379       }
380    }
381 
382    _mesa_set_add(instr_set, instr);
383    return false;
384 }
385 
386 static bool
vectorize_block(struct nir_shader * nir,nir_block * block,struct set * instr_set,nir_opt_vectorize_cb filter,void * data)387 vectorize_block(struct nir_shader *nir, nir_block *block,
388                 struct set *instr_set,
389                 nir_opt_vectorize_cb filter, void *data)
390 {
391    bool progress = false;
392 
393    nir_foreach_instr_safe(instr, block) {
394       if (vec_instr_set_add_or_rewrite(nir, instr_set, instr, filter, data))
395          progress = true;
396    }
397 
398    for (unsigned i = 0; i < block->num_dom_children; i++) {
399       nir_block *child = block->dom_children[i];
400       progress |= vectorize_block(nir, child, instr_set, filter, data);
401    }
402 
403    nir_foreach_instr_reverse(instr, block) {
404       if (instr_can_rewrite(instr, nir->options->vectorize_vec2_16bit) &&
405           (!filter || filter(instr, data)))
406          _mesa_set_remove_key(instr_set, instr);
407    }
408 
409    return progress;
410 }
411 
412 static bool
nir_opt_vectorize_impl(struct nir_shader * nir,nir_function_impl * impl,nir_opt_vectorize_cb filter,void * data)413 nir_opt_vectorize_impl(struct nir_shader *nir, nir_function_impl *impl,
414                        nir_opt_vectorize_cb filter, void *data)
415 {
416    struct set *instr_set = vec_instr_set_create();
417 
418    nir_metadata_require(impl, nir_metadata_dominance);
419 
420    bool progress = vectorize_block(nir, nir_start_block(impl), instr_set,
421                                    filter, data);
422 
423    if (progress) {
424       nir_metadata_preserve(impl, nir_metadata_block_index |
425                                   nir_metadata_dominance);
426    } else {
427       nir_metadata_preserve(impl, nir_metadata_all);
428    }
429 
430    vec_instr_set_destroy(instr_set);
431    return progress;
432 }
433 
434 bool
nir_opt_vectorize(nir_shader * shader,nir_opt_vectorize_cb filter,void * data)435 nir_opt_vectorize(nir_shader *shader, nir_opt_vectorize_cb filter,
436                   void *data)
437 {
438    bool progress = false;
439 
440    nir_foreach_function(function, shader) {
441       if (function->impl)
442          progress |= nir_opt_vectorize_impl(shader, function->impl, filter, data);
443    }
444 
445    return progress;
446 }
447