1 /*
2  * Copyright © 2018 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_deref.h"
27 #include "nir_vla.h"
28 
29 #include "util/set.h"
30 #include "util/u_math.h"
31 
32 static struct set *
get_complex_used_vars(nir_shader * shader,void * mem_ctx)33 get_complex_used_vars(nir_shader *shader, void *mem_ctx)
34 {
35    struct set *complex_vars = _mesa_pointer_set_create(mem_ctx);
36 
37    nir_foreach_function(function, shader) {
38       if (!function->impl)
39          continue;
40 
41       nir_foreach_block(block, function->impl) {
42          nir_foreach_instr(instr, block) {
43             if (instr->type != nir_instr_type_deref)
44                continue;
45 
46             nir_deref_instr *deref = nir_instr_as_deref(instr);
47 
48             /* We only need to consider var derefs because
49              * nir_deref_instr_has_complex_use is recursive.
50              */
51             if (deref->deref_type == nir_deref_type_var &&
52                 nir_deref_instr_has_complex_use(deref))
53                _mesa_set_add(complex_vars, deref->var);
54          }
55       }
56    }
57 
58    return complex_vars;
59 }
60 
61 struct split_var_state {
62    void *mem_ctx;
63 
64    nir_shader *shader;
65    nir_function_impl *impl;
66 
67    nir_variable *base_var;
68 };
69 
70 struct field {
71    struct field *parent;
72 
73    const struct glsl_type *type;
74 
75    unsigned num_fields;
76    struct field *fields;
77 
78    nir_variable *var;
79 };
80 
81 static const struct glsl_type *
wrap_type_in_array(const struct glsl_type * type,const struct glsl_type * array_type)82 wrap_type_in_array(const struct glsl_type *type,
83                    const struct glsl_type *array_type)
84 {
85    if (!glsl_type_is_array(array_type))
86       return type;
87 
88    const struct glsl_type *elem_type =
89       wrap_type_in_array(type, glsl_get_array_element(array_type));
90    assert(glsl_get_explicit_stride(array_type) == 0);
91    return glsl_array_type(elem_type, glsl_get_length(array_type), 0);
92 }
93 
94 static int
num_array_levels_in_array_of_vector_type(const struct glsl_type * type)95 num_array_levels_in_array_of_vector_type(const struct glsl_type *type)
96 {
97    int num_levels = 0;
98    while (true) {
99       if (glsl_type_is_array_or_matrix(type)) {
100          num_levels++;
101          type = glsl_get_array_element(type);
102       } else if (glsl_type_is_vector_or_scalar(type)) {
103          return num_levels;
104       } else {
105          /* Not an array of vectors */
106          return -1;
107       }
108    }
109 }
110 
111 static void
init_field_for_type(struct field * field,struct field * parent,const struct glsl_type * type,const char * name,struct split_var_state * state)112 init_field_for_type(struct field *field, struct field *parent,
113                     const struct glsl_type *type,
114                     const char *name,
115                     struct split_var_state *state)
116 {
117    *field = (struct field) {
118       .parent = parent,
119       .type = type,
120    };
121 
122    const struct glsl_type *struct_type = glsl_without_array(type);
123    if (glsl_type_is_struct_or_ifc(struct_type)) {
124       field->num_fields = glsl_get_length(struct_type),
125       field->fields = ralloc_array(state->mem_ctx, struct field,
126                                    field->num_fields);
127       for (unsigned i = 0; i < field->num_fields; i++) {
128          char *field_name = NULL;
129          if (name) {
130             field_name = ralloc_asprintf(state->mem_ctx, "%s_%s", name,
131                                          glsl_get_struct_elem_name(struct_type, i));
132          } else {
133             field_name = ralloc_asprintf(state->mem_ctx, "{unnamed %s}_%s",
134                                          glsl_get_type_name(struct_type),
135                                          glsl_get_struct_elem_name(struct_type, i));
136          }
137          init_field_for_type(&field->fields[i], field,
138                              glsl_get_struct_field(struct_type, i),
139                              field_name, state);
140       }
141    } else {
142       const struct glsl_type *var_type = type;
143       for (struct field *f = field->parent; f; f = f->parent)
144          var_type = wrap_type_in_array(var_type, f->type);
145 
146       nir_variable_mode mode = state->base_var->data.mode;
147       if (mode == nir_var_function_temp) {
148          field->var = nir_local_variable_create(state->impl, var_type, name);
149       } else {
150          field->var = nir_variable_create(state->shader, mode, var_type, name);
151       }
152    }
153 }
154 
155 static bool
split_var_list_structs(nir_shader * shader,nir_function_impl * impl,struct exec_list * vars,nir_variable_mode mode,struct hash_table * var_field_map,struct set ** complex_vars,void * mem_ctx)156 split_var_list_structs(nir_shader *shader,
157                        nir_function_impl *impl,
158                        struct exec_list *vars,
159                        nir_variable_mode mode,
160                        struct hash_table *var_field_map,
161                        struct set **complex_vars,
162                        void *mem_ctx)
163 {
164    struct split_var_state state = {
165       .mem_ctx = mem_ctx,
166       .shader = shader,
167       .impl = impl,
168    };
169 
170    struct exec_list split_vars;
171    exec_list_make_empty(&split_vars);
172 
173    /* To avoid list confusion (we'll be adding things as we split variables),
174     * pull all of the variables we plan to split off of the list
175     */
176    nir_foreach_variable_in_list_safe(var, vars) {
177       if (var->data.mode != mode)
178          continue;
179 
180       if (!glsl_type_is_struct_or_ifc(glsl_without_array(var->type)))
181          continue;
182 
183       if (*complex_vars == NULL)
184          *complex_vars = get_complex_used_vars(shader, mem_ctx);
185 
186       /* We can't split a variable that's referenced with deref that has any
187        * sort of complex usage.
188        */
189       if (_mesa_set_search(*complex_vars, var))
190          continue;
191 
192       exec_node_remove(&var->node);
193       exec_list_push_tail(&split_vars, &var->node);
194    }
195 
196    nir_foreach_variable_in_list(var, &split_vars) {
197       state.base_var = var;
198 
199       struct field *root_field = ralloc(mem_ctx, struct field);
200       init_field_for_type(root_field, NULL, var->type, var->name, &state);
201       _mesa_hash_table_insert(var_field_map, var, root_field);
202    }
203 
204    return !exec_list_is_empty(&split_vars);
205 }
206 
207 static void
split_struct_derefs_impl(nir_function_impl * impl,struct hash_table * var_field_map,nir_variable_mode modes,void * mem_ctx)208 split_struct_derefs_impl(nir_function_impl *impl,
209                          struct hash_table *var_field_map,
210                          nir_variable_mode modes,
211                          void *mem_ctx)
212 {
213    nir_builder b;
214    nir_builder_init(&b, impl);
215 
216    nir_foreach_block(block, impl) {
217       nir_foreach_instr_safe(instr, block) {
218          if (instr->type != nir_instr_type_deref)
219             continue;
220 
221          nir_deref_instr *deref = nir_instr_as_deref(instr);
222          if (!(deref->mode & modes))
223             continue;
224 
225          /* Clean up any dead derefs we find lying around.  They may refer to
226           * variables we're planning to split.
227           */
228          if (nir_deref_instr_remove_if_unused(deref))
229             continue;
230 
231          if (!glsl_type_is_vector_or_scalar(deref->type))
232             continue;
233 
234          nir_variable *base_var = nir_deref_instr_get_variable(deref);
235          /* If we can't chase back to the variable, then we're a complex use.
236           * This should have been detected by get_complex_used_vars() and the
237           * variable should not have been split.  However, we have no way of
238           * knowing that here, so we just have to trust it.
239           */
240          if (base_var == NULL)
241             continue;
242 
243          struct hash_entry *entry =
244             _mesa_hash_table_search(var_field_map, base_var);
245          if (!entry)
246             continue;
247 
248          struct field *root_field = entry->data;
249 
250          nir_deref_path path;
251          nir_deref_path_init(&path, deref, mem_ctx);
252 
253          struct field *tail_field = root_field;
254          for (unsigned i = 0; path.path[i]; i++) {
255             if (path.path[i]->deref_type != nir_deref_type_struct)
256                continue;
257 
258             assert(i > 0);
259             assert(glsl_type_is_struct_or_ifc(path.path[i - 1]->type));
260             assert(path.path[i - 1]->type ==
261                    glsl_without_array(tail_field->type));
262 
263             tail_field = &tail_field->fields[path.path[i]->strct.index];
264          }
265          nir_variable *split_var = tail_field->var;
266 
267          nir_deref_instr *new_deref = NULL;
268          for (unsigned i = 0; path.path[i]; i++) {
269             nir_deref_instr *p = path.path[i];
270             b.cursor = nir_after_instr(&p->instr);
271 
272             switch (p->deref_type) {
273             case nir_deref_type_var:
274                assert(new_deref == NULL);
275                new_deref = nir_build_deref_var(&b, split_var);
276                break;
277 
278             case nir_deref_type_array:
279             case nir_deref_type_array_wildcard:
280                new_deref = nir_build_deref_follower(&b, new_deref, p);
281                break;
282 
283             case nir_deref_type_struct:
284                /* Nothing to do; we're splitting structs */
285                break;
286 
287             default:
288                unreachable("Invalid deref type in path");
289             }
290          }
291 
292          assert(new_deref->type == deref->type);
293          nir_ssa_def_rewrite_uses(&deref->dest.ssa,
294                                   nir_src_for_ssa(&new_deref->dest.ssa));
295          nir_deref_instr_remove_if_unused(deref);
296       }
297    }
298 }
299 
300 /** A pass for splitting structs into multiple variables
301  *
302  * This pass splits arrays of structs into multiple variables, one for each
303  * (possibly nested) structure member.  After this pass completes, no
304  * variables of the given mode will contain a struct type.
305  */
306 bool
nir_split_struct_vars(nir_shader * shader,nir_variable_mode modes)307 nir_split_struct_vars(nir_shader *shader, nir_variable_mode modes)
308 {
309    void *mem_ctx = ralloc_context(NULL);
310    struct hash_table *var_field_map =
311       _mesa_pointer_hash_table_create(mem_ctx);
312    struct set *complex_vars = NULL;
313 
314    assert((modes & (nir_var_shader_temp | nir_var_function_temp)) == modes);
315 
316    bool has_global_splits = false;
317    if (modes & nir_var_shader_temp) {
318       has_global_splits = split_var_list_structs(shader, NULL,
319                                                  &shader->variables,
320                                                  nir_var_shader_temp,
321                                                  var_field_map,
322                                                  &complex_vars,
323                                                  mem_ctx);
324    }
325 
326    bool progress = false;
327    nir_foreach_function(function, shader) {
328       if (!function->impl)
329          continue;
330 
331       bool has_local_splits = false;
332       if (modes & nir_var_function_temp) {
333          has_local_splits = split_var_list_structs(shader, function->impl,
334                                                    &function->impl->locals,
335                                                    nir_var_function_temp,
336                                                    var_field_map,
337                                                    &complex_vars,
338                                                    mem_ctx);
339       }
340 
341       if (has_global_splits || has_local_splits) {
342          split_struct_derefs_impl(function->impl, var_field_map,
343                                   modes, mem_ctx);
344 
345          nir_metadata_preserve(function->impl, nir_metadata_block_index |
346                                                nir_metadata_dominance);
347          progress = true;
348       } else {
349          nir_metadata_preserve(function->impl, nir_metadata_all);
350       }
351    }
352 
353    ralloc_free(mem_ctx);
354 
355    return progress;
356 }
357 
358 struct array_level_info {
359    unsigned array_len;
360    bool split;
361 };
362 
363 struct array_split {
364    /* Only set if this is the tail end of the splitting */
365    nir_variable *var;
366 
367    unsigned num_splits;
368    struct array_split *splits;
369 };
370 
371 struct array_var_info {
372    nir_variable *base_var;
373 
374    const struct glsl_type *split_var_type;
375 
376    bool split_var;
377    struct array_split root_split;
378 
379    unsigned num_levels;
380    struct array_level_info levels[0];
381 };
382 
383 static bool
init_var_list_array_infos(nir_shader * shader,struct exec_list * vars,nir_variable_mode mode,struct hash_table * var_info_map,struct set ** complex_vars,void * mem_ctx)384 init_var_list_array_infos(nir_shader *shader,
385                           struct exec_list *vars,
386                           nir_variable_mode mode,
387                           struct hash_table *var_info_map,
388                           struct set **complex_vars,
389                           void *mem_ctx)
390 {
391    bool has_array = false;
392 
393    nir_foreach_variable_in_list(var, vars) {
394       if (var->data.mode != mode)
395          continue;
396 
397       int num_levels = num_array_levels_in_array_of_vector_type(var->type);
398       if (num_levels <= 0)
399          continue;
400 
401       if (*complex_vars == NULL)
402          *complex_vars = get_complex_used_vars(shader, mem_ctx);
403 
404       /* We can't split a variable that's referenced with deref that has any
405        * sort of complex usage.
406        */
407       if (_mesa_set_search(*complex_vars, var))
408          continue;
409 
410       struct array_var_info *info =
411          rzalloc_size(mem_ctx, sizeof(*info) +
412                                num_levels * sizeof(info->levels[0]));
413 
414       info->base_var = var;
415       info->num_levels = num_levels;
416 
417       const struct glsl_type *type = var->type;
418       for (int i = 0; i < num_levels; i++) {
419          info->levels[i].array_len = glsl_get_length(type);
420          type = glsl_get_array_element(type);
421 
422          /* All levels start out initially as split */
423          info->levels[i].split = true;
424       }
425 
426       _mesa_hash_table_insert(var_info_map, var, info);
427       has_array = true;
428    }
429 
430    return has_array;
431 }
432 
433 static struct array_var_info *
get_array_var_info(nir_variable * var,struct hash_table * var_info_map)434 get_array_var_info(nir_variable *var,
435                    struct hash_table *var_info_map)
436 {
437    struct hash_entry *entry =
438       _mesa_hash_table_search(var_info_map, var);
439    return entry ? entry->data : NULL;
440 }
441 
442 static struct array_var_info *
get_array_deref_info(nir_deref_instr * deref,struct hash_table * var_info_map,nir_variable_mode modes)443 get_array_deref_info(nir_deref_instr *deref,
444                      struct hash_table *var_info_map,
445                      nir_variable_mode modes)
446 {
447    if (!(deref->mode & modes))
448       return NULL;
449 
450    nir_variable *var = nir_deref_instr_get_variable(deref);
451    if (var == NULL)
452       return NULL;
453 
454    return get_array_var_info(var, var_info_map);
455 }
456 
457 static void
mark_array_deref_used(nir_deref_instr * deref,struct hash_table * var_info_map,nir_variable_mode modes,void * mem_ctx)458 mark_array_deref_used(nir_deref_instr *deref,
459                       struct hash_table *var_info_map,
460                       nir_variable_mode modes,
461                       void *mem_ctx)
462 {
463    struct array_var_info *info =
464       get_array_deref_info(deref, var_info_map, modes);
465    if (!info)
466       return;
467 
468    nir_deref_path path;
469    nir_deref_path_init(&path, deref, mem_ctx);
470 
471    /* Walk the path and look for indirects.  If we have an array deref with an
472     * indirect, mark the given level as not being split.
473     */
474    for (unsigned i = 0; i < info->num_levels; i++) {
475       nir_deref_instr *p = path.path[i + 1];
476       if (p->deref_type == nir_deref_type_array &&
477           !nir_src_is_const(p->arr.index))
478          info->levels[i].split = false;
479    }
480 }
481 
482 static void
mark_array_usage_impl(nir_function_impl * impl,struct hash_table * var_info_map,nir_variable_mode modes,void * mem_ctx)483 mark_array_usage_impl(nir_function_impl *impl,
484                       struct hash_table *var_info_map,
485                       nir_variable_mode modes,
486                       void *mem_ctx)
487 {
488    nir_foreach_block(block, impl) {
489       nir_foreach_instr(instr, block) {
490          if (instr->type != nir_instr_type_intrinsic)
491             continue;
492 
493          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
494          switch (intrin->intrinsic) {
495          case nir_intrinsic_copy_deref:
496             mark_array_deref_used(nir_src_as_deref(intrin->src[1]),
497                                   var_info_map, modes, mem_ctx);
498             /* Fall Through */
499 
500          case nir_intrinsic_load_deref:
501          case nir_intrinsic_store_deref:
502             mark_array_deref_used(nir_src_as_deref(intrin->src[0]),
503                                   var_info_map, modes, mem_ctx);
504             break;
505 
506          default:
507             break;
508          }
509       }
510    }
511 }
512 
513 static void
create_split_array_vars(struct array_var_info * var_info,unsigned level,struct array_split * split,const char * name,nir_shader * shader,nir_function_impl * impl,void * mem_ctx)514 create_split_array_vars(struct array_var_info *var_info,
515                         unsigned level,
516                         struct array_split *split,
517                         const char *name,
518                         nir_shader *shader,
519                         nir_function_impl *impl,
520                         void *mem_ctx)
521 {
522    while (level < var_info->num_levels && !var_info->levels[level].split) {
523       name = ralloc_asprintf(mem_ctx, "%s[*]", name);
524       level++;
525    }
526 
527    if (level == var_info->num_levels) {
528       /* We add parens to the variable name so it looks like "(foo[2][*])" so
529        * that further derefs will look like "(foo[2][*])[ssa_6]"
530        */
531       name = ralloc_asprintf(mem_ctx, "(%s)", name);
532 
533       nir_variable_mode mode = var_info->base_var->data.mode;
534       if (mode == nir_var_function_temp) {
535          split->var = nir_local_variable_create(impl,
536                                                 var_info->split_var_type, name);
537       } else {
538          split->var = nir_variable_create(shader, mode,
539                                           var_info->split_var_type, name);
540       }
541    } else {
542       assert(var_info->levels[level].split);
543       split->num_splits = var_info->levels[level].array_len;
544       split->splits = rzalloc_array(mem_ctx, struct array_split,
545                                     split->num_splits);
546       for (unsigned i = 0; i < split->num_splits; i++) {
547          create_split_array_vars(var_info, level + 1, &split->splits[i],
548                                  ralloc_asprintf(mem_ctx, "%s[%d]", name, i),
549                                  shader, impl, mem_ctx);
550       }
551    }
552 }
553 
554 static bool
split_var_list_arrays(nir_shader * shader,nir_function_impl * impl,struct exec_list * vars,nir_variable_mode mode,struct hash_table * var_info_map,void * mem_ctx)555 split_var_list_arrays(nir_shader *shader,
556                       nir_function_impl *impl,
557                       struct exec_list *vars,
558                       nir_variable_mode mode,
559                       struct hash_table *var_info_map,
560                       void *mem_ctx)
561 {
562    struct exec_list split_vars;
563    exec_list_make_empty(&split_vars);
564 
565    nir_foreach_variable_in_list_safe(var, vars) {
566       if (var->data.mode != mode)
567          continue;
568 
569       struct array_var_info *info = get_array_var_info(var, var_info_map);
570       if (!info)
571          continue;
572 
573       bool has_split = false;
574       const struct glsl_type *split_type =
575          glsl_without_array_or_matrix(var->type);
576       for (int i = info->num_levels - 1; i >= 0; i--) {
577          if (info->levels[i].split) {
578             has_split = true;
579             continue;
580          }
581 
582          /* If the original type was a matrix type, we'd like to keep that so
583           * we don't convert matrices into arrays.
584           */
585          if (i == info->num_levels - 1 &&
586              glsl_type_is_matrix(glsl_without_array(var->type))) {
587             split_type = glsl_matrix_type(glsl_get_base_type(split_type),
588                                           glsl_get_components(split_type),
589                                           info->levels[i].array_len);
590          } else {
591             split_type = glsl_array_type(split_type, info->levels[i].array_len, 0);
592          }
593       }
594 
595       if (has_split) {
596          info->split_var_type = split_type;
597          /* To avoid list confusion (we'll be adding things as we split
598           * variables), pull all of the variables we plan to split off of the
599           * main variable list.
600           */
601          exec_node_remove(&var->node);
602          exec_list_push_tail(&split_vars, &var->node);
603       } else {
604          assert(split_type == glsl_get_bare_type(var->type));
605          /* If we're not modifying this variable, delete the info so we skip
606           * it faster in later passes.
607           */
608          _mesa_hash_table_remove_key(var_info_map, var);
609       }
610    }
611 
612    nir_foreach_variable_in_list(var, &split_vars) {
613       struct array_var_info *info = get_array_var_info(var, var_info_map);
614       create_split_array_vars(info, 0, &info->root_split, var->name,
615                               shader, impl, mem_ctx);
616    }
617 
618    return !exec_list_is_empty(&split_vars);
619 }
620 
621 static bool
deref_has_split_wildcard(nir_deref_path * path,struct array_var_info * info)622 deref_has_split_wildcard(nir_deref_path *path,
623                          struct array_var_info *info)
624 {
625    if (info == NULL)
626       return false;
627 
628    assert(path->path[0]->var == info->base_var);
629    for (unsigned i = 0; i < info->num_levels; i++) {
630       if (path->path[i + 1]->deref_type == nir_deref_type_array_wildcard &&
631           info->levels[i].split)
632          return true;
633    }
634 
635    return false;
636 }
637 
638 static bool
array_path_is_out_of_bounds(nir_deref_path * path,struct array_var_info * info)639 array_path_is_out_of_bounds(nir_deref_path *path,
640                             struct array_var_info *info)
641 {
642    if (info == NULL)
643       return false;
644 
645    assert(path->path[0]->var == info->base_var);
646    for (unsigned i = 0; i < info->num_levels; i++) {
647       nir_deref_instr *p = path->path[i + 1];
648       if (p->deref_type == nir_deref_type_array_wildcard)
649          continue;
650 
651       if (nir_src_is_const(p->arr.index) &&
652           nir_src_as_uint(p->arr.index) >= info->levels[i].array_len)
653          return true;
654    }
655 
656    return false;
657 }
658 
659 static void
emit_split_copies(nir_builder * b,struct array_var_info * dst_info,nir_deref_path * dst_path,unsigned dst_level,nir_deref_instr * dst,struct array_var_info * src_info,nir_deref_path * src_path,unsigned src_level,nir_deref_instr * src)660 emit_split_copies(nir_builder *b,
661                   struct array_var_info *dst_info, nir_deref_path *dst_path,
662                   unsigned dst_level, nir_deref_instr *dst,
663                   struct array_var_info *src_info, nir_deref_path *src_path,
664                   unsigned src_level, nir_deref_instr *src)
665 {
666    nir_deref_instr *dst_p, *src_p;
667 
668    while ((dst_p = dst_path->path[dst_level + 1])) {
669       if (dst_p->deref_type == nir_deref_type_array_wildcard)
670          break;
671 
672       dst = nir_build_deref_follower(b, dst, dst_p);
673       dst_level++;
674    }
675 
676    while ((src_p = src_path->path[src_level + 1])) {
677       if (src_p->deref_type == nir_deref_type_array_wildcard)
678          break;
679 
680       src = nir_build_deref_follower(b, src, src_p);
681       src_level++;
682    }
683 
684    if (src_p == NULL || dst_p == NULL) {
685       assert(src_p == NULL && dst_p == NULL);
686       nir_copy_deref(b, dst, src);
687    } else {
688       assert(dst_p->deref_type == nir_deref_type_array_wildcard &&
689              src_p->deref_type == nir_deref_type_array_wildcard);
690 
691       if ((dst_info && dst_info->levels[dst_level].split) ||
692           (src_info && src_info->levels[src_level].split)) {
693          /* There are no indirects at this level on one of the source or the
694           * destination so we are lowering it.
695           */
696          assert(glsl_get_length(dst_path->path[dst_level]->type) ==
697                 glsl_get_length(src_path->path[src_level]->type));
698          unsigned len = glsl_get_length(dst_path->path[dst_level]->type);
699          for (unsigned i = 0; i < len; i++) {
700             emit_split_copies(b, dst_info, dst_path, dst_level + 1,
701                               nir_build_deref_array_imm(b, dst, i),
702                               src_info, src_path, src_level + 1,
703                               nir_build_deref_array_imm(b, src, i));
704          }
705       } else {
706          /* Neither side is being split so we just keep going */
707          emit_split_copies(b, dst_info, dst_path, dst_level + 1,
708                            nir_build_deref_array_wildcard(b, dst),
709                            src_info, src_path, src_level + 1,
710                            nir_build_deref_array_wildcard(b, src));
711       }
712    }
713 }
714 
715 static void
split_array_copies_impl(nir_function_impl * impl,struct hash_table * var_info_map,nir_variable_mode modes,void * mem_ctx)716 split_array_copies_impl(nir_function_impl *impl,
717                         struct hash_table *var_info_map,
718                         nir_variable_mode modes,
719                         void *mem_ctx)
720 {
721    nir_builder b;
722    nir_builder_init(&b, impl);
723 
724    nir_foreach_block(block, impl) {
725       nir_foreach_instr_safe(instr, block) {
726          if (instr->type != nir_instr_type_intrinsic)
727             continue;
728 
729          nir_intrinsic_instr *copy = nir_instr_as_intrinsic(instr);
730          if (copy->intrinsic != nir_intrinsic_copy_deref)
731             continue;
732 
733          nir_deref_instr *dst_deref = nir_src_as_deref(copy->src[0]);
734          nir_deref_instr *src_deref = nir_src_as_deref(copy->src[1]);
735 
736          struct array_var_info *dst_info =
737             get_array_deref_info(dst_deref, var_info_map, modes);
738          struct array_var_info *src_info =
739             get_array_deref_info(src_deref, var_info_map, modes);
740 
741          if (!src_info && !dst_info)
742             continue;
743 
744          nir_deref_path dst_path, src_path;
745          nir_deref_path_init(&dst_path, dst_deref, mem_ctx);
746          nir_deref_path_init(&src_path, src_deref, mem_ctx);
747 
748          if (!deref_has_split_wildcard(&dst_path, dst_info) &&
749              !deref_has_split_wildcard(&src_path, src_info))
750             continue;
751 
752          b.cursor = nir_instr_remove(&copy->instr);
753 
754          emit_split_copies(&b, dst_info, &dst_path, 0, dst_path.path[0],
755                                src_info, &src_path, 0, src_path.path[0]);
756       }
757    }
758 }
759 
760 static void
split_array_access_impl(nir_function_impl * impl,struct hash_table * var_info_map,nir_variable_mode modes,void * mem_ctx)761 split_array_access_impl(nir_function_impl *impl,
762                         struct hash_table *var_info_map,
763                         nir_variable_mode modes,
764                         void *mem_ctx)
765 {
766    nir_builder b;
767    nir_builder_init(&b, impl);
768 
769    nir_foreach_block(block, impl) {
770       nir_foreach_instr_safe(instr, block) {
771          if (instr->type == nir_instr_type_deref) {
772             /* Clean up any dead derefs we find lying around.  They may refer
773              * to variables we're planning to split.
774              */
775             nir_deref_instr *deref = nir_instr_as_deref(instr);
776             if (deref->mode & modes)
777                nir_deref_instr_remove_if_unused(deref);
778             continue;
779          }
780 
781          if (instr->type != nir_instr_type_intrinsic)
782             continue;
783 
784          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
785          if (intrin->intrinsic != nir_intrinsic_load_deref &&
786              intrin->intrinsic != nir_intrinsic_store_deref &&
787              intrin->intrinsic != nir_intrinsic_copy_deref)
788             continue;
789 
790          const unsigned num_derefs =
791             intrin->intrinsic == nir_intrinsic_copy_deref ? 2 : 1;
792 
793          for (unsigned d = 0; d < num_derefs; d++) {
794             nir_deref_instr *deref = nir_src_as_deref(intrin->src[d]);
795 
796             struct array_var_info *info =
797                get_array_deref_info(deref, var_info_map, modes);
798             if (!info)
799                continue;
800 
801             nir_deref_path path;
802             nir_deref_path_init(&path, deref, mem_ctx);
803 
804             b.cursor = nir_before_instr(&intrin->instr);
805 
806             if (array_path_is_out_of_bounds(&path, info)) {
807                /* If one of the derefs is out-of-bounds, we just delete the
808                 * instruction.  If a destination is out of bounds, then it may
809                 * have been in-bounds prior to shrinking so we don't want to
810                 * accidentally stomp something.  However, we've already proven
811                 * that it will never be read so it's safe to delete.  If a
812                 * source is out of bounds then it is loading random garbage.
813                 * For loads, we replace their uses with an undef instruction
814                 * and for copies we just delete the copy since it was writing
815                 * undefined garbage anyway and we may as well leave the random
816                 * garbage in the destination alone.
817                 */
818                if (intrin->intrinsic == nir_intrinsic_load_deref) {
819                   nir_ssa_def *u =
820                      nir_ssa_undef(&b, intrin->dest.ssa.num_components,
821                                        intrin->dest.ssa.bit_size);
822                   nir_ssa_def_rewrite_uses(&intrin->dest.ssa,
823                                            nir_src_for_ssa(u));
824                }
825                nir_instr_remove(&intrin->instr);
826                for (unsigned i = 0; i < num_derefs; i++)
827                   nir_deref_instr_remove_if_unused(nir_src_as_deref(intrin->src[i]));
828                break;
829             }
830 
831             struct array_split *split = &info->root_split;
832             for (unsigned i = 0; i < info->num_levels; i++) {
833                if (info->levels[i].split) {
834                   nir_deref_instr *p = path.path[i + 1];
835                   unsigned index = nir_src_as_uint(p->arr.index);
836                   assert(index < info->levels[i].array_len);
837                   split = &split->splits[index];
838                }
839             }
840             assert(!split->splits && split->var);
841 
842             nir_deref_instr *new_deref = nir_build_deref_var(&b, split->var);
843             for (unsigned i = 0; i < info->num_levels; i++) {
844                if (!info->levels[i].split) {
845                   new_deref = nir_build_deref_follower(&b, new_deref,
846                                                        path.path[i + 1]);
847                }
848             }
849             assert(new_deref->type == deref->type);
850 
851             /* Rewrite the deref source to point to the split one */
852             nir_instr_rewrite_src(&intrin->instr, &intrin->src[d],
853                                   nir_src_for_ssa(&new_deref->dest.ssa));
854             nir_deref_instr_remove_if_unused(deref);
855          }
856       }
857    }
858 }
859 
860 /** A pass for splitting arrays of vectors into multiple variables
861  *
862  * This pass looks at arrays (possibly multiple levels) of vectors (not
863  * structures or other types) and tries to split them into piles of variables,
864  * one for each array element.  The heuristic used is simple: If a given array
865  * level is never used with an indirect, that array level will get split.
866  *
867  * This pass probably could handles structures easily enough but making a pass
868  * that could see through an array of structures of arrays would be difficult
869  * so it's best to just run nir_split_struct_vars first.
870  */
871 bool
nir_split_array_vars(nir_shader * shader,nir_variable_mode modes)872 nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
873 {
874    void *mem_ctx = ralloc_context(NULL);
875    struct hash_table *var_info_map = _mesa_pointer_hash_table_create(mem_ctx);
876    struct set *complex_vars = NULL;
877 
878    assert((modes & (nir_var_shader_temp | nir_var_function_temp)) == modes);
879 
880    bool has_global_array = false;
881    if (modes & nir_var_shader_temp) {
882       has_global_array = init_var_list_array_infos(shader,
883                                                    &shader->variables,
884                                                    nir_var_shader_temp,
885                                                    var_info_map,
886                                                    &complex_vars,
887                                                    mem_ctx);
888    }
889 
890    bool has_any_array = false;
891    nir_foreach_function(function, shader) {
892       if (!function->impl)
893          continue;
894 
895       bool has_local_array = false;
896       if (modes & nir_var_function_temp) {
897          has_local_array = init_var_list_array_infos(shader,
898                                                      &function->impl->locals,
899                                                      nir_var_function_temp,
900                                                      var_info_map,
901                                                      &complex_vars,
902                                                      mem_ctx);
903       }
904 
905       if (has_global_array || has_local_array) {
906          has_any_array = true;
907          mark_array_usage_impl(function->impl, var_info_map, modes, mem_ctx);
908       }
909    }
910 
911    /* If we failed to find any arrays of arrays, bail early. */
912    if (!has_any_array) {
913       ralloc_free(mem_ctx);
914       nir_shader_preserve_all_metadata(shader);
915       return false;
916    }
917 
918    bool has_global_splits = false;
919    if (modes & nir_var_shader_temp) {
920       has_global_splits = split_var_list_arrays(shader, NULL,
921                                                 &shader->variables,
922                                                 nir_var_shader_temp,
923                                                 var_info_map, mem_ctx);
924    }
925 
926    bool progress = false;
927    nir_foreach_function(function, shader) {
928       if (!function->impl)
929          continue;
930 
931       bool has_local_splits = false;
932       if (modes & nir_var_function_temp) {
933          has_local_splits = split_var_list_arrays(shader, function->impl,
934                                                   &function->impl->locals,
935                                                   nir_var_function_temp,
936                                                   var_info_map, mem_ctx);
937       }
938 
939       if (has_global_splits || has_local_splits) {
940          split_array_copies_impl(function->impl, var_info_map, modes, mem_ctx);
941          split_array_access_impl(function->impl, var_info_map, modes, mem_ctx);
942 
943          nir_metadata_preserve(function->impl, nir_metadata_block_index |
944                                                nir_metadata_dominance);
945          progress = true;
946       } else {
947          nir_metadata_preserve(function->impl, nir_metadata_all);
948       }
949    }
950 
951    ralloc_free(mem_ctx);
952 
953    return progress;
954 }
955 
956 struct array_level_usage {
957    unsigned array_len;
958 
959    /* The value UINT_MAX will be used to indicate an indirect */
960    unsigned max_read;
961    unsigned max_written;
962 
963    /* True if there is a copy that isn't to/from a shrinkable array */
964    bool has_external_copy;
965    struct set *levels_copied;
966 };
967 
968 struct vec_var_usage {
969    /* Convenience set of all components this variable has */
970    nir_component_mask_t all_comps;
971 
972    nir_component_mask_t comps_read;
973    nir_component_mask_t comps_written;
974 
975    nir_component_mask_t comps_kept;
976 
977    /* True if there is a copy that isn't to/from a shrinkable vector */
978    bool has_external_copy;
979    bool has_complex_use;
980    struct set *vars_copied;
981 
982    unsigned num_levels;
983    struct array_level_usage levels[0];
984 };
985 
986 static struct vec_var_usage *
get_vec_var_usage(nir_variable * var,struct hash_table * var_usage_map,bool add_usage_entry,void * mem_ctx)987 get_vec_var_usage(nir_variable *var,
988                   struct hash_table *var_usage_map,
989                   bool add_usage_entry, void *mem_ctx)
990 {
991    struct hash_entry *entry = _mesa_hash_table_search(var_usage_map, var);
992    if (entry)
993       return entry->data;
994 
995    if (!add_usage_entry)
996       return NULL;
997 
998    /* Check to make sure that we are working with an array of vectors.  We
999     * don't bother to shrink single vectors because we figure that we can
1000     * clean it up better with SSA than by inserting piles of vecN instructions
1001     * to compact results.
1002     */
1003    int num_levels = num_array_levels_in_array_of_vector_type(var->type);
1004    if (num_levels < 1)
1005       return NULL; /* Not an array of vectors */
1006 
1007    struct vec_var_usage *usage =
1008       rzalloc_size(mem_ctx, sizeof(*usage) +
1009                             num_levels * sizeof(usage->levels[0]));
1010 
1011    usage->num_levels = num_levels;
1012    const struct glsl_type *type = var->type;
1013    for (unsigned i = 0; i < num_levels; i++) {
1014       usage->levels[i].array_len = glsl_get_length(type);
1015       type = glsl_get_array_element(type);
1016    }
1017    assert(glsl_type_is_vector_or_scalar(type));
1018 
1019    usage->all_comps = (1 << glsl_get_components(type)) - 1;
1020 
1021    _mesa_hash_table_insert(var_usage_map, var, usage);
1022 
1023    return usage;
1024 }
1025 
1026 static struct vec_var_usage *
get_vec_deref_usage(nir_deref_instr * deref,struct hash_table * var_usage_map,nir_variable_mode modes,bool add_usage_entry,void * mem_ctx)1027 get_vec_deref_usage(nir_deref_instr *deref,
1028                     struct hash_table *var_usage_map,
1029                     nir_variable_mode modes,
1030                     bool add_usage_entry, void *mem_ctx)
1031 {
1032    if (!(deref->mode & modes))
1033       return NULL;
1034 
1035    return get_vec_var_usage(nir_deref_instr_get_variable(deref),
1036                             var_usage_map, add_usage_entry, mem_ctx);
1037 }
1038 
1039 static void
mark_deref_if_complex(nir_deref_instr * deref,struct hash_table * var_usage_map,nir_variable_mode modes,void * mem_ctx)1040 mark_deref_if_complex(nir_deref_instr *deref,
1041                       struct hash_table *var_usage_map,
1042                       nir_variable_mode modes,
1043                       void *mem_ctx)
1044 {
1045    if (!(deref->mode & modes))
1046       return;
1047 
1048    /* Only bother with var derefs because nir_deref_instr_has_complex_use is
1049     * recursive.
1050     */
1051    if (deref->deref_type != nir_deref_type_var)
1052       return;
1053 
1054    if (!nir_deref_instr_has_complex_use(deref))
1055       return;
1056 
1057    struct vec_var_usage *usage =
1058       get_vec_var_usage(deref->var, var_usage_map, true, mem_ctx);
1059    if (!usage)
1060       return;
1061 
1062    usage->has_complex_use = true;
1063 }
1064 
1065 static void
mark_deref_used(nir_deref_instr * deref,nir_component_mask_t comps_read,nir_component_mask_t comps_written,nir_deref_instr * copy_deref,struct hash_table * var_usage_map,nir_variable_mode modes,void * mem_ctx)1066 mark_deref_used(nir_deref_instr *deref,
1067                 nir_component_mask_t comps_read,
1068                 nir_component_mask_t comps_written,
1069                 nir_deref_instr *copy_deref,
1070                 struct hash_table *var_usage_map,
1071                 nir_variable_mode modes,
1072                 void *mem_ctx)
1073 {
1074    if (!(deref->mode & modes))
1075       return;
1076 
1077    nir_variable *var = nir_deref_instr_get_variable(deref);
1078    if (var == NULL)
1079       return;
1080 
1081    struct vec_var_usage *usage =
1082       get_vec_var_usage(var, var_usage_map, true, mem_ctx);
1083    if (!usage)
1084       return;
1085 
1086    usage->comps_read |= comps_read & usage->all_comps;
1087    usage->comps_written |= comps_written & usage->all_comps;
1088 
1089    struct vec_var_usage *copy_usage = NULL;
1090    if (copy_deref) {
1091       copy_usage = get_vec_deref_usage(copy_deref, var_usage_map, modes,
1092                                        true, mem_ctx);
1093       if (copy_usage) {
1094          if (usage->vars_copied == NULL) {
1095             usage->vars_copied = _mesa_pointer_set_create(mem_ctx);
1096          }
1097          _mesa_set_add(usage->vars_copied, copy_usage);
1098       } else {
1099          usage->has_external_copy = true;
1100       }
1101    }
1102 
1103    nir_deref_path path;
1104    nir_deref_path_init(&path, deref, mem_ctx);
1105 
1106    nir_deref_path copy_path;
1107    if (copy_usage)
1108       nir_deref_path_init(&copy_path, copy_deref, mem_ctx);
1109 
1110    unsigned copy_i = 0;
1111    for (unsigned i = 0; i < usage->num_levels; i++) {
1112       struct array_level_usage *level = &usage->levels[i];
1113       nir_deref_instr *deref = path.path[i + 1];
1114       assert(deref->deref_type == nir_deref_type_array ||
1115              deref->deref_type == nir_deref_type_array_wildcard);
1116 
1117       unsigned max_used;
1118       if (deref->deref_type == nir_deref_type_array) {
1119          max_used = nir_src_is_const(deref->arr.index) ?
1120                     nir_src_as_uint(deref->arr.index) : UINT_MAX;
1121       } else {
1122          /* For wildcards, we read or wrote the whole thing. */
1123          assert(deref->deref_type == nir_deref_type_array_wildcard);
1124          max_used = level->array_len - 1;
1125 
1126          if (copy_usage) {
1127             /* Match each wildcard level with the level on copy_usage */
1128             for (; copy_path.path[copy_i + 1]; copy_i++) {
1129                if (copy_path.path[copy_i + 1]->deref_type ==
1130                    nir_deref_type_array_wildcard)
1131                   break;
1132             }
1133             struct array_level_usage *copy_level =
1134                &copy_usage->levels[copy_i++];
1135 
1136             if (level->levels_copied == NULL) {
1137                level->levels_copied = _mesa_pointer_set_create(mem_ctx);
1138             }
1139             _mesa_set_add(level->levels_copied, copy_level);
1140          } else {
1141             /* We have a wildcard and it comes from a variable we aren't
1142              * tracking; flag it and we'll know to not shorten this array.
1143              */
1144             level->has_external_copy = true;
1145          }
1146       }
1147 
1148       if (comps_written)
1149          level->max_written = MAX2(level->max_written, max_used);
1150       if (comps_read)
1151          level->max_read = MAX2(level->max_read, max_used);
1152    }
1153 }
1154 
1155 static bool
src_is_load_deref(nir_src src,nir_src deref_src)1156 src_is_load_deref(nir_src src, nir_src deref_src)
1157 {
1158    nir_intrinsic_instr *load = nir_src_as_intrinsic(src);
1159    if (load == NULL || load->intrinsic != nir_intrinsic_load_deref)
1160       return false;
1161 
1162    assert(load->src[0].is_ssa);
1163 
1164    return load->src[0].ssa == deref_src.ssa;
1165 }
1166 
1167 /* Returns all non-self-referential components of a store instruction.  A
1168  * component is self-referential if it comes from the same component of a load
1169  * instruction on the same deref.  If the only data in a particular component
1170  * of a variable came directly from that component then it's undefined.  The
1171  * only way to get defined data into a component of a variable is for it to
1172  * get written there by something outside or from a different component.
1173  *
1174  * This is a fairly common pattern in shaders that come from either GLSL IR or
1175  * GLSLang because both glsl_to_nir and GLSLang implement write-masking with
1176  * load-vec-store.
1177  */
1178 static nir_component_mask_t
get_non_self_referential_store_comps(nir_intrinsic_instr * store)1179 get_non_self_referential_store_comps(nir_intrinsic_instr *store)
1180 {
1181    nir_component_mask_t comps = nir_intrinsic_write_mask(store);
1182 
1183    assert(store->src[1].is_ssa);
1184    nir_instr *src_instr = store->src[1].ssa->parent_instr;
1185    if (src_instr->type != nir_instr_type_alu)
1186       return comps;
1187 
1188    nir_alu_instr *src_alu = nir_instr_as_alu(src_instr);
1189 
1190    if (src_alu->op == nir_op_mov) {
1191       /* If it's just a swizzle of a load from the same deref, discount any
1192        * channels that don't move in the swizzle.
1193        */
1194       if (src_is_load_deref(src_alu->src[0].src, store->src[0])) {
1195          for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
1196             if (src_alu->src[0].swizzle[i] == i)
1197                comps &= ~(1u << i);
1198          }
1199       }
1200    } else if (nir_op_is_vec(src_alu->op)) {
1201       /* If it's a vec, discount any channels that are just loads from the
1202        * same deref put in the same spot.
1203        */
1204       for (unsigned i = 0; i < nir_op_infos[src_alu->op].num_inputs; i++) {
1205          if (src_is_load_deref(src_alu->src[i].src, store->src[0]) &&
1206              src_alu->src[i].swizzle[0] == i)
1207             comps &= ~(1u << i);
1208       }
1209    }
1210 
1211    return comps;
1212 }
1213 
1214 static void
find_used_components_impl(nir_function_impl * impl,struct hash_table * var_usage_map,nir_variable_mode modes,void * mem_ctx)1215 find_used_components_impl(nir_function_impl *impl,
1216                           struct hash_table *var_usage_map,
1217                           nir_variable_mode modes,
1218                           void *mem_ctx)
1219 {
1220    nir_foreach_block(block, impl) {
1221       nir_foreach_instr(instr, block) {
1222          if (instr->type == nir_instr_type_deref) {
1223             mark_deref_if_complex(nir_instr_as_deref(instr),
1224                                   var_usage_map, modes, mem_ctx);
1225          }
1226 
1227          if (instr->type != nir_instr_type_intrinsic)
1228             continue;
1229 
1230          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1231          switch (intrin->intrinsic) {
1232          case nir_intrinsic_load_deref:
1233             mark_deref_used(nir_src_as_deref(intrin->src[0]),
1234                             nir_ssa_def_components_read(&intrin->dest.ssa), 0,
1235                             NULL, var_usage_map, modes, mem_ctx);
1236             break;
1237 
1238          case nir_intrinsic_store_deref:
1239             mark_deref_used(nir_src_as_deref(intrin->src[0]),
1240                             0, get_non_self_referential_store_comps(intrin),
1241                             NULL, var_usage_map, modes, mem_ctx);
1242             break;
1243 
1244          case nir_intrinsic_copy_deref: {
1245             /* Just mark everything used for copies. */
1246             nir_deref_instr *dst = nir_src_as_deref(intrin->src[0]);
1247             nir_deref_instr *src = nir_src_as_deref(intrin->src[1]);
1248             mark_deref_used(dst, 0, ~0, src, var_usage_map, modes, mem_ctx);
1249             mark_deref_used(src, ~0, 0, dst, var_usage_map, modes, mem_ctx);
1250             break;
1251          }
1252 
1253          default:
1254             break;
1255          }
1256       }
1257    }
1258 }
1259 
1260 static bool
shrink_vec_var_list(struct exec_list * vars,nir_variable_mode mode,struct hash_table * var_usage_map)1261 shrink_vec_var_list(struct exec_list *vars,
1262                     nir_variable_mode mode,
1263                     struct hash_table *var_usage_map)
1264 {
1265    /* Initialize the components kept field of each variable.  This is the
1266     * AND of the components written and components read.  If a component is
1267     * written but never read, it's dead.  If it is read but never written,
1268     * then all values read are undefined garbage and we may as well not read
1269     * them.
1270     *
1271     * The same logic applies to the array length.  We make the array length
1272     * the minimum needed required length between read and write and plan to
1273     * discard any OOB access.  The one exception here is indirect writes
1274     * because we don't know where they will land and we can't shrink an array
1275     * with indirect writes because previously in-bounds writes may become
1276     * out-of-bounds and have undefined behavior.
1277     *
1278     * Also, if we have a copy that to/from something we can't shrink, we need
1279     * to leave components and array_len of any wildcards alone.
1280     */
1281    nir_foreach_variable_in_list(var, vars) {
1282       if (var->data.mode != mode)
1283          continue;
1284 
1285       struct vec_var_usage *usage =
1286          get_vec_var_usage(var, var_usage_map, false, NULL);
1287       if (!usage)
1288          continue;
1289 
1290       assert(usage->comps_kept == 0);
1291       if (usage->has_external_copy || usage->has_complex_use)
1292          usage->comps_kept = usage->all_comps;
1293       else
1294          usage->comps_kept = usage->comps_read & usage->comps_written;
1295 
1296       for (unsigned i = 0; i < usage->num_levels; i++) {
1297          struct array_level_usage *level = &usage->levels[i];
1298          assert(level->array_len > 0);
1299 
1300          if (level->max_written == UINT_MAX || level->has_external_copy ||
1301              usage->has_complex_use)
1302             continue; /* Can't shrink */
1303 
1304          unsigned max_used = MIN2(level->max_read, level->max_written);
1305          level->array_len = MIN2(max_used, level->array_len - 1) + 1;
1306       }
1307    }
1308 
1309    /* In order for variable copies to work, we have to have the same data type
1310     * on the source and the destination.  In order to satisfy this, we run a
1311     * little fixed-point algorithm to transitively ensure that we get enough
1312     * components and array elements for this to hold for all copies.
1313     */
1314    bool fp_progress;
1315    do {
1316       fp_progress = false;
1317       nir_foreach_variable_in_list(var, vars) {
1318          if (var->data.mode != mode)
1319             continue;
1320 
1321          struct vec_var_usage *var_usage =
1322             get_vec_var_usage(var, var_usage_map, false, NULL);
1323          if (!var_usage || !var_usage->vars_copied)
1324             continue;
1325 
1326          set_foreach(var_usage->vars_copied, copy_entry) {
1327             struct vec_var_usage *copy_usage = (void *)copy_entry->key;
1328             if (copy_usage->comps_kept != var_usage->comps_kept) {
1329                nir_component_mask_t comps_kept =
1330                   (var_usage->comps_kept | copy_usage->comps_kept);
1331                var_usage->comps_kept = comps_kept;
1332                copy_usage->comps_kept = comps_kept;
1333                fp_progress = true;
1334             }
1335          }
1336 
1337          for (unsigned i = 0; i < var_usage->num_levels; i++) {
1338             struct array_level_usage *var_level = &var_usage->levels[i];
1339             if (!var_level->levels_copied)
1340                continue;
1341 
1342             set_foreach(var_level->levels_copied, copy_entry) {
1343                struct array_level_usage *copy_level = (void *)copy_entry->key;
1344                if (var_level->array_len != copy_level->array_len) {
1345                   unsigned array_len =
1346                      MAX2(var_level->array_len, copy_level->array_len);
1347                   var_level->array_len = array_len;
1348                   copy_level->array_len = array_len;
1349                   fp_progress = true;
1350                }
1351             }
1352          }
1353       }
1354    } while (fp_progress);
1355 
1356    bool vars_shrunk = false;
1357    nir_foreach_variable_in_list_safe(var, vars) {
1358       if (var->data.mode != mode)
1359          continue;
1360 
1361       struct vec_var_usage *usage =
1362          get_vec_var_usage(var, var_usage_map, false, NULL);
1363       if (!usage)
1364          continue;
1365 
1366       bool shrunk = false;
1367       const struct glsl_type *vec_type = var->type;
1368       for (unsigned i = 0; i < usage->num_levels; i++) {
1369          /* If we've reduced the array to zero elements at some level, just
1370           * set comps_kept to 0 and delete the variable.
1371           */
1372          if (usage->levels[i].array_len == 0) {
1373             usage->comps_kept = 0;
1374             break;
1375          }
1376 
1377          assert(usage->levels[i].array_len <= glsl_get_length(vec_type));
1378          if (usage->levels[i].array_len < glsl_get_length(vec_type))
1379             shrunk = true;
1380          vec_type = glsl_get_array_element(vec_type);
1381       }
1382       assert(glsl_type_is_vector_or_scalar(vec_type));
1383 
1384       assert(usage->comps_kept == (usage->comps_kept & usage->all_comps));
1385       if (usage->comps_kept != usage->all_comps)
1386          shrunk = true;
1387 
1388       if (usage->comps_kept == 0) {
1389          /* This variable is dead, remove it */
1390          vars_shrunk = true;
1391          exec_node_remove(&var->node);
1392          continue;
1393       }
1394 
1395       if (!shrunk) {
1396          /* This variable doesn't need to be shrunk.  Remove it from the
1397           * hash table so later steps will ignore it.
1398           */
1399          _mesa_hash_table_remove_key(var_usage_map, var);
1400          continue;
1401       }
1402 
1403       /* Build the new var type */
1404       unsigned new_num_comps = util_bitcount(usage->comps_kept);
1405       const struct glsl_type *new_type =
1406          glsl_vector_type(glsl_get_base_type(vec_type), new_num_comps);
1407       for (int i = usage->num_levels - 1; i >= 0; i--) {
1408          assert(usage->levels[i].array_len > 0);
1409          /* If the original type was a matrix type, we'd like to keep that so
1410           * we don't convert matrices into arrays.
1411           */
1412          if (i == usage->num_levels - 1 &&
1413              glsl_type_is_matrix(glsl_without_array(var->type)) &&
1414              new_num_comps > 1 && usage->levels[i].array_len > 1) {
1415             new_type = glsl_matrix_type(glsl_get_base_type(new_type),
1416                                         new_num_comps,
1417                                         usage->levels[i].array_len);
1418          } else {
1419             new_type = glsl_array_type(new_type, usage->levels[i].array_len, 0);
1420          }
1421       }
1422       var->type = new_type;
1423 
1424       vars_shrunk = true;
1425    }
1426 
1427    return vars_shrunk;
1428 }
1429 
1430 static bool
vec_deref_is_oob(nir_deref_instr * deref,struct vec_var_usage * usage)1431 vec_deref_is_oob(nir_deref_instr *deref,
1432                  struct vec_var_usage *usage)
1433 {
1434    nir_deref_path path;
1435    nir_deref_path_init(&path, deref, NULL);
1436 
1437    bool oob = false;
1438    for (unsigned i = 0; i < usage->num_levels; i++) {
1439       nir_deref_instr *p = path.path[i + 1];
1440       if (p->deref_type == nir_deref_type_array_wildcard)
1441          continue;
1442 
1443       if (nir_src_is_const(p->arr.index) &&
1444           nir_src_as_uint(p->arr.index) >= usage->levels[i].array_len) {
1445          oob = true;
1446          break;
1447       }
1448    }
1449 
1450    nir_deref_path_finish(&path);
1451 
1452    return oob;
1453 }
1454 
1455 static bool
vec_deref_is_dead_or_oob(nir_deref_instr * deref,struct hash_table * var_usage_map,nir_variable_mode modes)1456 vec_deref_is_dead_or_oob(nir_deref_instr *deref,
1457                          struct hash_table *var_usage_map,
1458                          nir_variable_mode modes)
1459 {
1460    struct vec_var_usage *usage =
1461       get_vec_deref_usage(deref, var_usage_map, modes, false, NULL);
1462    if (!usage)
1463       return false;
1464 
1465    return usage->comps_kept == 0 || vec_deref_is_oob(deref, usage);
1466 }
1467 
1468 static void
shrink_vec_var_access_impl(nir_function_impl * impl,struct hash_table * var_usage_map,nir_variable_mode modes)1469 shrink_vec_var_access_impl(nir_function_impl *impl,
1470                            struct hash_table *var_usage_map,
1471                            nir_variable_mode modes)
1472 {
1473    nir_builder b;
1474    nir_builder_init(&b, impl);
1475 
1476    nir_foreach_block(block, impl) {
1477       nir_foreach_instr_safe(instr, block) {
1478          switch (instr->type) {
1479          case nir_instr_type_deref: {
1480             nir_deref_instr *deref = nir_instr_as_deref(instr);
1481             if (!(deref->mode & modes))
1482                break;
1483 
1484             /* Clean up any dead derefs we find lying around.  They may refer
1485              * to variables we've deleted.
1486              */
1487             if (nir_deref_instr_remove_if_unused(deref))
1488                break;
1489 
1490             /* Update the type in the deref to keep the types consistent as
1491              * you walk down the chain.  We don't need to check if this is one
1492              * of the derefs we're shrinking because this is a no-op if it
1493              * isn't.  The worst that could happen is that we accidentally fix
1494              * an invalid deref.
1495              */
1496             if (deref->deref_type == nir_deref_type_var) {
1497                deref->type = deref->var->type;
1498             } else if (deref->deref_type == nir_deref_type_array ||
1499                        deref->deref_type == nir_deref_type_array_wildcard) {
1500                nir_deref_instr *parent = nir_deref_instr_parent(deref);
1501                assert(glsl_type_is_array(parent->type) ||
1502                       glsl_type_is_matrix(parent->type));
1503                deref->type = glsl_get_array_element(parent->type);
1504             }
1505             break;
1506          }
1507 
1508          case nir_instr_type_intrinsic: {
1509             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1510 
1511             /* If we have a copy whose source or destination has been deleted
1512              * because we determined the variable was dead, then we just
1513              * delete the copy instruction.  If the source variable was dead
1514              * then it was writing undefined garbage anyway and if it's the
1515              * destination variable that's dead then the write isn't needed.
1516              */
1517             if (intrin->intrinsic == nir_intrinsic_copy_deref) {
1518                nir_deref_instr *dst = nir_src_as_deref(intrin->src[0]);
1519                nir_deref_instr *src = nir_src_as_deref(intrin->src[1]);
1520                if (vec_deref_is_dead_or_oob(dst, var_usage_map, modes) ||
1521                    vec_deref_is_dead_or_oob(src, var_usage_map, modes)) {
1522                   nir_instr_remove(&intrin->instr);
1523                   nir_deref_instr_remove_if_unused(dst);
1524                   nir_deref_instr_remove_if_unused(src);
1525                }
1526                continue;
1527             }
1528 
1529             if (intrin->intrinsic != nir_intrinsic_load_deref &&
1530                 intrin->intrinsic != nir_intrinsic_store_deref)
1531                continue;
1532 
1533             nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
1534             if (!(deref->mode & modes))
1535                continue;
1536 
1537             struct vec_var_usage *usage =
1538                get_vec_deref_usage(deref, var_usage_map, modes, false, NULL);
1539             if (!usage)
1540                continue;
1541 
1542             if (usage->comps_kept == 0 || vec_deref_is_oob(deref, usage)) {
1543                if (intrin->intrinsic == nir_intrinsic_load_deref) {
1544                   nir_ssa_def *u =
1545                      nir_ssa_undef(&b, intrin->dest.ssa.num_components,
1546                                        intrin->dest.ssa.bit_size);
1547                   nir_ssa_def_rewrite_uses(&intrin->dest.ssa,
1548                                            nir_src_for_ssa(u));
1549                }
1550                nir_instr_remove(&intrin->instr);
1551                nir_deref_instr_remove_if_unused(deref);
1552                continue;
1553             }
1554 
1555             /* If we're not dropping any components, there's no need to
1556              * compact vectors.
1557              */
1558             if (usage->comps_kept == usage->all_comps)
1559                continue;
1560 
1561             if (intrin->intrinsic == nir_intrinsic_load_deref) {
1562                b.cursor = nir_after_instr(&intrin->instr);
1563 
1564                nir_ssa_def *undef =
1565                   nir_ssa_undef(&b, 1, intrin->dest.ssa.bit_size);
1566                nir_ssa_def *vec_srcs[NIR_MAX_VEC_COMPONENTS];
1567                unsigned c = 0;
1568                for (unsigned i = 0; i < intrin->num_components; i++) {
1569                   if (usage->comps_kept & (1u << i))
1570                      vec_srcs[i] = nir_channel(&b, &intrin->dest.ssa, c++);
1571                   else
1572                      vec_srcs[i] = undef;
1573                }
1574                nir_ssa_def *vec = nir_vec(&b, vec_srcs, intrin->num_components);
1575 
1576                nir_ssa_def_rewrite_uses_after(&intrin->dest.ssa,
1577                                               nir_src_for_ssa(vec),
1578                                               vec->parent_instr);
1579 
1580                /* The SSA def is now only used by the swizzle.  It's safe to
1581                 * shrink the number of components.
1582                 */
1583                assert(list_length(&intrin->dest.ssa.uses) == c);
1584                intrin->num_components = c;
1585                intrin->dest.ssa.num_components = c;
1586             } else {
1587                nir_component_mask_t write_mask =
1588                   nir_intrinsic_write_mask(intrin);
1589 
1590                unsigned swizzle[NIR_MAX_VEC_COMPONENTS];
1591                nir_component_mask_t new_write_mask = 0;
1592                unsigned c = 0;
1593                for (unsigned i = 0; i < intrin->num_components; i++) {
1594                   if (usage->comps_kept & (1u << i)) {
1595                      swizzle[c] = i;
1596                      if (write_mask & (1u << i))
1597                         new_write_mask |= 1u << c;
1598                      c++;
1599                   }
1600                }
1601 
1602                b.cursor = nir_before_instr(&intrin->instr);
1603 
1604                nir_ssa_def *swizzled =
1605                   nir_swizzle(&b, intrin->src[1].ssa, swizzle, c);
1606 
1607                /* Rewrite to use the compacted source */
1608                nir_instr_rewrite_src(&intrin->instr, &intrin->src[1],
1609                                      nir_src_for_ssa(swizzled));
1610                nir_intrinsic_set_write_mask(intrin, new_write_mask);
1611                intrin->num_components = c;
1612             }
1613             break;
1614          }
1615 
1616          default:
1617             break;
1618          }
1619       }
1620    }
1621 }
1622 
1623 static bool
function_impl_has_vars_with_modes(nir_function_impl * impl,nir_variable_mode modes)1624 function_impl_has_vars_with_modes(nir_function_impl *impl,
1625                                   nir_variable_mode modes)
1626 {
1627    nir_shader *shader = impl->function->shader;
1628 
1629    if (modes & ~nir_var_function_temp) {
1630       nir_foreach_variable_with_modes(var, shader,
1631                                       modes & ~nir_var_function_temp)
1632          return true;
1633    }
1634 
1635    if ((modes & nir_var_function_temp) && !exec_list_is_empty(&impl->locals))
1636       return true;
1637 
1638    return false;
1639 }
1640 
1641 /** Attempt to shrink arrays of vectors
1642  *
1643  * This pass looks at variables which contain a vector or an array (possibly
1644  * multiple dimensions) of vectors and attempts to lower to a smaller vector
1645  * or array.  If the pass can prove that a component of a vector (or array of
1646  * vectors) is never really used, then that component will be removed.
1647  * Similarly, the pass attempts to shorten arrays based on what elements it
1648  * can prove are never read or never contain valid data.
1649  */
1650 bool
nir_shrink_vec_array_vars(nir_shader * shader,nir_variable_mode modes)1651 nir_shrink_vec_array_vars(nir_shader *shader, nir_variable_mode modes)
1652 {
1653    assert((modes & (nir_var_shader_temp | nir_var_function_temp)) == modes);
1654 
1655    void *mem_ctx = ralloc_context(NULL);
1656 
1657    struct hash_table *var_usage_map =
1658       _mesa_pointer_hash_table_create(mem_ctx);
1659 
1660    bool has_vars_to_shrink = false;
1661    nir_foreach_function(function, shader) {
1662       if (!function->impl)
1663          continue;
1664 
1665       /* Don't even bother crawling the IR if we don't have any variables.
1666        * Given that this pass deletes any unused variables, it's likely that
1667        * we will be in this scenario eventually.
1668        */
1669       if (function_impl_has_vars_with_modes(function->impl, modes)) {
1670          has_vars_to_shrink = true;
1671          find_used_components_impl(function->impl, var_usage_map,
1672                                    modes, mem_ctx);
1673       }
1674    }
1675    if (!has_vars_to_shrink) {
1676       ralloc_free(mem_ctx);
1677       nir_shader_preserve_all_metadata(shader);
1678       return false;
1679    }
1680 
1681    bool globals_shrunk = false;
1682    if (modes & nir_var_shader_temp) {
1683       globals_shrunk = shrink_vec_var_list(&shader->variables,
1684                                            nir_var_shader_temp,
1685                                            var_usage_map);
1686    }
1687 
1688    bool progress = false;
1689    nir_foreach_function(function, shader) {
1690       if (!function->impl)
1691          continue;
1692 
1693       bool locals_shrunk = false;
1694       if (modes & nir_var_function_temp) {
1695          locals_shrunk = shrink_vec_var_list(&function->impl->locals,
1696                                              nir_var_function_temp,
1697                                              var_usage_map);
1698       }
1699 
1700       if (globals_shrunk || locals_shrunk) {
1701          shrink_vec_var_access_impl(function->impl, var_usage_map, modes);
1702 
1703          nir_metadata_preserve(function->impl, nir_metadata_block_index |
1704                                                nir_metadata_dominance);
1705          progress = true;
1706       } else {
1707          nir_metadata_preserve(function->impl, nir_metadata_all);
1708       }
1709    }
1710 
1711    ralloc_free(mem_ctx);
1712 
1713    return progress;
1714 }
1715