1 /*
2  * Copyright © 2018 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  */
24 
25 #include "aco_builder.h"
26 #include "aco_ir.h"
27 
28 #include "util/half_float.h"
29 #include "util/memstream.h"
30 
31 #include <algorithm>
32 #include <array>
33 #include <vector>
34 
35 namespace aco {
36 
37 #ifndef NDEBUG
38 void
perfwarn(Program * program,bool cond,const char * msg,Instruction * instr)39 perfwarn(Program* program, bool cond, const char* msg, Instruction* instr)
40 {
41    if (cond) {
42       char* out;
43       size_t outsize;
44       struct u_memstream mem;
45       u_memstream_open(&mem, &out, &outsize);
46       FILE* const memf = u_memstream_get(&mem);
47 
48       fprintf(memf, "%s: ", msg);
49       aco_print_instr(instr, memf);
50       u_memstream_close(&mem);
51 
52       aco_perfwarn(program, out);
53       free(out);
54 
55       if (debug_flags & DEBUG_PERFWARN)
56          exit(1);
57    }
58 }
59 #endif
60 
61 /**
62  * The optimizer works in 4 phases:
63  * (1) The first pass collects information for each ssa-def,
64  *     propagates reg->reg operands of the same type, inline constants
65  *     and neg/abs input modifiers.
66  * (2) The second pass combines instructions like mad, omod, clamp and
67  *     propagates sgpr's on VALU instructions.
68  *     This pass depends on information collected in the first pass.
69  * (3) The third pass goes backwards, and selects instructions,
70  *     i.e. decides if a mad instruction is profitable and eliminates dead code.
71  * (4) The fourth pass cleans up the sequence: literals get applied and dead
72  *     instructions are removed from the sequence.
73  */
74 
75 struct mad_info {
76    aco_ptr<Instruction> add_instr;
77    uint32_t mul_temp_id;
78    uint16_t literal_idx;
79    bool check_literal;
80 
mad_infoaco::mad_info81    mad_info(aco_ptr<Instruction> instr, uint32_t id)
82        : add_instr(std::move(instr)), mul_temp_id(id), literal_idx(0), check_literal(false)
83    {}
84 };
85 
86 enum Label {
87    label_vec = 1 << 0,
88    label_constant_32bit = 1 << 1,
89    /* label_{abs,neg,mul,omod2,omod4,omod5,clamp} are used for both 16 and
90     * 32-bit operations but this shouldn't cause any issues because we don't
91     * look through any conversions */
92    label_abs = 1 << 2,
93    label_neg = 1 << 3,
94    label_mul = 1 << 4,
95    label_temp = 1 << 5,
96    label_literal = 1 << 6,
97    label_mad = 1 << 7,
98    label_omod2 = 1 << 8,
99    label_omod4 = 1 << 9,
100    label_omod5 = 1 << 10,
101    label_clamp = 1 << 12,
102    label_undefined = 1 << 14,
103    label_vcc = 1 << 15,
104    label_b2f = 1 << 16,
105    label_add_sub = 1 << 17,
106    label_bitwise = 1 << 18,
107    label_minmax = 1 << 19,
108    label_vopc = 1 << 20,
109    label_uniform_bool = 1 << 21,
110    label_constant_64bit = 1 << 22,
111    label_uniform_bitwise = 1 << 23,
112    label_scc_invert = 1 << 24,
113    label_vcc_hint = 1 << 25,
114    label_scc_needed = 1 << 26,
115    label_b2i = 1 << 27,
116    label_fcanonicalize = 1 << 28,
117    label_constant_16bit = 1 << 29,
118    label_usedef = 1 << 30,   /* generic label */
119    label_vop3p = 1ull << 31, /* 1ull to prevent sign extension */
120    label_canonicalized = 1ull << 32,
121    label_extract = 1ull << 33,
122    label_insert = 1ull << 34,
123    label_dpp = 1ull << 35,
124 };
125 
126 static constexpr uint64_t instr_usedef_labels =
127    label_vec | label_mul | label_mad | label_add_sub | label_vop3p | label_bitwise |
128    label_uniform_bitwise | label_minmax | label_vopc | label_usedef | label_extract | label_dpp;
129 static constexpr uint64_t instr_mod_labels =
130    label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert;
131 
132 static constexpr uint64_t instr_labels = instr_usedef_labels | instr_mod_labels;
133 static constexpr uint64_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f |
134                                         label_uniform_bool | label_scc_invert | label_b2i |
135                                         label_fcanonicalize;
136 static constexpr uint32_t val_labels =
137    label_constant_32bit | label_constant_64bit | label_constant_16bit | label_literal;
138 
139 static_assert((instr_labels & temp_labels) == 0, "labels cannot intersect");
140 static_assert((instr_labels & val_labels) == 0, "labels cannot intersect");
141 static_assert((temp_labels & val_labels) == 0, "labels cannot intersect");
142 
143 struct ssa_info {
144    uint64_t label;
145    union {
146       uint32_t val;
147       Temp temp;
148       Instruction* instr;
149    };
150 
ssa_infoaco::ssa_info151    ssa_info() : label(0) {}
152 
add_labelaco::ssa_info153    void add_label(Label new_label)
154    {
155       /* Since all the instr_usedef_labels use instr for the same thing
156        * (indicating the defining instruction), there is usually no need to
157        * clear any other instr labels. */
158       if (new_label & instr_usedef_labels)
159          label &= ~(instr_mod_labels | temp_labels | val_labels); /* instr, temp and val alias */
160 
161       if (new_label & instr_mod_labels) {
162          label &= ~instr_labels;
163          label &= ~(temp_labels | val_labels); /* instr, temp and val alias */
164       }
165 
166       if (new_label & temp_labels) {
167          label &= ~temp_labels;
168          label &= ~(instr_labels | val_labels); /* instr, temp and val alias */
169       }
170 
171       uint32_t const_labels =
172          label_literal | label_constant_32bit | label_constant_64bit | label_constant_16bit;
173       if (new_label & const_labels) {
174          label &= ~val_labels | const_labels;
175          label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
176       } else if (new_label & val_labels) {
177          label &= ~val_labels;
178          label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
179       }
180 
181       label |= new_label;
182    }
183 
set_vecaco::ssa_info184    void set_vec(Instruction* vec)
185    {
186       add_label(label_vec);
187       instr = vec;
188    }
189 
is_vecaco::ssa_info190    bool is_vec() { return label & label_vec; }
191 
set_constantaco::ssa_info192    void set_constant(chip_class chip, uint64_t constant)
193    {
194       Operand op16 = Operand::c16(constant);
195       Operand op32 = Operand::get_const(chip, constant, 4);
196       add_label(label_literal);
197       val = constant;
198 
199       /* check that no upper bits are lost in case of packed 16bit constants */
200       if (chip >= GFX8 && !op16.isLiteral() && op16.constantValue64() == constant)
201          add_label(label_constant_16bit);
202 
203       if (!op32.isLiteral())
204          add_label(label_constant_32bit);
205 
206       if (Operand::is_constant_representable(constant, 8))
207          add_label(label_constant_64bit);
208 
209       if (label & label_constant_64bit) {
210          val = Operand::c64(constant).constantValue();
211          if (val != constant)
212             label &= ~(label_literal | label_constant_16bit | label_constant_32bit);
213       }
214    }
215 
is_constantaco::ssa_info216    bool is_constant(unsigned bits)
217    {
218       switch (bits) {
219       case 8: return label & label_literal;
220       case 16: return label & label_constant_16bit;
221       case 32: return label & label_constant_32bit;
222       case 64: return label & label_constant_64bit;
223       }
224       return false;
225    }
226 
is_literalaco::ssa_info227    bool is_literal(unsigned bits)
228    {
229       bool is_lit = label & label_literal;
230       switch (bits) {
231       case 8: return false;
232       case 16: return is_lit && ~(label & label_constant_16bit);
233       case 32: return is_lit && ~(label & label_constant_32bit);
234       case 64: return false;
235       }
236       return false;
237    }
238 
is_constant_or_literalaco::ssa_info239    bool is_constant_or_literal(unsigned bits)
240    {
241       if (bits == 64)
242          return label & label_constant_64bit;
243       else
244          return label & label_literal;
245    }
246 
set_absaco::ssa_info247    void set_abs(Temp abs_temp)
248    {
249       add_label(label_abs);
250       temp = abs_temp;
251    }
252 
is_absaco::ssa_info253    bool is_abs() { return label & label_abs; }
254 
set_negaco::ssa_info255    void set_neg(Temp neg_temp)
256    {
257       add_label(label_neg);
258       temp = neg_temp;
259    }
260 
is_negaco::ssa_info261    bool is_neg() { return label & label_neg; }
262 
set_neg_absaco::ssa_info263    void set_neg_abs(Temp neg_abs_temp)
264    {
265       add_label((Label)((uint32_t)label_abs | (uint32_t)label_neg));
266       temp = neg_abs_temp;
267    }
268 
set_mulaco::ssa_info269    void set_mul(Instruction* mul)
270    {
271       add_label(label_mul);
272       instr = mul;
273    }
274 
is_mulaco::ssa_info275    bool is_mul() { return label & label_mul; }
276 
set_tempaco::ssa_info277    void set_temp(Temp tmp)
278    {
279       add_label(label_temp);
280       temp = tmp;
281    }
282 
is_tempaco::ssa_info283    bool is_temp() { return label & label_temp; }
284 
set_madaco::ssa_info285    void set_mad(Instruction* mad, uint32_t mad_info_idx)
286    {
287       add_label(label_mad);
288       mad->pass_flags = mad_info_idx;
289       instr = mad;
290    }
291 
is_madaco::ssa_info292    bool is_mad() { return label & label_mad; }
293 
set_omod2aco::ssa_info294    void set_omod2(Instruction* mul)
295    {
296       add_label(label_omod2);
297       instr = mul;
298    }
299 
is_omod2aco::ssa_info300    bool is_omod2() { return label & label_omod2; }
301 
set_omod4aco::ssa_info302    void set_omod4(Instruction* mul)
303    {
304       add_label(label_omod4);
305       instr = mul;
306    }
307 
is_omod4aco::ssa_info308    bool is_omod4() { return label & label_omod4; }
309 
set_omod5aco::ssa_info310    void set_omod5(Instruction* mul)
311    {
312       add_label(label_omod5);
313       instr = mul;
314    }
315 
is_omod5aco::ssa_info316    bool is_omod5() { return label & label_omod5; }
317 
set_clampaco::ssa_info318    void set_clamp(Instruction* med3)
319    {
320       add_label(label_clamp);
321       instr = med3;
322    }
323 
is_clampaco::ssa_info324    bool is_clamp() { return label & label_clamp; }
325 
set_undefinedaco::ssa_info326    void set_undefined() { add_label(label_undefined); }
327 
is_undefinedaco::ssa_info328    bool is_undefined() { return label & label_undefined; }
329 
set_vccaco::ssa_info330    void set_vcc(Temp vcc_val)
331    {
332       add_label(label_vcc);
333       temp = vcc_val;
334    }
335 
is_vccaco::ssa_info336    bool is_vcc() { return label & label_vcc; }
337 
set_b2faco::ssa_info338    void set_b2f(Temp b2f_val)
339    {
340       add_label(label_b2f);
341       temp = b2f_val;
342    }
343 
is_b2faco::ssa_info344    bool is_b2f() { return label & label_b2f; }
345 
set_add_subaco::ssa_info346    void set_add_sub(Instruction* add_sub_instr)
347    {
348       add_label(label_add_sub);
349       instr = add_sub_instr;
350    }
351 
is_add_subaco::ssa_info352    bool is_add_sub() { return label & label_add_sub; }
353 
set_bitwiseaco::ssa_info354    void set_bitwise(Instruction* bitwise_instr)
355    {
356       add_label(label_bitwise);
357       instr = bitwise_instr;
358    }
359 
is_bitwiseaco::ssa_info360    bool is_bitwise() { return label & label_bitwise; }
361 
set_uniform_bitwiseaco::ssa_info362    void set_uniform_bitwise() { add_label(label_uniform_bitwise); }
363 
is_uniform_bitwiseaco::ssa_info364    bool is_uniform_bitwise() { return label & label_uniform_bitwise; }
365 
set_minmaxaco::ssa_info366    void set_minmax(Instruction* minmax_instr)
367    {
368       add_label(label_minmax);
369       instr = minmax_instr;
370    }
371 
is_minmaxaco::ssa_info372    bool is_minmax() { return label & label_minmax; }
373 
set_vopcaco::ssa_info374    void set_vopc(Instruction* vopc_instr)
375    {
376       add_label(label_vopc);
377       instr = vopc_instr;
378    }
379 
is_vopcaco::ssa_info380    bool is_vopc() { return label & label_vopc; }
381 
set_scc_neededaco::ssa_info382    void set_scc_needed() { add_label(label_scc_needed); }
383 
is_scc_neededaco::ssa_info384    bool is_scc_needed() { return label & label_scc_needed; }
385 
set_scc_invertaco::ssa_info386    void set_scc_invert(Temp scc_inv)
387    {
388       add_label(label_scc_invert);
389       temp = scc_inv;
390    }
391 
is_scc_invertaco::ssa_info392    bool is_scc_invert() { return label & label_scc_invert; }
393 
set_uniform_boolaco::ssa_info394    void set_uniform_bool(Temp uniform_bool)
395    {
396       add_label(label_uniform_bool);
397       temp = uniform_bool;
398    }
399 
is_uniform_boolaco::ssa_info400    bool is_uniform_bool() { return label & label_uniform_bool; }
401 
set_vcc_hintaco::ssa_info402    void set_vcc_hint() { add_label(label_vcc_hint); }
403 
is_vcc_hintaco::ssa_info404    bool is_vcc_hint() { return label & label_vcc_hint; }
405 
set_b2iaco::ssa_info406    void set_b2i(Temp b2i_val)
407    {
408       add_label(label_b2i);
409       temp = b2i_val;
410    }
411 
is_b2iaco::ssa_info412    bool is_b2i() { return label & label_b2i; }
413 
set_usedefaco::ssa_info414    void set_usedef(Instruction* label_instr)
415    {
416       add_label(label_usedef);
417       instr = label_instr;
418    }
419 
is_usedefaco::ssa_info420    bool is_usedef() { return label & label_usedef; }
421 
set_vop3paco::ssa_info422    void set_vop3p(Instruction* vop3p_instr)
423    {
424       add_label(label_vop3p);
425       instr = vop3p_instr;
426    }
427 
is_vop3paco::ssa_info428    bool is_vop3p() { return label & label_vop3p; }
429 
set_fcanonicalizeaco::ssa_info430    void set_fcanonicalize(Temp tmp)
431    {
432       add_label(label_fcanonicalize);
433       temp = tmp;
434    }
435 
is_fcanonicalizeaco::ssa_info436    bool is_fcanonicalize() { return label & label_fcanonicalize; }
437 
set_canonicalizedaco::ssa_info438    void set_canonicalized() { add_label(label_canonicalized); }
439 
is_canonicalizedaco::ssa_info440    bool is_canonicalized() { return label & label_canonicalized; }
441 
set_extractaco::ssa_info442    void set_extract(Instruction* extract)
443    {
444       add_label(label_extract);
445       instr = extract;
446    }
447 
is_extractaco::ssa_info448    bool is_extract() { return label & label_extract; }
449 
set_insertaco::ssa_info450    void set_insert(Instruction* insert)
451    {
452       add_label(label_insert);
453       instr = insert;
454    }
455 
is_insertaco::ssa_info456    bool is_insert() { return label & label_insert; }
457 
set_dppaco::ssa_info458    void set_dpp(Instruction* mov)
459    {
460       add_label(label_dpp);
461       instr = mov;
462    }
463 
is_dppaco::ssa_info464    bool is_dpp() { return label & label_dpp; }
465 };
466 
467 struct opt_ctx {
468    Program* program;
469    float_mode fp_mode;
470    std::vector<aco_ptr<Instruction>> instructions;
471    ssa_info* info;
472    std::pair<uint32_t, Temp> last_literal;
473    std::vector<mad_info> mad_infos;
474    std::vector<uint16_t> uses;
475 };
476 
477 bool
can_use_VOP3(opt_ctx & ctx,const aco_ptr<Instruction> & instr)478 can_use_VOP3(opt_ctx& ctx, const aco_ptr<Instruction>& instr)
479 {
480    if (instr->isVOP3())
481       return true;
482 
483    if (instr->isVOP3P())
484       return false;
485 
486    if (instr->operands.size() && instr->operands[0].isLiteral() && ctx.program->chip_class < GFX10)
487       return false;
488 
489    if (instr->isDPP() || instr->isSDWA())
490       return false;
491 
492    return instr->opcode != aco_opcode::v_madmk_f32 && instr->opcode != aco_opcode::v_madak_f32 &&
493           instr->opcode != aco_opcode::v_madmk_f16 && instr->opcode != aco_opcode::v_madak_f16 &&
494           instr->opcode != aco_opcode::v_fmamk_f32 && instr->opcode != aco_opcode::v_fmaak_f32 &&
495           instr->opcode != aco_opcode::v_fmamk_f16 && instr->opcode != aco_opcode::v_fmaak_f16 &&
496           instr->opcode != aco_opcode::v_readlane_b32 &&
497           instr->opcode != aco_opcode::v_writelane_b32 &&
498           instr->opcode != aco_opcode::v_readfirstlane_b32;
499 }
500 
501 bool
pseudo_propagate_temp(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp temp,unsigned index)502 pseudo_propagate_temp(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp temp, unsigned index)
503 {
504    if (instr->definitions.empty())
505       return false;
506 
507    const bool vgpr =
508       instr->opcode == aco_opcode::p_as_uniform ||
509       std::all_of(instr->definitions.begin(), instr->definitions.end(),
510                   [](const Definition& def) { return def.regClass().type() == RegType::vgpr; });
511 
512    /* don't propagate VGPRs into SGPR instructions */
513    if (temp.type() == RegType::vgpr && !vgpr)
514       return false;
515 
516    bool can_accept_sgpr =
517       ctx.program->chip_class >= GFX9 ||
518       std::none_of(instr->definitions.begin(), instr->definitions.end(),
519                    [](const Definition& def) { return def.regClass().is_subdword(); });
520 
521    switch (instr->opcode) {
522    case aco_opcode::p_phi:
523    case aco_opcode::p_linear_phi:
524    case aco_opcode::p_parallelcopy:
525    case aco_opcode::p_create_vector:
526       if (temp.bytes() != instr->operands[index].bytes())
527          return false;
528       break;
529    case aco_opcode::p_extract_vector:
530       if (temp.type() == RegType::sgpr && !can_accept_sgpr)
531          return false;
532       break;
533    case aco_opcode::p_split_vector: {
534       if (temp.type() == RegType::sgpr && !can_accept_sgpr)
535          return false;
536       /* don't increase the vector size */
537       if (temp.bytes() > instr->operands[index].bytes())
538          return false;
539       /* We can decrease the vector size as smaller temporaries are only
540        * propagated by p_as_uniform instructions.
541        * If this propagation leads to invalid IR or hits the assertion below,
542        * it means that some undefined bytes within a dword are begin accessed
543        * and a bug in instruction_selection is likely. */
544       int decrease = instr->operands[index].bytes() - temp.bytes();
545       while (decrease > 0) {
546          decrease -= instr->definitions.back().bytes();
547          instr->definitions.pop_back();
548       }
549       assert(decrease == 0);
550       break;
551    }
552    case aco_opcode::p_as_uniform:
553       if (temp.regClass() == instr->definitions[0].regClass())
554          instr->opcode = aco_opcode::p_parallelcopy;
555       break;
556    default: return false;
557    }
558 
559    instr->operands[index].setTemp(temp);
560    return true;
561 }
562 
563 /* This expects the DPP modifier to be removed. */
564 bool
can_apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)565 can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
566 {
567    if (instr->isSDWA() && ctx.program->chip_class < GFX9)
568       return false;
569    return instr->opcode != aco_opcode::v_readfirstlane_b32 &&
570           instr->opcode != aco_opcode::v_readlane_b32 &&
571           instr->opcode != aco_opcode::v_readlane_b32_e64 &&
572           instr->opcode != aco_opcode::v_writelane_b32 &&
573           instr->opcode != aco_opcode::v_writelane_b32_e64 &&
574           instr->opcode != aco_opcode::v_permlane16_b32 &&
575           instr->opcode != aco_opcode::v_permlanex16_b32;
576 }
577 
578 void
to_VOP3(opt_ctx & ctx,aco_ptr<Instruction> & instr)579 to_VOP3(opt_ctx& ctx, aco_ptr<Instruction>& instr)
580 {
581    if (instr->isVOP3())
582       return;
583 
584    aco_ptr<Instruction> tmp = std::move(instr);
585    Format format = asVOP3(tmp->format);
586    instr.reset(create_instruction<VOP3_instruction>(tmp->opcode, format, tmp->operands.size(),
587                                                     tmp->definitions.size()));
588    std::copy(tmp->operands.cbegin(), tmp->operands.cend(), instr->operands.begin());
589    for (unsigned i = 0; i < instr->definitions.size(); i++) {
590       instr->definitions[i] = tmp->definitions[i];
591       if (instr->definitions[i].isTemp()) {
592          ssa_info& info = ctx.info[instr->definitions[i].tempId()];
593          if (info.label & instr_usedef_labels && info.instr == tmp.get())
594             info.instr = instr.get();
595       }
596    }
597    /* we don't need to update any instr_mod_labels because they either haven't
598     * been applied yet or this instruction isn't dead and so they've been ignored */
599 }
600 
601 bool
is_operand_vgpr(Operand op)602 is_operand_vgpr(Operand op)
603 {
604    return op.isTemp() && op.getTemp().type() == RegType::vgpr;
605 }
606 
607 void
to_SDWA(opt_ctx & ctx,aco_ptr<Instruction> & instr)608 to_SDWA(opt_ctx& ctx, aco_ptr<Instruction>& instr)
609 {
610    aco_ptr<Instruction> tmp = convert_to_SDWA(ctx.program->chip_class, instr);
611    if (!tmp)
612       return;
613 
614    for (unsigned i = 0; i < instr->definitions.size(); i++) {
615       ssa_info& info = ctx.info[instr->definitions[i].tempId()];
616       if (info.label & instr_labels && info.instr == tmp.get())
617          info.instr = instr.get();
618    }
619 }
620 
621 /* only covers special cases */
622 bool
alu_can_accept_constant(aco_opcode opcode,unsigned operand)623 alu_can_accept_constant(aco_opcode opcode, unsigned operand)
624 {
625    switch (opcode) {
626    case aco_opcode::v_interp_p2_f32:
627    case aco_opcode::v_mac_f32:
628    case aco_opcode::v_writelane_b32:
629    case aco_opcode::v_writelane_b32_e64:
630    case aco_opcode::v_cndmask_b32: return operand != 2;
631    case aco_opcode::s_addk_i32:
632    case aco_opcode::s_mulk_i32:
633    case aco_opcode::p_wqm:
634    case aco_opcode::p_extract_vector:
635    case aco_opcode::p_split_vector:
636    case aco_opcode::v_readlane_b32:
637    case aco_opcode::v_readlane_b32_e64:
638    case aco_opcode::v_readfirstlane_b32:
639    case aco_opcode::p_extract:
640    case aco_opcode::p_insert: return operand != 0;
641    default: return true;
642    }
643 }
644 
645 bool
valu_can_accept_vgpr(aco_ptr<Instruction> & instr,unsigned operand)646 valu_can_accept_vgpr(aco_ptr<Instruction>& instr, unsigned operand)
647 {
648    if (instr->opcode == aco_opcode::v_readlane_b32 ||
649        instr->opcode == aco_opcode::v_readlane_b32_e64 ||
650        instr->opcode == aco_opcode::v_writelane_b32 ||
651        instr->opcode == aco_opcode::v_writelane_b32_e64)
652       return operand != 1;
653    if (instr->opcode == aco_opcode::v_permlane16_b32 ||
654        instr->opcode == aco_opcode::v_permlanex16_b32)
655       return operand == 0;
656    return true;
657 }
658 
659 /* check constant bus and literal limitations */
660 bool
check_vop3_operands(opt_ctx & ctx,unsigned num_operands,Operand * operands)661 check_vop3_operands(opt_ctx& ctx, unsigned num_operands, Operand* operands)
662 {
663    int limit = ctx.program->chip_class >= GFX10 ? 2 : 1;
664    Operand literal32(s1);
665    Operand literal64(s2);
666    unsigned num_sgprs = 0;
667    unsigned sgpr[] = {0, 0};
668 
669    for (unsigned i = 0; i < num_operands; i++) {
670       Operand op = operands[i];
671 
672       if (op.hasRegClass() && op.regClass().type() == RegType::sgpr) {
673          /* two reads of the same SGPR count as 1 to the limit */
674          if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
675             if (num_sgprs < 2)
676                sgpr[num_sgprs++] = op.tempId();
677             limit--;
678             if (limit < 0)
679                return false;
680          }
681       } else if (op.isLiteral()) {
682          if (ctx.program->chip_class < GFX10)
683             return false;
684 
685          if (!literal32.isUndefined() && literal32.constantValue() != op.constantValue())
686             return false;
687          if (!literal64.isUndefined() && literal64.constantValue() != op.constantValue())
688             return false;
689 
690          /* Any number of 32-bit literals counts as only 1 to the limit. Same
691           * (but separately) for 64-bit literals. */
692          if (op.size() == 1 && literal32.isUndefined()) {
693             limit--;
694             literal32 = op;
695          } else if (op.size() == 2 && literal64.isUndefined()) {
696             limit--;
697             literal64 = op;
698          }
699 
700          if (limit < 0)
701             return false;
702       }
703    }
704 
705    return true;
706 }
707 
708 bool
parse_base_offset(opt_ctx & ctx,Instruction * instr,unsigned op_index,Temp * base,uint32_t * offset,bool prevent_overflow)709 parse_base_offset(opt_ctx& ctx, Instruction* instr, unsigned op_index, Temp* base, uint32_t* offset,
710                   bool prevent_overflow)
711 {
712    Operand op = instr->operands[op_index];
713 
714    if (!op.isTemp())
715       return false;
716    Temp tmp = op.getTemp();
717    if (!ctx.info[tmp.id()].is_add_sub())
718       return false;
719 
720    Instruction* add_instr = ctx.info[tmp.id()].instr;
721 
722    switch (add_instr->opcode) {
723    case aco_opcode::v_add_u32:
724    case aco_opcode::v_add_co_u32:
725    case aco_opcode::v_add_co_u32_e64:
726    case aco_opcode::s_add_i32:
727    case aco_opcode::s_add_u32: break;
728    default: return false;
729    }
730    if (prevent_overflow && !add_instr->definitions[0].isNUW())
731       return false;
732 
733    if (add_instr->usesModifiers())
734       return false;
735 
736    for (unsigned i = 0; i < 2; i++) {
737       if (add_instr->operands[i].isConstant()) {
738          *offset = add_instr->operands[i].constantValue();
739       } else if (add_instr->operands[i].isTemp() &&
740                  ctx.info[add_instr->operands[i].tempId()].is_constant_or_literal(32)) {
741          *offset = ctx.info[add_instr->operands[i].tempId()].val;
742       } else {
743          continue;
744       }
745       if (!add_instr->operands[!i].isTemp())
746          continue;
747 
748       uint32_t offset2 = 0;
749       if (parse_base_offset(ctx, add_instr, !i, base, &offset2, prevent_overflow)) {
750          *offset += offset2;
751       } else {
752          *base = add_instr->operands[!i].getTemp();
753       }
754       return true;
755    }
756 
757    return false;
758 }
759 
760 unsigned
get_operand_size(aco_ptr<Instruction> & instr,unsigned index)761 get_operand_size(aco_ptr<Instruction>& instr, unsigned index)
762 {
763    if (instr->isPseudo())
764       return instr->operands[index].bytes() * 8u;
765    else if (instr->opcode == aco_opcode::v_mad_u64_u32 ||
766             instr->opcode == aco_opcode::v_mad_i64_i32)
767       return index == 2 ? 64 : 32;
768    else if (instr->isVALU() || instr->isSALU())
769       return instr_info.operand_size[(int)instr->opcode];
770    else
771       return 0;
772 }
773 
774 Operand
get_constant_op(opt_ctx & ctx,ssa_info info,uint32_t bits)775 get_constant_op(opt_ctx& ctx, ssa_info info, uint32_t bits)
776 {
777    if (bits == 64)
778       return Operand::c32_or_c64(info.val, true);
779    return Operand::get_const(ctx.program->chip_class, info.val, bits / 8u);
780 }
781 
782 bool
fixed_to_exec(Operand op)783 fixed_to_exec(Operand op)
784 {
785    return op.isFixed() && op.physReg() == exec;
786 }
787 
788 SubdwordSel
parse_extract(Instruction * instr)789 parse_extract(Instruction* instr)
790 {
791    if (instr->opcode == aco_opcode::p_extract) {
792       unsigned size = instr->operands[2].constantValue() / 8;
793       unsigned offset = instr->operands[1].constantValue() * size;
794       bool sext = instr->operands[3].constantEquals(1);
795       return SubdwordSel(size, offset, sext);
796    } else if (instr->opcode == aco_opcode::p_insert && instr->operands[1].constantEquals(0)) {
797       return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
798    } else {
799       return SubdwordSel();
800    }
801 }
802 
803 SubdwordSel
parse_insert(Instruction * instr)804 parse_insert(Instruction* instr)
805 {
806    if (instr->opcode == aco_opcode::p_extract && instr->operands[3].constantEquals(0) &&
807        instr->operands[1].constantEquals(0)) {
808       return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
809    } else if (instr->opcode == aco_opcode::p_insert) {
810       unsigned size = instr->operands[2].constantValue() / 8;
811       unsigned offset = instr->operands[1].constantValue() * size;
812       return SubdwordSel(size, offset, false);
813    } else {
814       return SubdwordSel();
815    }
816 }
817 
818 bool
can_apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)819 can_apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
820 {
821    if (idx >= 2)
822       return false;
823 
824    Temp tmp = info.instr->operands[0].getTemp();
825    SubdwordSel sel = parse_extract(info.instr);
826 
827    if (!sel) {
828       return false;
829    } else if (sel.size() == 4) {
830       return true;
831    } else if (instr->opcode == aco_opcode::v_cvt_f32_u32 && sel.size() == 1 && !sel.sign_extend()) {
832       return true;
833    } else if (can_use_SDWA(ctx.program->chip_class, instr, true) &&
834               (tmp.type() == RegType::vgpr || ctx.program->chip_class >= GFX9)) {
835       if (instr->isSDWA() && instr->sdwa().sel[idx] != SubdwordSel::dword)
836          return false;
837       return true;
838    } else if (instr->isVOP3() && sel.size() == 2 &&
839               can_use_opsel(ctx.program->chip_class, instr->opcode, idx, sel.offset()) &&
840               !(instr->vop3().opsel & (1 << idx))) {
841       return true;
842    } else {
843       return false;
844    }
845 }
846 
847 /* Combine an p_extract (or p_insert, in some cases) instruction with instr.
848  * instr(p_extract(...)) -> instr()
849  */
850 void
apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)851 apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
852 {
853    Temp tmp = info.instr->operands[0].getTemp();
854    SubdwordSel sel = parse_extract(info.instr);
855    assert(sel);
856 
857    instr->operands[idx].set16bit(false);
858    instr->operands[idx].set24bit(false);
859 
860    ctx.info[tmp.id()].label &= ~label_insert;
861 
862    if (sel.size() == 4) {
863       /* full dword selection */
864    } else if (instr->opcode == aco_opcode::v_cvt_f32_u32 && sel.size() == 1 && !sel.sign_extend()) {
865       switch (sel.offset()) {
866       case 0: instr->opcode = aco_opcode::v_cvt_f32_ubyte0; break;
867       case 1: instr->opcode = aco_opcode::v_cvt_f32_ubyte1; break;
868       case 2: instr->opcode = aco_opcode::v_cvt_f32_ubyte2; break;
869       case 3: instr->opcode = aco_opcode::v_cvt_f32_ubyte3; break;
870       }
871    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() &&
872               sel.offset() == 0 &&
873               ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) ||
874                (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) {
875       /* The undesireable upper bits are already shifted out. */
876       return;
877    } else if (can_use_SDWA(ctx.program->chip_class, instr, true) &&
878               (tmp.type() == RegType::vgpr || ctx.program->chip_class >= GFX9)) {
879       to_SDWA(ctx, instr);
880       static_cast<SDWA_instruction*>(instr.get())->sel[idx] = sel;
881    } else if (instr->isVOP3()) {
882       if (sel.offset())
883          instr->vop3().opsel |= 1 << idx;
884    }
885 
886    /* label_vopc seems to be the only one worth keeping at the moment */
887    for (Definition& def : instr->definitions)
888       ctx.info[def.tempId()].label &= label_vopc;
889 }
890 
891 void
check_sdwa_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr)892 check_sdwa_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr)
893 {
894    for (unsigned i = 0; i < instr->operands.size(); i++) {
895       Operand op = instr->operands[i];
896       if (!op.isTemp())
897          continue;
898       ssa_info& info = ctx.info[op.tempId()];
899       if (info.is_extract() && (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
900                                 op.getTemp().type() == RegType::sgpr)) {
901          if (!can_apply_extract(ctx, instr, i, info))
902             info.label &= ~label_extract;
903       }
904    }
905 }
906 
907 bool
does_fp_op_flush_denorms(opt_ctx & ctx,aco_opcode op)908 does_fp_op_flush_denorms(opt_ctx& ctx, aco_opcode op)
909 {
910    if (ctx.program->chip_class <= GFX8) {
911       switch (op) {
912       case aco_opcode::v_min_f32:
913       case aco_opcode::v_max_f32:
914       case aco_opcode::v_med3_f32:
915       case aco_opcode::v_min3_f32:
916       case aco_opcode::v_max3_f32:
917       case aco_opcode::v_min_f16:
918       case aco_opcode::v_max_f16: return false;
919       default: break;
920       }
921    }
922    return op != aco_opcode::v_cndmask_b32;
923 }
924 
925 bool
can_eliminate_fcanonicalize(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp tmp)926 can_eliminate_fcanonicalize(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp tmp)
927 {
928    float_mode* fp = &ctx.fp_mode;
929    if (ctx.info[tmp.id()].is_canonicalized() ||
930        (tmp.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
931       return true;
932 
933    aco_opcode op = instr->opcode;
934    return instr_info.can_use_input_modifiers[(int)op] && does_fp_op_flush_denorms(ctx, op);
935 }
936 
937 bool
is_copy_label(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info)938 is_copy_label(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info)
939 {
940    return info.is_temp() ||
941           (info.is_fcanonicalize() && can_eliminate_fcanonicalize(ctx, instr, info.temp));
942 }
943 
944 bool
is_op_canonicalized(opt_ctx & ctx,Operand op)945 is_op_canonicalized(opt_ctx& ctx, Operand op)
946 {
947    float_mode* fp = &ctx.fp_mode;
948    if ((op.isTemp() && ctx.info[op.tempId()].is_canonicalized()) ||
949        (op.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
950       return true;
951 
952    if (op.isConstant() || (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(32))) {
953       uint32_t val = op.isTemp() ? ctx.info[op.tempId()].val : op.constantValue();
954       if (op.bytes() == 2)
955          return (val & 0x7fff) == 0 || (val & 0x7fff) > 0x3ff;
956       else if (op.bytes() == 4)
957          return (val & 0x7fffffff) == 0 || (val & 0x7fffffff) > 0x7fffff;
958    }
959    return false;
960 }
961 
962 void
label_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)963 label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
964 {
965    if (instr->isSALU() || instr->isVALU() || instr->isPseudo()) {
966       ASSERTED bool all_const = false;
967       for (Operand& op : instr->operands)
968          all_const =
969             all_const && (!op.isTemp() || ctx.info[op.tempId()].is_constant_or_literal(32));
970       perfwarn(ctx.program, all_const, "All instruction operands are constant", instr.get());
971 
972       ASSERTED bool is_copy = instr->opcode == aco_opcode::s_mov_b32 ||
973                               instr->opcode == aco_opcode::s_mov_b64 ||
974                               instr->opcode == aco_opcode::v_mov_b32;
975       perfwarn(ctx.program, is_copy && !instr->usesModifiers(), "Use p_parallelcopy instead",
976                instr.get());
977    }
978 
979    for (unsigned i = 0; i < instr->operands.size(); i++) {
980       if (!instr->operands[i].isTemp())
981          continue;
982 
983       ssa_info info = ctx.info[instr->operands[i].tempId()];
984       /* propagate undef */
985       if (info.is_undefined() && is_phi(instr))
986          instr->operands[i] = Operand(instr->operands[i].regClass());
987       /* propagate reg->reg of same type */
988       while (info.is_temp() && info.temp.regClass() == instr->operands[i].getTemp().regClass()) {
989          instr->operands[i].setTemp(ctx.info[instr->operands[i].tempId()].temp);
990          info = ctx.info[info.temp.id()];
991       }
992 
993       /* PSEUDO: propagate temporaries */
994       if (instr->isPseudo()) {
995          while (info.is_temp()) {
996             pseudo_propagate_temp(ctx, instr, info.temp, i);
997             info = ctx.info[info.temp.id()];
998          }
999       }
1000 
1001       /* SALU / PSEUDO: propagate inline constants */
1002       if (instr->isSALU() || instr->isPseudo()) {
1003          unsigned bits = get_operand_size(instr, i);
1004          if ((info.is_constant(bits) || (info.is_literal(bits) && instr->isPseudo())) &&
1005              !instr->operands[i].isFixed() && alu_can_accept_constant(instr->opcode, i)) {
1006             instr->operands[i] = get_constant_op(ctx, info, bits);
1007             continue;
1008          }
1009       }
1010 
1011       /* VALU: propagate neg, abs & inline constants */
1012       else if (instr->isVALU()) {
1013          if (is_copy_label(ctx, instr, info) && info.temp.type() == RegType::vgpr &&
1014              valu_can_accept_vgpr(instr, i)) {
1015             instr->operands[i].setTemp(info.temp);
1016             info = ctx.info[info.temp.id()];
1017          }
1018          /* applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier */
1019          if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(ctx, instr) &&
1020              instr->operands.size() == 1) {
1021             instr->format = withoutDPP(instr->format);
1022             instr->operands[i].setTemp(info.temp);
1023             info = ctx.info[info.temp.id()];
1024          }
1025 
1026          /* for instructions other than v_cndmask_b32, the size of the instruction should match the
1027           * operand size */
1028          unsigned can_use_mod =
1029             instr->opcode != aco_opcode::v_cndmask_b32 || instr->operands[i].getTemp().bytes() == 4;
1030          can_use_mod = can_use_mod && instr_info.can_use_input_modifiers[(int)instr->opcode];
1031 
1032          if (instr->isSDWA())
1033             can_use_mod = can_use_mod && instr->sdwa().sel[i].size() == 4;
1034          else
1035             can_use_mod = can_use_mod && (instr->isDPP() || can_use_VOP3(ctx, instr));
1036 
1037          if (info.is_neg() && instr->opcode == aco_opcode::v_add_f32) {
1038             instr->opcode = i ? aco_opcode::v_sub_f32 : aco_opcode::v_subrev_f32;
1039             instr->operands[i].setTemp(info.temp);
1040          } else if (info.is_neg() && instr->opcode == aco_opcode::v_add_f16) {
1041             instr->opcode = i ? aco_opcode::v_sub_f16 : aco_opcode::v_subrev_f16;
1042             instr->operands[i].setTemp(info.temp);
1043          } else if (info.is_neg() && can_use_mod &&
1044                     can_eliminate_fcanonicalize(ctx, instr, info.temp)) {
1045             if (!instr->isDPP() && !instr->isSDWA())
1046                to_VOP3(ctx, instr);
1047             instr->operands[i].setTemp(info.temp);
1048             if (instr->isDPP() && !instr->dpp().abs[i])
1049                instr->dpp().neg[i] = true;
1050             else if (instr->isSDWA() && !instr->sdwa().abs[i])
1051                instr->sdwa().neg[i] = true;
1052             else if (instr->isVOP3() && !instr->vop3().abs[i])
1053                instr->vop3().neg[i] = true;
1054          }
1055          if (info.is_abs() && can_use_mod && can_eliminate_fcanonicalize(ctx, instr, info.temp)) {
1056             if (!instr->isDPP() && !instr->isSDWA())
1057                to_VOP3(ctx, instr);
1058             instr->operands[i] = Operand(info.temp);
1059             if (instr->isDPP())
1060                instr->dpp().abs[i] = true;
1061             else if (instr->isSDWA())
1062                instr->sdwa().abs[i] = true;
1063             else
1064                instr->vop3().abs[i] = true;
1065             continue;
1066          }
1067 
1068          unsigned bits = get_operand_size(instr, i);
1069          if (info.is_constant(bits) && alu_can_accept_constant(instr->opcode, i) &&
1070              (!instr->isSDWA() || ctx.program->chip_class >= GFX9)) {
1071             Operand op = get_constant_op(ctx, info, bits);
1072             perfwarn(ctx.program, instr->opcode == aco_opcode::v_cndmask_b32 && i == 2,
1073                      "v_cndmask_b32 with a constant selector", instr.get());
1074             if (i == 0 || instr->isSDWA() || instr->isVOP3P() ||
1075                 instr->opcode == aco_opcode::v_readlane_b32 ||
1076                 instr->opcode == aco_opcode::v_writelane_b32) {
1077                instr->format = withoutDPP(instr->format);
1078                instr->operands[i] = op;
1079                continue;
1080             } else if (!instr->isVOP3() && can_swap_operands(instr, &instr->opcode)) {
1081                instr->operands[i] = instr->operands[0];
1082                instr->operands[0] = op;
1083                continue;
1084             } else if (can_use_VOP3(ctx, instr)) {
1085                to_VOP3(ctx, instr);
1086                instr->operands[i] = op;
1087                continue;
1088             }
1089          }
1090       }
1091 
1092       /* MUBUF: propagate constants and combine additions */
1093       else if (instr->isMUBUF()) {
1094          MUBUF_instruction& mubuf = instr->mubuf();
1095          Temp base;
1096          uint32_t offset;
1097          while (info.is_temp())
1098             info = ctx.info[info.temp.id()];
1099 
1100          /* According to AMDGPUDAGToDAGISel::SelectMUBUFScratchOffen(), vaddr
1101           * overflow for scratch accesses works only on GFX9+ and saddr overflow
1102           * never works. Since swizzling is the only thing that separates
1103           * scratch accesses and other accesses and swizzling changing how
1104           * addressing works significantly, this probably applies to swizzled
1105           * MUBUF accesses. */
1106          bool vaddr_prevent_overflow = mubuf.swizzled && ctx.program->chip_class < GFX9;
1107          bool saddr_prevent_overflow = mubuf.swizzled;
1108 
1109          if (mubuf.offen && i == 1 && info.is_constant_or_literal(32) &&
1110              mubuf.offset + info.val < 4096) {
1111             assert(!mubuf.idxen);
1112             instr->operands[1] = Operand(v1);
1113             mubuf.offset += info.val;
1114             mubuf.offen = false;
1115             continue;
1116          } else if (i == 2 && info.is_constant_or_literal(32) && mubuf.offset + info.val < 4096) {
1117             instr->operands[2] = Operand::c32(0);
1118             mubuf.offset += info.val;
1119             continue;
1120          } else if (mubuf.offen && i == 1 &&
1121                     parse_base_offset(ctx, instr.get(), i, &base, &offset,
1122                                       vaddr_prevent_overflow) &&
1123                     base.regClass() == v1 && mubuf.offset + offset < 4096) {
1124             assert(!mubuf.idxen);
1125             instr->operands[1].setTemp(base);
1126             mubuf.offset += offset;
1127             continue;
1128          } else if (i == 2 &&
1129                     parse_base_offset(ctx, instr.get(), i, &base, &offset,
1130                                       saddr_prevent_overflow) &&
1131                     base.regClass() == s1 && mubuf.offset + offset < 4096) {
1132             instr->operands[i].setTemp(base);
1133             mubuf.offset += offset;
1134             continue;
1135          }
1136       }
1137 
1138       /* DS: combine additions */
1139       else if (instr->isDS()) {
1140 
1141          DS_instruction& ds = instr->ds();
1142          Temp base;
1143          uint32_t offset;
1144          bool has_usable_ds_offset = ctx.program->chip_class >= GFX7;
1145          if (has_usable_ds_offset && i == 0 &&
1146              parse_base_offset(ctx, instr.get(), i, &base, &offset, false) &&
1147              base.regClass() == instr->operands[i].regClass() &&
1148              instr->opcode != aco_opcode::ds_swizzle_b32) {
1149             if (instr->opcode == aco_opcode::ds_write2_b32 ||
1150                 instr->opcode == aco_opcode::ds_read2_b32 ||
1151                 instr->opcode == aco_opcode::ds_write2_b64 ||
1152                 instr->opcode == aco_opcode::ds_read2_b64) {
1153                unsigned mask = (instr->opcode == aco_opcode::ds_write2_b64 ||
1154                                 instr->opcode == aco_opcode::ds_read2_b64)
1155                                   ? 0x7
1156                                   : 0x3;
1157                unsigned shifts = (instr->opcode == aco_opcode::ds_write2_b64 ||
1158                                   instr->opcode == aco_opcode::ds_read2_b64)
1159                                     ? 3
1160                                     : 2;
1161 
1162                if ((offset & mask) == 0 && ds.offset0 + (offset >> shifts) <= 255 &&
1163                    ds.offset1 + (offset >> shifts) <= 255) {
1164                   instr->operands[i].setTemp(base);
1165                   ds.offset0 += offset >> shifts;
1166                   ds.offset1 += offset >> shifts;
1167                }
1168             } else {
1169                if (ds.offset0 + offset <= 65535) {
1170                   instr->operands[i].setTemp(base);
1171                   ds.offset0 += offset;
1172                }
1173             }
1174          }
1175       }
1176 
1177       /* SMEM: propagate constants and combine additions */
1178       else if (instr->isSMEM()) {
1179 
1180          SMEM_instruction& smem = instr->smem();
1181          Temp base;
1182          uint32_t offset;
1183          bool prevent_overflow = smem.operands[0].size() > 2 || smem.prevent_overflow;
1184          if (i == 1 && info.is_constant_or_literal(32) &&
1185              ((ctx.program->chip_class == GFX6 && info.val <= 0x3FF) ||
1186               (ctx.program->chip_class == GFX7 && info.val <= 0xFFFFFFFF) ||
1187               (ctx.program->chip_class >= GFX8 && info.val <= 0xFFFFF))) {
1188             instr->operands[i] = Operand::c32(info.val);
1189             continue;
1190          } else if (i == 1 &&
1191                     parse_base_offset(ctx, instr.get(), i, &base, &offset, prevent_overflow) &&
1192                     base.regClass() == s1 && offset <= 0xFFFFF && ctx.program->chip_class >= GFX9) {
1193             bool soe = smem.operands.size() >= (!smem.definitions.empty() ? 3 : 4);
1194             if (soe && (!ctx.info[smem.operands.back().tempId()].is_constant_or_literal(32) ||
1195                         ctx.info[smem.operands.back().tempId()].val != 0)) {
1196                continue;
1197             }
1198             if (soe) {
1199                smem.operands[1] = Operand::c32(offset);
1200                smem.operands.back() = Operand(base);
1201             } else {
1202                SMEM_instruction* new_instr = create_instruction<SMEM_instruction>(
1203                   smem.opcode, Format::SMEM, smem.operands.size() + 1, smem.definitions.size());
1204                new_instr->operands[0] = smem.operands[0];
1205                new_instr->operands[1] = Operand::c32(offset);
1206                if (smem.definitions.empty())
1207                   new_instr->operands[2] = smem.operands[2];
1208                new_instr->operands.back() = Operand(base);
1209                if (!smem.definitions.empty())
1210                   new_instr->definitions[0] = smem.definitions[0];
1211                new_instr->sync = smem.sync;
1212                new_instr->glc = smem.glc;
1213                new_instr->dlc = smem.dlc;
1214                new_instr->nv = smem.nv;
1215                new_instr->disable_wqm = smem.disable_wqm;
1216                instr.reset(new_instr);
1217             }
1218             continue;
1219          }
1220       }
1221 
1222       else if (instr->isBranch()) {
1223          if (ctx.info[instr->operands[0].tempId()].is_scc_invert()) {
1224             /* Flip the branch instruction to get rid of the scc_invert instruction */
1225             instr->opcode = instr->opcode == aco_opcode::p_cbranch_z ? aco_opcode::p_cbranch_nz
1226                                                                      : aco_opcode::p_cbranch_z;
1227             instr->operands[0].setTemp(ctx.info[instr->operands[0].tempId()].temp);
1228          }
1229       }
1230    }
1231 
1232    /* if this instruction doesn't define anything, return */
1233    if (instr->definitions.empty()) {
1234       check_sdwa_extract(ctx, instr);
1235       return;
1236    }
1237 
1238    if (instr->isVALU() || instr->isVINTRP()) {
1239       if (instr_info.can_use_output_modifiers[(int)instr->opcode] || instr->isVINTRP() ||
1240           instr->opcode == aco_opcode::v_cndmask_b32) {
1241          bool canonicalized = true;
1242          if (!does_fp_op_flush_denorms(ctx, instr->opcode)) {
1243             unsigned ops = instr->opcode == aco_opcode::v_cndmask_b32 ? 2 : instr->operands.size();
1244             for (unsigned i = 0; canonicalized && (i < ops); i++)
1245                canonicalized = is_op_canonicalized(ctx, instr->operands[i]);
1246          }
1247          if (canonicalized)
1248             ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1249       }
1250 
1251       if (instr->isVOPC()) {
1252          ctx.info[instr->definitions[0].tempId()].set_vopc(instr.get());
1253          check_sdwa_extract(ctx, instr);
1254          return;
1255       }
1256       if (instr->isVOP3P()) {
1257          ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
1258          return;
1259       }
1260    }
1261 
1262    switch (instr->opcode) {
1263    case aco_opcode::p_create_vector: {
1264       bool copy_prop = instr->operands.size() == 1 && instr->operands[0].isTemp() &&
1265                        instr->operands[0].regClass() == instr->definitions[0].regClass();
1266       if (copy_prop) {
1267          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1268          break;
1269       }
1270 
1271       /* expand vector operands */
1272       std::vector<Operand> ops;
1273       unsigned offset = 0;
1274       for (const Operand& op : instr->operands) {
1275          /* ensure that any expanded operands are properly aligned */
1276          bool aligned = offset % 4 == 0 || op.bytes() < 4;
1277          offset += op.bytes();
1278          if (aligned && op.isTemp() && ctx.info[op.tempId()].is_vec()) {
1279             Instruction* vec = ctx.info[op.tempId()].instr;
1280             for (const Operand& vec_op : vec->operands)
1281                ops.emplace_back(vec_op);
1282          } else {
1283             ops.emplace_back(op);
1284          }
1285       }
1286 
1287       /* combine expanded operands to new vector */
1288       if (ops.size() != instr->operands.size()) {
1289          assert(ops.size() > instr->operands.size());
1290          Definition def = instr->definitions[0];
1291          instr.reset(create_instruction<Pseudo_instruction>(aco_opcode::p_create_vector,
1292                                                             Format::PSEUDO, ops.size(), 1));
1293          for (unsigned i = 0; i < ops.size(); i++) {
1294             if (ops[i].isTemp() && ctx.info[ops[i].tempId()].is_temp() &&
1295                 ops[i].regClass() == ctx.info[ops[i].tempId()].temp.regClass())
1296                ops[i].setTemp(ctx.info[ops[i].tempId()].temp);
1297             instr->operands[i] = ops[i];
1298          }
1299          instr->definitions[0] = def;
1300       } else {
1301          for (unsigned i = 0; i < ops.size(); i++) {
1302             assert(instr->operands[i] == ops[i]);
1303          }
1304       }
1305       ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1306       break;
1307    }
1308    case aco_opcode::p_split_vector: {
1309       ssa_info& info = ctx.info[instr->operands[0].tempId()];
1310 
1311       if (info.is_constant_or_literal(32)) {
1312          uint32_t val = info.val;
1313          for (Definition def : instr->definitions) {
1314             uint32_t mask = u_bit_consecutive(0, def.bytes() * 8u);
1315             ctx.info[def.tempId()].set_constant(ctx.program->chip_class, val & mask);
1316             val >>= def.bytes() * 8u;
1317          }
1318          break;
1319       } else if (!info.is_vec()) {
1320          break;
1321       }
1322 
1323       Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1324       unsigned split_offset = 0;
1325       unsigned vec_offset = 0;
1326       unsigned vec_index = 0;
1327       for (unsigned i = 0; i < instr->definitions.size();
1328            split_offset += instr->definitions[i++].bytes()) {
1329          while (vec_offset < split_offset && vec_index < vec->operands.size())
1330             vec_offset += vec->operands[vec_index++].bytes();
1331 
1332          if (vec_offset != split_offset ||
1333              vec->operands[vec_index].bytes() != instr->definitions[i].bytes())
1334             continue;
1335 
1336          Operand vec_op = vec->operands[vec_index];
1337          if (vec_op.isConstant()) {
1338             ctx.info[instr->definitions[i].tempId()].set_constant(ctx.program->chip_class,
1339                                                                   vec_op.constantValue64());
1340          } else if (vec_op.isUndefined()) {
1341             ctx.info[instr->definitions[i].tempId()].set_undefined();
1342          } else {
1343             assert(vec_op.isTemp());
1344             ctx.info[instr->definitions[i].tempId()].set_temp(vec_op.getTemp());
1345          }
1346       }
1347       break;
1348    }
1349    case aco_opcode::p_extract_vector: { /* mov */
1350       ssa_info& info = ctx.info[instr->operands[0].tempId()];
1351       const unsigned index = instr->operands[1].constantValue();
1352       const unsigned dst_offset = index * instr->definitions[0].bytes();
1353 
1354       if (info.is_vec()) {
1355          /* check if we index directly into a vector element */
1356          Instruction* vec = info.instr;
1357          unsigned offset = 0;
1358 
1359          for (const Operand& op : vec->operands) {
1360             if (offset < dst_offset) {
1361                offset += op.bytes();
1362                continue;
1363             } else if (offset != dst_offset || op.bytes() != instr->definitions[0].bytes()) {
1364                break;
1365             }
1366             instr->operands[0] = op;
1367             break;
1368          }
1369       } else if (info.is_constant_or_literal(32)) {
1370          /* propagate constants */
1371          uint32_t mask = u_bit_consecutive(0, instr->definitions[0].bytes() * 8u);
1372          uint32_t val = (info.val >> (dst_offset * 8u)) & mask;
1373          instr->operands[0] =
1374             Operand::get_const(ctx.program->chip_class, val, instr->definitions[0].bytes());
1375          ;
1376       } else if (index == 0 && instr->operands[0].size() == instr->definitions[0].size()) {
1377          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1378       }
1379 
1380       if (instr->operands[0].bytes() != instr->definitions[0].bytes())
1381          break;
1382 
1383       /* convert this extract into a copy instruction */
1384       instr->opcode = aco_opcode::p_parallelcopy;
1385       instr->operands.pop_back();
1386       FALLTHROUGH;
1387    }
1388    case aco_opcode::p_parallelcopy: /* propagate */
1389       if (instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_vec() &&
1390           instr->operands[0].regClass() != instr->definitions[0].regClass()) {
1391          /* We might not be able to copy-propagate if it's a SGPR->VGPR copy, so
1392           * duplicate the vector instead.
1393           */
1394          Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1395          aco_ptr<Instruction> old_copy = std::move(instr);
1396 
1397          instr.reset(create_instruction<Pseudo_instruction>(
1398             aco_opcode::p_create_vector, Format::PSEUDO, vec->operands.size(), 1));
1399          instr->definitions[0] = old_copy->definitions[0];
1400          std::copy(vec->operands.begin(), vec->operands.end(), instr->operands.begin());
1401          for (unsigned i = 0; i < vec->operands.size(); i++) {
1402             Operand& op = instr->operands[i];
1403             if (op.isTemp() && ctx.info[op.tempId()].is_temp() &&
1404                 ctx.info[op.tempId()].temp.type() == instr->definitions[0].regClass().type())
1405                op.setTemp(ctx.info[op.tempId()].temp);
1406          }
1407          ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1408          break;
1409       }
1410       FALLTHROUGH;
1411    case aco_opcode::p_as_uniform:
1412       if (instr->definitions[0].isFixed()) {
1413          /* don't copy-propagate copies into fixed registers */
1414       } else if (instr->usesModifiers()) {
1415          // TODO
1416       } else if (instr->operands[0].isConstant()) {
1417          ctx.info[instr->definitions[0].tempId()].set_constant(
1418             ctx.program->chip_class, instr->operands[0].constantValue64());
1419       } else if (instr->operands[0].isTemp()) {
1420          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1421          if (ctx.info[instr->operands[0].tempId()].is_canonicalized())
1422             ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1423       } else {
1424          assert(instr->operands[0].isFixed());
1425       }
1426       break;
1427    case aco_opcode::v_mov_b32:
1428       if (instr->isDPP()) {
1429          /* anything else doesn't make sense in SSA */
1430          assert(instr->dpp().row_mask == 0xf && instr->dpp().bank_mask == 0xf);
1431          ctx.info[instr->definitions[0].tempId()].set_dpp(instr.get());
1432       }
1433       break;
1434    case aco_opcode::p_is_helper:
1435       if (!ctx.program->needs_wqm)
1436          ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->chip_class, 0u);
1437       break;
1438    case aco_opcode::v_mul_f64: ctx.info[instr->definitions[0].tempId()].set_mul(instr.get()); break;
1439    case aco_opcode::v_mul_f16:
1440    case aco_opcode::v_mul_f32: { /* omod */
1441       ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
1442 
1443       /* TODO: try to move the negate/abs modifier to the consumer instead */
1444       bool uses_mods = instr->usesModifiers();
1445       bool fp16 = instr->opcode == aco_opcode::v_mul_f16;
1446 
1447       for (unsigned i = 0; i < 2; i++) {
1448          if (instr->operands[!i].isConstant() && instr->operands[i].isTemp()) {
1449             if (!instr->isDPP() && !instr->isSDWA() &&
1450                 (instr->operands[!i].constantEquals(fp16 ? 0x3c00 : 0x3f800000) ||   /* 1.0 */
1451                  instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u))) { /* -1.0 */
1452                bool neg1 = instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u);
1453 
1454                VOP3_instruction* vop3 = instr->isVOP3() ? &instr->vop3() : NULL;
1455                if (vop3 && (vop3->abs[!i] || vop3->neg[!i] || vop3->clamp || vop3->omod))
1456                   continue;
1457 
1458                bool abs = vop3 && vop3->abs[i];
1459                bool neg = neg1 ^ (vop3 && vop3->neg[i]);
1460 
1461                Temp other = instr->operands[i].getTemp();
1462                if (abs && neg && other.type() == RegType::vgpr)
1463                   ctx.info[instr->definitions[0].tempId()].set_neg_abs(other);
1464                else if (abs && !neg && other.type() == RegType::vgpr)
1465                   ctx.info[instr->definitions[0].tempId()].set_abs(other);
1466                else if (!abs && neg && other.type() == RegType::vgpr)
1467                   ctx.info[instr->definitions[0].tempId()].set_neg(other);
1468                else if (!abs && !neg)
1469                   ctx.info[instr->definitions[0].tempId()].set_fcanonicalize(other);
1470             } else if (uses_mods) {
1471                continue;
1472             } else if (instr->operands[!i].constantValue() ==
1473                        (fp16 ? 0x4000 : 0x40000000)) { /* 2.0 */
1474                ctx.info[instr->operands[i].tempId()].set_omod2(instr.get());
1475             } else if (instr->operands[!i].constantValue() ==
1476                        (fp16 ? 0x4400 : 0x40800000)) { /* 4.0 */
1477                ctx.info[instr->operands[i].tempId()].set_omod4(instr.get());
1478             } else if (instr->operands[!i].constantValue() ==
1479                        (fp16 ? 0x3800 : 0x3f000000)) { /* 0.5 */
1480                ctx.info[instr->operands[i].tempId()].set_omod5(instr.get());
1481             } else if (instr->operands[!i].constantValue() == 0u &&
1482                        !(fp16 ? ctx.fp_mode.preserve_signed_zero_inf_nan16_64
1483                               : ctx.fp_mode.preserve_signed_zero_inf_nan32)) { /* 0.0 */
1484                ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->chip_class, 0u);
1485             } else {
1486                continue;
1487             }
1488             break;
1489          }
1490       }
1491       break;
1492    }
1493    case aco_opcode::v_mul_lo_u16:
1494    case aco_opcode::v_mul_lo_u16_e64:
1495    case aco_opcode::v_mul_u32_u24:
1496       ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
1497       break;
1498    case aco_opcode::v_med3_f16:
1499    case aco_opcode::v_med3_f32: { /* clamp */
1500       VOP3_instruction& vop3 = instr->vop3();
1501       if (vop3.abs[0] || vop3.abs[1] || vop3.abs[2] || vop3.neg[0] || vop3.neg[1] || vop3.neg[2] ||
1502           vop3.omod != 0 || vop3.opsel != 0)
1503          break;
1504 
1505       unsigned idx = 0;
1506       bool found_zero = false, found_one = false;
1507       bool is_fp16 = instr->opcode == aco_opcode::v_med3_f16;
1508       for (unsigned i = 0; i < 3; i++) {
1509          if (instr->operands[i].constantEquals(0))
1510             found_zero = true;
1511          else if (instr->operands[i].constantEquals(is_fp16 ? 0x3c00 : 0x3f800000)) /* 1.0 */
1512             found_one = true;
1513          else
1514             idx = i;
1515       }
1516       if (found_zero && found_one && instr->operands[idx].isTemp())
1517          ctx.info[instr->operands[idx].tempId()].set_clamp(instr.get());
1518       break;
1519    }
1520    case aco_opcode::v_cndmask_b32:
1521       if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(0xFFFFFFFF))
1522          ctx.info[instr->definitions[0].tempId()].set_vcc(instr->operands[2].getTemp());
1523       else if (instr->operands[0].constantEquals(0) &&
1524                instr->operands[1].constantEquals(0x3f800000u))
1525          ctx.info[instr->definitions[0].tempId()].set_b2f(instr->operands[2].getTemp());
1526       else if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(1))
1527          ctx.info[instr->definitions[0].tempId()].set_b2i(instr->operands[2].getTemp());
1528 
1529       ctx.info[instr->operands[2].tempId()].set_vcc_hint();
1530       break;
1531    case aco_opcode::v_cmp_lg_u32:
1532       if (instr->format == Format::VOPC && /* don't optimize VOP3 / SDWA / DPP */
1533           instr->operands[0].constantEquals(0) && instr->operands[1].isTemp() &&
1534           ctx.info[instr->operands[1].tempId()].is_vcc())
1535          ctx.info[instr->definitions[0].tempId()].set_temp(
1536             ctx.info[instr->operands[1].tempId()].temp);
1537       break;
1538    case aco_opcode::p_linear_phi: {
1539       /* lower_bool_phis() can create phis like this */
1540       bool all_same_temp = instr->operands[0].isTemp();
1541       /* this check is needed when moving uniform loop counters out of a divergent loop */
1542       if (all_same_temp)
1543          all_same_temp = instr->definitions[0].regClass() == instr->operands[0].regClass();
1544       for (unsigned i = 1; all_same_temp && (i < instr->operands.size()); i++) {
1545          if (!instr->operands[i].isTemp() ||
1546              instr->operands[i].tempId() != instr->operands[0].tempId())
1547             all_same_temp = false;
1548       }
1549       if (all_same_temp) {
1550          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1551       } else {
1552          bool all_undef = instr->operands[0].isUndefined();
1553          for (unsigned i = 1; all_undef && (i < instr->operands.size()); i++) {
1554             if (!instr->operands[i].isUndefined())
1555                all_undef = false;
1556          }
1557          if (all_undef)
1558             ctx.info[instr->definitions[0].tempId()].set_undefined();
1559       }
1560       break;
1561    }
1562    case aco_opcode::v_add_u32:
1563    case aco_opcode::v_add_co_u32:
1564    case aco_opcode::v_add_co_u32_e64:
1565    case aco_opcode::s_add_i32:
1566    case aco_opcode::s_add_u32:
1567    case aco_opcode::v_subbrev_co_u32:
1568       ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
1569       break;
1570    case aco_opcode::s_not_b32:
1571    case aco_opcode::s_not_b64:
1572       if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
1573          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1574          ctx.info[instr->definitions[1].tempId()].set_scc_invert(
1575             ctx.info[instr->operands[0].tempId()].temp);
1576       } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
1577          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1578          ctx.info[instr->definitions[1].tempId()].set_scc_invert(
1579             ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1580       }
1581       ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
1582       break;
1583    case aco_opcode::s_and_b32:
1584    case aco_opcode::s_and_b64:
1585       if (fixed_to_exec(instr->operands[1]) && instr->operands[0].isTemp()) {
1586          if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
1587             /* Try to get rid of the superfluous s_cselect + s_and_b64 that comes from turning a
1588              * uniform bool into divergent */
1589             ctx.info[instr->definitions[1].tempId()].set_temp(
1590                ctx.info[instr->operands[0].tempId()].temp);
1591             ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
1592                ctx.info[instr->operands[0].tempId()].temp);
1593             break;
1594          } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
1595             /* Try to get rid of the superfluous s_and_b64, since the uniform bitwise instruction
1596              * already produces the same SCC */
1597             ctx.info[instr->definitions[1].tempId()].set_temp(
1598                ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1599             ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
1600                ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1601             break;
1602          } else if ((ctx.program->stage.num_sw_stages() > 1 ||
1603                      ctx.program->stage.hw == HWStage::NGG) &&
1604                     instr->pass_flags == 1) {
1605             /* In case of merged shaders, pass_flags=1 means that all lanes are active (exec=-1), so
1606              * s_and is unnecessary. */
1607             ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1608             break;
1609          } else if (ctx.info[instr->operands[0].tempId()].is_vopc()) {
1610             Instruction* vopc_instr = ctx.info[instr->operands[0].tempId()].instr;
1611             /* Remove superfluous s_and when the VOPC instruction uses the same exec and thus
1612              * already produces the same result */
1613             if (vopc_instr->pass_flags == instr->pass_flags) {
1614                assert(instr->pass_flags > 0);
1615                ctx.info[instr->definitions[0].tempId()].set_temp(
1616                   vopc_instr->definitions[0].getTemp());
1617                break;
1618             }
1619          }
1620       }
1621       FALLTHROUGH;
1622    case aco_opcode::s_or_b32:
1623    case aco_opcode::s_or_b64:
1624    case aco_opcode::s_xor_b32:
1625    case aco_opcode::s_xor_b64:
1626       if (std::all_of(instr->operands.begin(), instr->operands.end(),
1627                       [&ctx](const Operand& op)
1628                       {
1629                          return op.isTemp() && (ctx.info[op.tempId()].is_uniform_bool() ||
1630                                                 ctx.info[op.tempId()].is_uniform_bitwise());
1631                       })) {
1632          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1633       }
1634       FALLTHROUGH;
1635    case aco_opcode::s_lshl_b32:
1636    case aco_opcode::v_or_b32:
1637    case aco_opcode::v_lshlrev_b32:
1638    case aco_opcode::v_bcnt_u32_b32:
1639    case aco_opcode::v_and_b32:
1640    case aco_opcode::v_xor_b32:
1641       ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
1642       break;
1643    case aco_opcode::v_min_f32:
1644    case aco_opcode::v_min_f16:
1645    case aco_opcode::v_min_u32:
1646    case aco_opcode::v_min_i32:
1647    case aco_opcode::v_min_u16:
1648    case aco_opcode::v_min_i16:
1649    case aco_opcode::v_max_f32:
1650    case aco_opcode::v_max_f16:
1651    case aco_opcode::v_max_u32:
1652    case aco_opcode::v_max_i32:
1653    case aco_opcode::v_max_u16:
1654    case aco_opcode::v_max_i16:
1655       ctx.info[instr->definitions[0].tempId()].set_minmax(instr.get());
1656       break;
1657    case aco_opcode::s_cselect_b64:
1658    case aco_opcode::s_cselect_b32:
1659       if (instr->operands[0].constantEquals((unsigned)-1) && instr->operands[1].constantEquals(0)) {
1660          /* Found a cselect that operates on a uniform bool that comes from eg. s_cmp */
1661          ctx.info[instr->definitions[0].tempId()].set_uniform_bool(instr->operands[2].getTemp());
1662       }
1663       if (instr->operands[2].isTemp() && ctx.info[instr->operands[2].tempId()].is_scc_invert()) {
1664          /* Flip the operands to get rid of the scc_invert instruction */
1665          std::swap(instr->operands[0], instr->operands[1]);
1666          instr->operands[2].setTemp(ctx.info[instr->operands[2].tempId()].temp);
1667       }
1668       break;
1669    case aco_opcode::p_wqm:
1670       if (instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_scc_invert()) {
1671          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1672       }
1673       break;
1674    case aco_opcode::s_mul_i32:
1675       /* Testing every uint32_t shows that 0x3f800000*n is never a denormal.
1676        * This pattern is created from a uniform nir_op_b2f. */
1677       if (instr->operands[0].constantEquals(0x3f800000u))
1678          ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1679       break;
1680    case aco_opcode::p_extract: {
1681       if (instr->definitions[0].bytes() == 4) {
1682          ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
1683          if (instr->operands[0].regClass() == v1 && parse_insert(instr.get()))
1684             ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
1685       }
1686       break;
1687    }
1688    case aco_opcode::p_insert: {
1689       if (instr->operands[0].bytes() == 4) {
1690          if (instr->operands[0].regClass() == v1)
1691             ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
1692          if (parse_extract(instr.get()))
1693             ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
1694          ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
1695       }
1696       break;
1697    }
1698    case aco_opcode::ds_read_u8:
1699    case aco_opcode::ds_read_u8_d16:
1700    case aco_opcode::ds_read_u16:
1701    case aco_opcode::ds_read_u16_d16: {
1702       ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
1703       break;
1704    }
1705    default: break;
1706    }
1707 
1708    /* Don't remove label_extract if we can't apply the extract to
1709     * neg/abs instructions because we'll likely combine it into another valu. */
1710    if (!(ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)))
1711       check_sdwa_extract(ctx, instr);
1712 }
1713 
1714 unsigned
original_temp_id(opt_ctx & ctx,Temp tmp)1715 original_temp_id(opt_ctx& ctx, Temp tmp)
1716 {
1717    if (ctx.info[tmp.id()].is_temp())
1718       return ctx.info[tmp.id()].temp.id();
1719    else
1720       return tmp.id();
1721 }
1722 
1723 void
decrease_uses(opt_ctx & ctx,Instruction * instr)1724 decrease_uses(opt_ctx& ctx, Instruction* instr)
1725 {
1726    if (!--ctx.uses[instr->definitions[0].tempId()]) {
1727       for (const Operand& op : instr->operands) {
1728          if (op.isTemp())
1729             ctx.uses[op.tempId()]--;
1730       }
1731    }
1732 }
1733 
1734 Instruction*
follow_operand(opt_ctx & ctx,Operand op,bool ignore_uses=false)1735 follow_operand(opt_ctx& ctx, Operand op, bool ignore_uses = false)
1736 {
1737    if (!op.isTemp() || !(ctx.info[op.tempId()].label & instr_usedef_labels))
1738       return nullptr;
1739    if (!ignore_uses && ctx.uses[op.tempId()] > 1)
1740       return nullptr;
1741 
1742    Instruction* instr = ctx.info[op.tempId()].instr;
1743 
1744    if (instr->definitions.size() == 2) {
1745       assert(instr->definitions[0].isTemp() && instr->definitions[0].tempId() == op.tempId());
1746       if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
1747          return nullptr;
1748    }
1749 
1750    return instr;
1751 }
1752 
1753 /* s_or_b64(neq(a, a), neq(b, b)) -> v_cmp_u_f32(a, b)
1754  * s_and_b64(eq(a, a), eq(b, b)) -> v_cmp_o_f32(a, b) */
1755 bool
combine_ordering_test(opt_ctx & ctx,aco_ptr<Instruction> & instr)1756 combine_ordering_test(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1757 {
1758    if (instr->definitions[0].regClass() != ctx.program->lane_mask)
1759       return false;
1760    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
1761       return false;
1762 
1763    bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
1764 
1765    bool neg[2] = {false, false};
1766    bool abs[2] = {false, false};
1767    uint8_t opsel = 0;
1768    Instruction* op_instr[2];
1769    Temp op[2];
1770 
1771    unsigned bitsize = 0;
1772    for (unsigned i = 0; i < 2; i++) {
1773       op_instr[i] = follow_operand(ctx, instr->operands[i], true);
1774       if (!op_instr[i])
1775          return false;
1776 
1777       aco_opcode expected_cmp = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
1778       unsigned op_bitsize = get_cmp_bitsize(op_instr[i]->opcode);
1779 
1780       if (get_f32_cmp(op_instr[i]->opcode) != expected_cmp)
1781          return false;
1782       if (bitsize && op_bitsize != bitsize)
1783          return false;
1784       if (!op_instr[i]->operands[0].isTemp() || !op_instr[i]->operands[1].isTemp())
1785          return false;
1786 
1787       if (op_instr[i]->isVOP3()) {
1788          VOP3_instruction& vop3 = op_instr[i]->vop3();
1789          if (vop3.neg[0] != vop3.neg[1] || vop3.abs[0] != vop3.abs[1] || vop3.opsel == 1 ||
1790              vop3.opsel == 2)
1791             return false;
1792          neg[i] = vop3.neg[0];
1793          abs[i] = vop3.abs[0];
1794          opsel |= (vop3.opsel & 1) << i;
1795       } else if (op_instr[i]->isSDWA()) {
1796          return false;
1797       }
1798 
1799       Temp op0 = op_instr[i]->operands[0].getTemp();
1800       Temp op1 = op_instr[i]->operands[1].getTemp();
1801       if (original_temp_id(ctx, op0) != original_temp_id(ctx, op1))
1802          return false;
1803 
1804       op[i] = op1;
1805       bitsize = op_bitsize;
1806    }
1807 
1808    if (op[1].type() == RegType::sgpr)
1809       std::swap(op[0], op[1]);
1810    unsigned num_sgprs = (op[0].type() == RegType::sgpr) + (op[1].type() == RegType::sgpr);
1811    if (num_sgprs > (ctx.program->chip_class >= GFX10 ? 2 : 1))
1812       return false;
1813 
1814    ctx.uses[op[0].id()]++;
1815    ctx.uses[op[1].id()]++;
1816    decrease_uses(ctx, op_instr[0]);
1817    decrease_uses(ctx, op_instr[1]);
1818 
1819    aco_opcode new_op = aco_opcode::num_opcodes;
1820    switch (bitsize) {
1821    case 16: new_op = is_or ? aco_opcode::v_cmp_u_f16 : aco_opcode::v_cmp_o_f16; break;
1822    case 32: new_op = is_or ? aco_opcode::v_cmp_u_f32 : aco_opcode::v_cmp_o_f32; break;
1823    case 64: new_op = is_or ? aco_opcode::v_cmp_u_f64 : aco_opcode::v_cmp_o_f64; break;
1824    }
1825    Instruction* new_instr;
1826    if (neg[0] || neg[1] || abs[0] || abs[1] || opsel || num_sgprs > 1) {
1827       VOP3_instruction* vop3 =
1828          create_instruction<VOP3_instruction>(new_op, asVOP3(Format::VOPC), 2, 1);
1829       for (unsigned i = 0; i < 2; i++) {
1830          vop3->neg[i] = neg[i];
1831          vop3->abs[i] = abs[i];
1832       }
1833       vop3->opsel = opsel;
1834       new_instr = static_cast<Instruction*>(vop3);
1835    } else {
1836       new_instr = create_instruction<VOPC_instruction>(new_op, Format::VOPC, 2, 1);
1837       instr->definitions[0].setHint(vcc);
1838    }
1839    new_instr->operands[0] = Operand(op[0]);
1840    new_instr->operands[1] = Operand(op[1]);
1841    new_instr->definitions[0] = instr->definitions[0];
1842 
1843    ctx.info[instr->definitions[0].tempId()].label = 0;
1844    ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
1845 
1846    instr.reset(new_instr);
1847 
1848    return true;
1849 }
1850 
1851 /* s_or_b64(v_cmp_u_f32(a, b), cmp(a, b)) -> get_unordered(cmp)(a, b)
1852  * s_and_b64(v_cmp_o_f32(a, b), cmp(a, b)) -> get_ordered(cmp)(a, b) */
1853 bool
combine_comparison_ordering(opt_ctx & ctx,aco_ptr<Instruction> & instr)1854 combine_comparison_ordering(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1855 {
1856    if (instr->definitions[0].regClass() != ctx.program->lane_mask)
1857       return false;
1858    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
1859       return false;
1860 
1861    bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
1862    aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_u_f32 : aco_opcode::v_cmp_o_f32;
1863 
1864    Instruction* nan_test = follow_operand(ctx, instr->operands[0], true);
1865    Instruction* cmp = follow_operand(ctx, instr->operands[1], true);
1866    if (!nan_test || !cmp)
1867       return false;
1868    if (nan_test->isSDWA() || cmp->isSDWA())
1869       return false;
1870 
1871    if (get_f32_cmp(cmp->opcode) == expected_nan_test)
1872       std::swap(nan_test, cmp);
1873    else if (get_f32_cmp(nan_test->opcode) != expected_nan_test)
1874       return false;
1875 
1876    if (!is_cmp(cmp->opcode) || get_cmp_bitsize(cmp->opcode) != get_cmp_bitsize(nan_test->opcode))
1877       return false;
1878 
1879    if (!nan_test->operands[0].isTemp() || !nan_test->operands[1].isTemp())
1880       return false;
1881    if (!cmp->operands[0].isTemp() || !cmp->operands[1].isTemp())
1882       return false;
1883 
1884    unsigned prop_cmp0 = original_temp_id(ctx, cmp->operands[0].getTemp());
1885    unsigned prop_cmp1 = original_temp_id(ctx, cmp->operands[1].getTemp());
1886    unsigned prop_nan0 = original_temp_id(ctx, nan_test->operands[0].getTemp());
1887    unsigned prop_nan1 = original_temp_id(ctx, nan_test->operands[1].getTemp());
1888    if (prop_cmp0 != prop_nan0 && prop_cmp0 != prop_nan1)
1889       return false;
1890    if (prop_cmp1 != prop_nan0 && prop_cmp1 != prop_nan1)
1891       return false;
1892 
1893    ctx.uses[cmp->operands[0].tempId()]++;
1894    ctx.uses[cmp->operands[1].tempId()]++;
1895    decrease_uses(ctx, nan_test);
1896    decrease_uses(ctx, cmp);
1897 
1898    aco_opcode new_op = is_or ? get_unordered(cmp->opcode) : get_ordered(cmp->opcode);
1899    Instruction* new_instr;
1900    if (cmp->isVOP3()) {
1901       VOP3_instruction* new_vop3 =
1902          create_instruction<VOP3_instruction>(new_op, asVOP3(Format::VOPC), 2, 1);
1903       VOP3_instruction& cmp_vop3 = cmp->vop3();
1904       memcpy(new_vop3->abs, cmp_vop3.abs, sizeof(new_vop3->abs));
1905       memcpy(new_vop3->neg, cmp_vop3.neg, sizeof(new_vop3->neg));
1906       new_vop3->clamp = cmp_vop3.clamp;
1907       new_vop3->omod = cmp_vop3.omod;
1908       new_vop3->opsel = cmp_vop3.opsel;
1909       new_instr = new_vop3;
1910    } else {
1911       new_instr = create_instruction<VOPC_instruction>(new_op, Format::VOPC, 2, 1);
1912       instr->definitions[0].setHint(vcc);
1913    }
1914    new_instr->operands[0] = cmp->operands[0];
1915    new_instr->operands[1] = cmp->operands[1];
1916    new_instr->definitions[0] = instr->definitions[0];
1917 
1918    ctx.info[instr->definitions[0].tempId()].label = 0;
1919    ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
1920 
1921    instr.reset(new_instr);
1922 
1923    return true;
1924 }
1925 
1926 bool
is_operand_constant(opt_ctx & ctx,Operand op,unsigned bit_size,uint64_t * value)1927 is_operand_constant(opt_ctx& ctx, Operand op, unsigned bit_size, uint64_t* value)
1928 {
1929    if (op.isConstant()) {
1930       *value = op.constantValue64();
1931       return true;
1932    } else if (op.isTemp()) {
1933       unsigned id = original_temp_id(ctx, op.getTemp());
1934       if (!ctx.info[id].is_constant_or_literal(bit_size))
1935          return false;
1936       *value = get_constant_op(ctx, ctx.info[id], bit_size).constantValue64();
1937       return true;
1938    }
1939    return false;
1940 }
1941 
1942 bool
is_constant_nan(uint64_t value,unsigned bit_size)1943 is_constant_nan(uint64_t value, unsigned bit_size)
1944 {
1945    if (bit_size == 16)
1946       return ((value >> 10) & 0x1f) == 0x1f && (value & 0x3ff);
1947    else if (bit_size == 32)
1948       return ((value >> 23) & 0xff) == 0xff && (value & 0x7fffff);
1949    else
1950       return ((value >> 52) & 0x7ff) == 0x7ff && (value & 0xfffffffffffff);
1951 }
1952 
1953 /* s_or_b64(v_cmp_neq_f32(a, a), cmp(a, #b)) and b is not NaN -> get_unordered(cmp)(a, b)
1954  * s_and_b64(v_cmp_eq_f32(a, a), cmp(a, #b)) and b is not NaN -> get_ordered(cmp)(a, b) */
1955 bool
combine_constant_comparison_ordering(opt_ctx & ctx,aco_ptr<Instruction> & instr)1956 combine_constant_comparison_ordering(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1957 {
1958    if (instr->definitions[0].regClass() != ctx.program->lane_mask)
1959       return false;
1960    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
1961       return false;
1962 
1963    bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
1964 
1965    Instruction* nan_test = follow_operand(ctx, instr->operands[0], true);
1966    Instruction* cmp = follow_operand(ctx, instr->operands[1], true);
1967 
1968    if (!nan_test || !cmp || nan_test->isSDWA() || cmp->isSDWA())
1969       return false;
1970    if (nan_test->isSDWA() || cmp->isSDWA())
1971       return false;
1972 
1973    aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
1974    if (get_f32_cmp(cmp->opcode) == expected_nan_test)
1975       std::swap(nan_test, cmp);
1976    else if (get_f32_cmp(nan_test->opcode) != expected_nan_test)
1977       return false;
1978 
1979    unsigned bit_size = get_cmp_bitsize(cmp->opcode);
1980    if (!is_cmp(cmp->opcode) || get_cmp_bitsize(nan_test->opcode) != bit_size)
1981       return false;
1982 
1983    if (!nan_test->operands[0].isTemp() || !nan_test->operands[1].isTemp())
1984       return false;
1985    if (!cmp->operands[0].isTemp() && !cmp->operands[1].isTemp())
1986       return false;
1987 
1988    unsigned prop_nan0 = original_temp_id(ctx, nan_test->operands[0].getTemp());
1989    unsigned prop_nan1 = original_temp_id(ctx, nan_test->operands[1].getTemp());
1990    if (prop_nan0 != prop_nan1)
1991       return false;
1992 
1993    if (nan_test->isVOP3()) {
1994       VOP3_instruction& vop3 = nan_test->vop3();
1995       if (vop3.neg[0] != vop3.neg[1] || vop3.abs[0] != vop3.abs[1] || vop3.opsel == 1 ||
1996           vop3.opsel == 2)
1997          return false;
1998    }
1999 
2000    int constant_operand = -1;
2001    for (unsigned i = 0; i < 2; i++) {
2002       if (cmp->operands[i].isTemp() &&
2003           original_temp_id(ctx, cmp->operands[i].getTemp()) == prop_nan0) {
2004          constant_operand = !i;
2005          break;
2006       }
2007    }
2008    if (constant_operand == -1)
2009       return false;
2010 
2011    uint64_t constant_value;
2012    if (!is_operand_constant(ctx, cmp->operands[constant_operand], bit_size, &constant_value))
2013       return false;
2014    if (is_constant_nan(constant_value, bit_size))
2015       return false;
2016 
2017    if (cmp->operands[0].isTemp())
2018       ctx.uses[cmp->operands[0].tempId()]++;
2019    if (cmp->operands[1].isTemp())
2020       ctx.uses[cmp->operands[1].tempId()]++;
2021    decrease_uses(ctx, nan_test);
2022    decrease_uses(ctx, cmp);
2023 
2024    aco_opcode new_op = is_or ? get_unordered(cmp->opcode) : get_ordered(cmp->opcode);
2025    Instruction* new_instr;
2026    if (cmp->isVOP3()) {
2027       VOP3_instruction* new_vop3 =
2028          create_instruction<VOP3_instruction>(new_op, asVOP3(Format::VOPC), 2, 1);
2029       VOP3_instruction& cmp_vop3 = cmp->vop3();
2030       memcpy(new_vop3->abs, cmp_vop3.abs, sizeof(new_vop3->abs));
2031       memcpy(new_vop3->neg, cmp_vop3.neg, sizeof(new_vop3->neg));
2032       new_vop3->clamp = cmp_vop3.clamp;
2033       new_vop3->omod = cmp_vop3.omod;
2034       new_vop3->opsel = cmp_vop3.opsel;
2035       new_instr = new_vop3;
2036    } else {
2037       new_instr = create_instruction<VOPC_instruction>(new_op, Format::VOPC, 2, 1);
2038       instr->definitions[0].setHint(vcc);
2039    }
2040    new_instr->operands[0] = cmp->operands[0];
2041    new_instr->operands[1] = cmp->operands[1];
2042    new_instr->definitions[0] = instr->definitions[0];
2043 
2044    ctx.info[instr->definitions[0].tempId()].label = 0;
2045    ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2046 
2047    instr.reset(new_instr);
2048 
2049    return true;
2050 }
2051 
2052 /* s_andn2(exec, cmp(a, b)) -> get_inverse(cmp)(a, b) */
2053 bool
combine_inverse_comparison(opt_ctx & ctx,aco_ptr<Instruction> & instr)2054 combine_inverse_comparison(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2055 {
2056    if (!instr->operands[0].isFixed() || instr->operands[0].physReg() != exec)
2057       return false;
2058    if (ctx.uses[instr->definitions[1].tempId()])
2059       return false;
2060 
2061    Instruction* cmp = follow_operand(ctx, instr->operands[1]);
2062    if (!cmp)
2063       return false;
2064 
2065    aco_opcode new_opcode = get_inverse(cmp->opcode);
2066    if (new_opcode == aco_opcode::num_opcodes)
2067       return false;
2068 
2069    if (cmp->operands[0].isTemp())
2070       ctx.uses[cmp->operands[0].tempId()]++;
2071    if (cmp->operands[1].isTemp())
2072       ctx.uses[cmp->operands[1].tempId()]++;
2073    decrease_uses(ctx, cmp);
2074 
2075    /* This creates a new instruction instead of modifying the existing
2076     * comparison so that the comparison is done with the correct exec mask. */
2077    Instruction* new_instr;
2078    if (cmp->isVOP3()) {
2079       VOP3_instruction* new_vop3 =
2080          create_instruction<VOP3_instruction>(new_opcode, asVOP3(Format::VOPC), 2, 1);
2081       VOP3_instruction& cmp_vop3 = cmp->vop3();
2082       memcpy(new_vop3->abs, cmp_vop3.abs, sizeof(new_vop3->abs));
2083       memcpy(new_vop3->neg, cmp_vop3.neg, sizeof(new_vop3->neg));
2084       new_vop3->clamp = cmp_vop3.clamp;
2085       new_vop3->omod = cmp_vop3.omod;
2086       new_vop3->opsel = cmp_vop3.opsel;
2087       new_instr = new_vop3;
2088    } else if (cmp->isSDWA()) {
2089       SDWA_instruction* new_sdwa = create_instruction<SDWA_instruction>(
2090          new_opcode, (Format)((uint16_t)Format::SDWA | (uint16_t)Format::VOPC), 2, 1);
2091       SDWA_instruction& cmp_sdwa = cmp->sdwa();
2092       memcpy(new_sdwa->abs, cmp_sdwa.abs, sizeof(new_sdwa->abs));
2093       memcpy(new_sdwa->sel, cmp_sdwa.sel, sizeof(new_sdwa->sel));
2094       memcpy(new_sdwa->neg, cmp_sdwa.neg, sizeof(new_sdwa->neg));
2095       new_sdwa->dst_sel = cmp_sdwa.dst_sel;
2096       new_sdwa->clamp = cmp_sdwa.clamp;
2097       new_sdwa->omod = cmp_sdwa.omod;
2098       new_instr = new_sdwa;
2099    } else if (cmp->isDPP()) {
2100       DPP_instruction* new_dpp = create_instruction<DPP_instruction>(
2101          new_opcode, (Format)((uint16_t)Format::DPP | (uint16_t)Format::VOPC), 2, 1);
2102       DPP_instruction& cmp_dpp = cmp->dpp();
2103       memcpy(new_dpp->abs, cmp_dpp.abs, sizeof(new_dpp->abs));
2104       memcpy(new_dpp->neg, cmp_dpp.neg, sizeof(new_dpp->neg));
2105       new_dpp->dpp_ctrl = cmp_dpp.dpp_ctrl;
2106       new_dpp->row_mask = cmp_dpp.row_mask;
2107       new_dpp->bank_mask = cmp_dpp.bank_mask;
2108       new_dpp->bound_ctrl = cmp_dpp.bound_ctrl;
2109       new_instr = new_dpp;
2110    } else {
2111       new_instr = create_instruction<VOPC_instruction>(new_opcode, Format::VOPC, 2, 1);
2112       instr->definitions[0].setHint(vcc);
2113    }
2114    new_instr->operands[0] = cmp->operands[0];
2115    new_instr->operands[1] = cmp->operands[1];
2116    new_instr->definitions[0] = instr->definitions[0];
2117 
2118    ctx.info[instr->definitions[0].tempId()].label = 0;
2119    ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2120 
2121    instr.reset(new_instr);
2122 
2123    return true;
2124 }
2125 
2126 /* op1(op2(1, 2), 0) if swap = false
2127  * op1(0, op2(1, 2)) if swap = true */
2128 bool
match_op3_for_vop3(opt_ctx & ctx,aco_opcode op1,aco_opcode op2,Instruction * op1_instr,bool swap,const char * shuffle_str,Operand operands[3],bool neg[3],bool abs[3],uint8_t * opsel,bool * op1_clamp,uint8_t * op1_omod,bool * inbetween_neg,bool * inbetween_abs,bool * inbetween_opsel,bool * precise)2129 match_op3_for_vop3(opt_ctx& ctx, aco_opcode op1, aco_opcode op2, Instruction* op1_instr, bool swap,
2130                    const char* shuffle_str, Operand operands[3], bool neg[3], bool abs[3],
2131                    uint8_t* opsel, bool* op1_clamp, uint8_t* op1_omod, bool* inbetween_neg,
2132                    bool* inbetween_abs, bool* inbetween_opsel, bool* precise)
2133 {
2134    /* checks */
2135    if (op1_instr->opcode != op1)
2136       return false;
2137 
2138    Instruction* op2_instr = follow_operand(ctx, op1_instr->operands[swap]);
2139    if (!op2_instr || op2_instr->opcode != op2)
2140       return false;
2141    if (fixed_to_exec(op2_instr->operands[0]) || fixed_to_exec(op2_instr->operands[1]))
2142       return false;
2143 
2144    VOP3_instruction* op1_vop3 = op1_instr->isVOP3() ? &op1_instr->vop3() : NULL;
2145    VOP3_instruction* op2_vop3 = op2_instr->isVOP3() ? &op2_instr->vop3() : NULL;
2146 
2147    if (op1_instr->isSDWA() || op2_instr->isSDWA())
2148       return false;
2149    if (op1_instr->isDPP() || op2_instr->isDPP())
2150       return false;
2151 
2152    /* don't support inbetween clamp/omod */
2153    if (op2_vop3 && (op2_vop3->clamp || op2_vop3->omod))
2154       return false;
2155 
2156    /* get operands and modifiers and check inbetween modifiers */
2157    *op1_clamp = op1_vop3 ? op1_vop3->clamp : false;
2158    *op1_omod = op1_vop3 ? op1_vop3->omod : 0u;
2159 
2160    if (inbetween_neg)
2161       *inbetween_neg = op1_vop3 ? op1_vop3->neg[swap] : false;
2162    else if (op1_vop3 && op1_vop3->neg[swap])
2163       return false;
2164 
2165    if (inbetween_abs)
2166       *inbetween_abs = op1_vop3 ? op1_vop3->abs[swap] : false;
2167    else if (op1_vop3 && op1_vop3->abs[swap])
2168       return false;
2169 
2170    if (inbetween_opsel)
2171       *inbetween_opsel = op1_vop3 ? op1_vop3->opsel & (1 << (unsigned)swap) : false;
2172    else if (op1_vop3 && op1_vop3->opsel & (1 << (unsigned)swap))
2173       return false;
2174 
2175    *precise = op1_instr->definitions[0].isPrecise() || op2_instr->definitions[0].isPrecise();
2176 
2177    int shuffle[3];
2178    shuffle[shuffle_str[0] - '0'] = 0;
2179    shuffle[shuffle_str[1] - '0'] = 1;
2180    shuffle[shuffle_str[2] - '0'] = 2;
2181 
2182    operands[shuffle[0]] = op1_instr->operands[!swap];
2183    neg[shuffle[0]] = op1_vop3 ? op1_vop3->neg[!swap] : false;
2184    abs[shuffle[0]] = op1_vop3 ? op1_vop3->abs[!swap] : false;
2185    if (op1_vop3 && (op1_vop3->opsel & (1 << (unsigned)!swap)))
2186       *opsel |= 1 << shuffle[0];
2187 
2188    for (unsigned i = 0; i < 2; i++) {
2189       operands[shuffle[i + 1]] = op2_instr->operands[i];
2190       neg[shuffle[i + 1]] = op2_vop3 ? op2_vop3->neg[i] : false;
2191       abs[shuffle[i + 1]] = op2_vop3 ? op2_vop3->abs[i] : false;
2192       if (op2_vop3 && op2_vop3->opsel & (1 << i))
2193          *opsel |= 1 << shuffle[i + 1];
2194    }
2195 
2196    /* check operands */
2197    if (!check_vop3_operands(ctx, 3, operands))
2198       return false;
2199 
2200    return true;
2201 }
2202 
2203 void
create_vop3_for_op3(opt_ctx & ctx,aco_opcode opcode,aco_ptr<Instruction> & instr,Operand operands[3],bool neg[3],bool abs[3],uint8_t opsel,bool clamp,unsigned omod)2204 create_vop3_for_op3(opt_ctx& ctx, aco_opcode opcode, aco_ptr<Instruction>& instr,
2205                     Operand operands[3], bool neg[3], bool abs[3], uint8_t opsel, bool clamp,
2206                     unsigned omod)
2207 {
2208    VOP3_instruction* new_instr = create_instruction<VOP3_instruction>(opcode, Format::VOP3, 3, 1);
2209    memcpy(new_instr->abs, abs, sizeof(bool[3]));
2210    memcpy(new_instr->neg, neg, sizeof(bool[3]));
2211    new_instr->clamp = clamp;
2212    new_instr->omod = omod;
2213    new_instr->opsel = opsel;
2214    new_instr->operands[0] = operands[0];
2215    new_instr->operands[1] = operands[1];
2216    new_instr->operands[2] = operands[2];
2217    new_instr->definitions[0] = instr->definitions[0];
2218    ctx.info[instr->definitions[0].tempId()].label = 0;
2219 
2220    instr.reset(new_instr);
2221 }
2222 
2223 bool
combine_three_valu_op(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode op2,aco_opcode new_op,const char * shuffle,uint8_t ops)2224 combine_three_valu_op(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode op2, aco_opcode new_op,
2225                       const char* shuffle, uint8_t ops)
2226 {
2227    for (unsigned swap = 0; swap < 2; swap++) {
2228       if (!((1 << swap) & ops))
2229          continue;
2230 
2231       Operand operands[3];
2232       bool neg[3], abs[3], clamp, precise;
2233       uint8_t opsel = 0, omod = 0;
2234       if (match_op3_for_vop3(ctx, instr->opcode, op2, instr.get(), swap, shuffle, operands, neg,
2235                              abs, &opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
2236          ctx.uses[instr->operands[swap].tempId()]--;
2237          create_vop3_for_op3(ctx, new_op, instr, operands, neg, abs, opsel, clamp, omod);
2238          return true;
2239       }
2240    }
2241    return false;
2242 }
2243 
2244 /* creates v_lshl_add_u32, v_lshl_or_b32 or v_and_or_b32 */
2245 bool
combine_add_or_then_and_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr)2246 combine_add_or_then_and_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2247 {
2248    bool is_or = instr->opcode == aco_opcode::v_or_b32;
2249    aco_opcode new_op_lshl = is_or ? aco_opcode::v_lshl_or_b32 : aco_opcode::v_lshl_add_u32;
2250 
2251    if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::s_and_b32, aco_opcode::v_and_or_b32,
2252                                       "120", 1 | 2))
2253       return true;
2254    if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32,
2255                                       "120", 1 | 2))
2256       return true;
2257    if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, new_op_lshl, "120", 1 | 2))
2258       return true;
2259    if (combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, new_op_lshl, "210", 1 | 2))
2260       return true;
2261 
2262    if (instr->isSDWA() || instr->isDPP())
2263       return false;
2264 
2265    /* v_or_b32(p_extract(a, 0, 8/16, 0), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2266     * v_or_b32(p_insert(a, 0, 8/16), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2267     * v_or_b32(p_insert(a, 24/16, 8/16), b) -> v_lshl_or_b32(a, 24/16, b)
2268     * v_add_u32(p_insert(a, 24/16, 8/16), b) -> v_lshl_add_b32(a, 24/16, b)
2269     */
2270    for (unsigned i = 0; i < 2; i++) {
2271       Instruction* extins = follow_operand(ctx, instr->operands[i]);
2272       if (!extins)
2273          continue;
2274 
2275       aco_opcode op;
2276       Operand operands[3];
2277 
2278       if (extins->opcode == aco_opcode::p_insert &&
2279           (extins->operands[1].constantValue() + 1) * extins->operands[2].constantValue() == 32) {
2280          op = new_op_lshl;
2281          operands[1] =
2282             Operand::c32(extins->operands[1].constantValue() * extins->operands[2].constantValue());
2283       } else if (is_or &&
2284                  (extins->opcode == aco_opcode::p_insert ||
2285                   (extins->opcode == aco_opcode::p_extract &&
2286                    extins->operands[3].constantEquals(0))) &&
2287                  extins->operands[1].constantEquals(0)) {
2288          op = aco_opcode::v_and_or_b32;
2289          operands[1] = Operand::c32(extins->operands[2].constantEquals(8) ? 0xffu : 0xffffu);
2290       } else {
2291          continue;
2292       }
2293 
2294       operands[0] = extins->operands[0];
2295       operands[2] = instr->operands[!i];
2296 
2297       if (!check_vop3_operands(ctx, 3, operands))
2298          continue;
2299 
2300       bool neg[3] = {}, abs[3] = {};
2301       uint8_t opsel = 0, omod = 0;
2302       bool clamp = false;
2303       if (instr->isVOP3())
2304          clamp = instr->vop3().clamp;
2305 
2306       ctx.uses[instr->operands[i].tempId()]--;
2307       create_vop3_for_op3(ctx, op, instr, operands, neg, abs, opsel, clamp, omod);
2308       return true;
2309    }
2310 
2311    return false;
2312 }
2313 
2314 bool
combine_minmax(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode opposite,aco_opcode minmax3)2315 combine_minmax(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode opposite, aco_opcode minmax3)
2316 {
2317    /* TODO: this can handle SDWA min/max instructions by using opsel */
2318    if (combine_three_valu_op(ctx, instr, instr->opcode, minmax3, "012", 1 | 2))
2319       return true;
2320 
2321    /* min(-max(a, b), c) -> min3(c, -a, -b) *
2322     * max(-min(a, b), c) -> max3(c, -a, -b) */
2323    for (unsigned swap = 0; swap < 2; swap++) {
2324       Operand operands[3];
2325       bool neg[3], abs[3], clamp, precise;
2326       uint8_t opsel = 0, omod = 0;
2327       bool inbetween_neg;
2328       if (match_op3_for_vop3(ctx, instr->opcode, opposite, instr.get(), swap, "012", operands, neg,
2329                              abs, &opsel, &clamp, &omod, &inbetween_neg, NULL, NULL, &precise) &&
2330           inbetween_neg) {
2331          ctx.uses[instr->operands[swap].tempId()]--;
2332          neg[1] = !neg[1];
2333          neg[2] = !neg[2];
2334          create_vop3_for_op3(ctx, minmax3, instr, operands, neg, abs, opsel, clamp, omod);
2335          return true;
2336       }
2337    }
2338    return false;
2339 }
2340 
2341 /* s_not_b32(s_and_b32(a, b)) -> s_nand_b32(a, b)
2342  * s_not_b32(s_or_b32(a, b)) -> s_nor_b32(a, b)
2343  * s_not_b32(s_xor_b32(a, b)) -> s_xnor_b32(a, b)
2344  * s_not_b64(s_and_b64(a, b)) -> s_nand_b64(a, b)
2345  * s_not_b64(s_or_b64(a, b)) -> s_nor_b64(a, b)
2346  * s_not_b64(s_xor_b64(a, b)) -> s_xnor_b64(a, b) */
2347 bool
combine_salu_not_bitwise(opt_ctx & ctx,aco_ptr<Instruction> & instr)2348 combine_salu_not_bitwise(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2349 {
2350    /* checks */
2351    if (!instr->operands[0].isTemp())
2352       return false;
2353    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2354       return false;
2355 
2356    Instruction* op2_instr = follow_operand(ctx, instr->operands[0]);
2357    if (!op2_instr)
2358       return false;
2359    switch (op2_instr->opcode) {
2360    case aco_opcode::s_and_b32:
2361    case aco_opcode::s_or_b32:
2362    case aco_opcode::s_xor_b32:
2363    case aco_opcode::s_and_b64:
2364    case aco_opcode::s_or_b64:
2365    case aco_opcode::s_xor_b64: break;
2366    default: return false;
2367    }
2368 
2369    /* create instruction */
2370    std::swap(instr->definitions[0], op2_instr->definitions[0]);
2371    std::swap(instr->definitions[1], op2_instr->definitions[1]);
2372    ctx.uses[instr->operands[0].tempId()]--;
2373    ctx.info[op2_instr->definitions[0].tempId()].label = 0;
2374 
2375    switch (op2_instr->opcode) {
2376    case aco_opcode::s_and_b32: op2_instr->opcode = aco_opcode::s_nand_b32; break;
2377    case aco_opcode::s_or_b32: op2_instr->opcode = aco_opcode::s_nor_b32; break;
2378    case aco_opcode::s_xor_b32: op2_instr->opcode = aco_opcode::s_xnor_b32; break;
2379    case aco_opcode::s_and_b64: op2_instr->opcode = aco_opcode::s_nand_b64; break;
2380    case aco_opcode::s_or_b64: op2_instr->opcode = aco_opcode::s_nor_b64; break;
2381    case aco_opcode::s_xor_b64: op2_instr->opcode = aco_opcode::s_xnor_b64; break;
2382    default: break;
2383    }
2384 
2385    return true;
2386 }
2387 
2388 /* s_and_b32(a, s_not_b32(b)) -> s_andn2_b32(a, b)
2389  * s_or_b32(a, s_not_b32(b)) -> s_orn2_b32(a, b)
2390  * s_and_b64(a, s_not_b64(b)) -> s_andn2_b64(a, b)
2391  * s_or_b64(a, s_not_b64(b)) -> s_orn2_b64(a, b) */
2392 bool
combine_salu_n2(opt_ctx & ctx,aco_ptr<Instruction> & instr)2393 combine_salu_n2(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2394 {
2395    if (instr->definitions[0].isTemp() && ctx.info[instr->definitions[0].tempId()].is_uniform_bool())
2396       return false;
2397 
2398    for (unsigned i = 0; i < 2; i++) {
2399       Instruction* op2_instr = follow_operand(ctx, instr->operands[i]);
2400       if (!op2_instr || (op2_instr->opcode != aco_opcode::s_not_b32 &&
2401                          op2_instr->opcode != aco_opcode::s_not_b64))
2402          continue;
2403       if (ctx.uses[op2_instr->definitions[1].tempId()] || fixed_to_exec(op2_instr->operands[0]))
2404          continue;
2405 
2406       if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
2407           instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
2408          continue;
2409 
2410       ctx.uses[instr->operands[i].tempId()]--;
2411       instr->operands[0] = instr->operands[!i];
2412       instr->operands[1] = op2_instr->operands[0];
2413       ctx.info[instr->definitions[0].tempId()].label = 0;
2414 
2415       switch (instr->opcode) {
2416       case aco_opcode::s_and_b32: instr->opcode = aco_opcode::s_andn2_b32; break;
2417       case aco_opcode::s_or_b32: instr->opcode = aco_opcode::s_orn2_b32; break;
2418       case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_andn2_b64; break;
2419       case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_orn2_b64; break;
2420       default: break;
2421       }
2422 
2423       return true;
2424    }
2425    return false;
2426 }
2427 
2428 /* s_add_{i32,u32}(a, s_lshl_b32(b, <n>)) -> s_lshl<n>_add_u32(a, b) */
2429 bool
combine_salu_lshl_add(opt_ctx & ctx,aco_ptr<Instruction> & instr)2430 combine_salu_lshl_add(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2431 {
2432    if (instr->opcode == aco_opcode::s_add_i32 && ctx.uses[instr->definitions[1].tempId()])
2433       return false;
2434 
2435    for (unsigned i = 0; i < 2; i++) {
2436       Instruction* op2_instr = follow_operand(ctx, instr->operands[i], true);
2437       if (!op2_instr || op2_instr->opcode != aco_opcode::s_lshl_b32 ||
2438           ctx.uses[op2_instr->definitions[1].tempId()])
2439          continue;
2440       if (!op2_instr->operands[1].isConstant() || fixed_to_exec(op2_instr->operands[0]))
2441          continue;
2442 
2443       uint32_t shift = op2_instr->operands[1].constantValue();
2444       if (shift < 1 || shift > 4)
2445          continue;
2446 
2447       if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
2448           instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
2449          continue;
2450 
2451       ctx.uses[instr->operands[i].tempId()]--;
2452       instr->operands[1] = instr->operands[!i];
2453       instr->operands[0] = op2_instr->operands[0];
2454       ctx.info[instr->definitions[0].tempId()].label = 0;
2455 
2456       instr->opcode = std::array<aco_opcode, 4>{
2457          aco_opcode::s_lshl1_add_u32, aco_opcode::s_lshl2_add_u32, aco_opcode::s_lshl3_add_u32,
2458          aco_opcode::s_lshl4_add_u32}[shift - 1];
2459 
2460       return true;
2461    }
2462    return false;
2463 }
2464 
2465 bool
combine_add_sub_b2i(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode new_op,uint8_t ops)2466 combine_add_sub_b2i(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode new_op, uint8_t ops)
2467 {
2468    if (instr->usesModifiers())
2469       return false;
2470 
2471    for (unsigned i = 0; i < 2; i++) {
2472       if (!((1 << i) & ops))
2473          continue;
2474       if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2i() &&
2475           ctx.uses[instr->operands[i].tempId()] == 1) {
2476 
2477          aco_ptr<Instruction> new_instr;
2478          if (instr->operands[!i].isTemp() &&
2479              instr->operands[!i].getTemp().type() == RegType::vgpr) {
2480             new_instr.reset(create_instruction<VOP2_instruction>(new_op, Format::VOP2, 3, 2));
2481          } else if (ctx.program->chip_class >= GFX10 ||
2482                     (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
2483             new_instr.reset(
2484                create_instruction<VOP3_instruction>(new_op, asVOP3(Format::VOP2), 3, 2));
2485          } else {
2486             return false;
2487          }
2488          ctx.uses[instr->operands[i].tempId()]--;
2489          new_instr->definitions[0] = instr->definitions[0];
2490          if (instr->definitions.size() == 2) {
2491             new_instr->definitions[1] = instr->definitions[1];
2492          } else {
2493             new_instr->definitions[1] =
2494                Definition(ctx.program->allocateTmp(ctx.program->lane_mask));
2495             /* Make sure the uses vector is large enough and the number of
2496              * uses properly initialized to 0.
2497              */
2498             ctx.uses.push_back(0);
2499          }
2500          new_instr->definitions[1].setHint(vcc);
2501          new_instr->operands[0] = Operand::zero();
2502          new_instr->operands[1] = instr->operands[!i];
2503          new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
2504          instr = std::move(new_instr);
2505          ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
2506          return true;
2507       }
2508    }
2509 
2510    return false;
2511 }
2512 
2513 bool
combine_add_bcnt(opt_ctx & ctx,aco_ptr<Instruction> & instr)2514 combine_add_bcnt(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2515 {
2516    if (instr->usesModifiers())
2517       return false;
2518 
2519    for (unsigned i = 0; i < 2; i++) {
2520       Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
2521       if (op_instr && op_instr->opcode == aco_opcode::v_bcnt_u32_b32 &&
2522           !op_instr->usesModifiers() && op_instr->operands[0].isTemp() &&
2523           op_instr->operands[0].getTemp().type() == RegType::vgpr &&
2524           op_instr->operands[1].constantEquals(0)) {
2525          aco_ptr<Instruction> new_instr{
2526             create_instruction<VOP3_instruction>(aco_opcode::v_bcnt_u32_b32, Format::VOP3, 2, 1)};
2527          ctx.uses[instr->operands[i].tempId()]--;
2528          new_instr->operands[0] = op_instr->operands[0];
2529          new_instr->operands[1] = instr->operands[!i];
2530          new_instr->definitions[0] = instr->definitions[0];
2531          instr = std::move(new_instr);
2532          ctx.info[instr->definitions[0].tempId()].label = 0;
2533 
2534          return true;
2535       }
2536    }
2537 
2538    return false;
2539 }
2540 
2541 bool
get_minmax_info(aco_opcode op,aco_opcode * min,aco_opcode * max,aco_opcode * min3,aco_opcode * max3,aco_opcode * med3,bool * some_gfx9_only)2542 get_minmax_info(aco_opcode op, aco_opcode* min, aco_opcode* max, aco_opcode* min3, aco_opcode* max3,
2543                 aco_opcode* med3, bool* some_gfx9_only)
2544 {
2545    switch (op) {
2546 #define MINMAX(type, gfx9)                                                                         \
2547    case aco_opcode::v_min_##type:                                                                  \
2548    case aco_opcode::v_max_##type:                                                                  \
2549    case aco_opcode::v_med3_##type:                                                                 \
2550       *min = aco_opcode::v_min_##type;                                                             \
2551       *max = aco_opcode::v_max_##type;                                                             \
2552       *med3 = aco_opcode::v_med3_##type;                                                           \
2553       *min3 = aco_opcode::v_min3_##type;                                                           \
2554       *max3 = aco_opcode::v_max3_##type;                                                           \
2555       *some_gfx9_only = gfx9;                                                                      \
2556       return true;
2557       MINMAX(f32, false)
2558       MINMAX(u32, false)
2559       MINMAX(i32, false)
2560       MINMAX(f16, true)
2561       MINMAX(u16, true)
2562       MINMAX(i16, true)
2563 #undef MINMAX
2564    default: return false;
2565    }
2566 }
2567 
2568 /* when ub > lb:
2569  * v_min_{f,u,i}{16,32}(v_max_{f,u,i}{16,32}(a, lb), ub) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
2570  * v_max_{f,u,i}{16,32}(v_min_{f,u,i}{16,32}(a, ub), lb) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
2571  */
2572 bool
combine_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode min,aco_opcode max,aco_opcode med)2573 combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode min, aco_opcode max,
2574               aco_opcode med)
2575 {
2576    /* TODO: GLSL's clamp(x, minVal, maxVal) and SPIR-V's
2577     * FClamp(x, minVal, maxVal)/NClamp(x, minVal, maxVal) are undefined if
2578     * minVal > maxVal, which means we can always select it to a v_med3_f32 */
2579    aco_opcode other_op;
2580    if (instr->opcode == min)
2581       other_op = max;
2582    else if (instr->opcode == max)
2583       other_op = min;
2584    else
2585       return false;
2586 
2587    for (unsigned swap = 0; swap < 2; swap++) {
2588       Operand operands[3];
2589       bool neg[3], abs[3], clamp, precise;
2590       uint8_t opsel = 0, omod = 0;
2591       if (match_op3_for_vop3(ctx, instr->opcode, other_op, instr.get(), swap, "012", operands, neg,
2592                              abs, &opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
2593          /* max(min(src, upper), lower) returns upper if src is NaN, but
2594           * med3(src, lower, upper) returns lower.
2595           */
2596          if (precise && instr->opcode != min)
2597             continue;
2598 
2599          int const0_idx = -1, const1_idx = -1;
2600          uint32_t const0 = 0, const1 = 0;
2601          for (int i = 0; i < 3; i++) {
2602             uint32_t val;
2603             if (operands[i].isConstant()) {
2604                val = operands[i].constantValue();
2605             } else if (operands[i].isTemp() &&
2606                        ctx.info[operands[i].tempId()].is_constant_or_literal(32)) {
2607                val = ctx.info[operands[i].tempId()].val;
2608             } else {
2609                continue;
2610             }
2611             if (const0_idx >= 0) {
2612                const1_idx = i;
2613                const1 = val;
2614             } else {
2615                const0_idx = i;
2616                const0 = val;
2617             }
2618          }
2619          if (const0_idx < 0 || const1_idx < 0)
2620             continue;
2621 
2622          if (opsel & (1 << const0_idx))
2623             const0 >>= 16;
2624          if (opsel & (1 << const1_idx))
2625             const1 >>= 16;
2626 
2627          int lower_idx = const0_idx;
2628          switch (min) {
2629          case aco_opcode::v_min_f32:
2630          case aco_opcode::v_min_f16: {
2631             float const0_f, const1_f;
2632             if (min == aco_opcode::v_min_f32) {
2633                memcpy(&const0_f, &const0, 4);
2634                memcpy(&const1_f, &const1, 4);
2635             } else {
2636                const0_f = _mesa_half_to_float(const0);
2637                const1_f = _mesa_half_to_float(const1);
2638             }
2639             if (abs[const0_idx])
2640                const0_f = fabsf(const0_f);
2641             if (abs[const1_idx])
2642                const1_f = fabsf(const1_f);
2643             if (neg[const0_idx])
2644                const0_f = -const0_f;
2645             if (neg[const1_idx])
2646                const1_f = -const1_f;
2647             lower_idx = const0_f < const1_f ? const0_idx : const1_idx;
2648             break;
2649          }
2650          case aco_opcode::v_min_u32: {
2651             lower_idx = const0 < const1 ? const0_idx : const1_idx;
2652             break;
2653          }
2654          case aco_opcode::v_min_u16: {
2655             lower_idx = (uint16_t)const0 < (uint16_t)const1 ? const0_idx : const1_idx;
2656             break;
2657          }
2658          case aco_opcode::v_min_i32: {
2659             int32_t const0_i =
2660                const0 & 0x80000000u ? -2147483648 + (int32_t)(const0 & 0x7fffffffu) : const0;
2661             int32_t const1_i =
2662                const1 & 0x80000000u ? -2147483648 + (int32_t)(const1 & 0x7fffffffu) : const1;
2663             lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
2664             break;
2665          }
2666          case aco_opcode::v_min_i16: {
2667             int16_t const0_i = const0 & 0x8000u ? -32768 + (int16_t)(const0 & 0x7fffu) : const0;
2668             int16_t const1_i = const1 & 0x8000u ? -32768 + (int16_t)(const1 & 0x7fffu) : const1;
2669             lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
2670             break;
2671          }
2672          default: break;
2673          }
2674          int upper_idx = lower_idx == const0_idx ? const1_idx : const0_idx;
2675 
2676          if (instr->opcode == min) {
2677             if (upper_idx != 0 || lower_idx == 0)
2678                return false;
2679          } else {
2680             if (upper_idx == 0 || lower_idx != 0)
2681                return false;
2682          }
2683 
2684          ctx.uses[instr->operands[swap].tempId()]--;
2685          create_vop3_for_op3(ctx, med, instr, operands, neg, abs, opsel, clamp, omod);
2686 
2687          return true;
2688       }
2689    }
2690 
2691    return false;
2692 }
2693 
2694 void
apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)2695 apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2696 {
2697    bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64 ||
2698                      instr->opcode == aco_opcode::v_lshrrev_b64 ||
2699                      instr->opcode == aco_opcode::v_ashrrev_i64;
2700 
2701    /* find candidates and create the set of sgprs already read */
2702    unsigned sgpr_ids[2] = {0, 0};
2703    uint32_t operand_mask = 0;
2704    bool has_literal = false;
2705    for (unsigned i = 0; i < instr->operands.size(); i++) {
2706       if (instr->operands[i].isLiteral())
2707          has_literal = true;
2708       if (!instr->operands[i].isTemp())
2709          continue;
2710       if (instr->operands[i].getTemp().type() == RegType::sgpr) {
2711          if (instr->operands[i].tempId() != sgpr_ids[0])
2712             sgpr_ids[!!sgpr_ids[0]] = instr->operands[i].tempId();
2713       }
2714       ssa_info& info = ctx.info[instr->operands[i].tempId()];
2715       if (is_copy_label(ctx, instr, info) && info.temp.type() == RegType::sgpr)
2716          operand_mask |= 1u << i;
2717       if (info.is_extract() && info.instr->operands[0].getTemp().type() == RegType::sgpr)
2718          operand_mask |= 1u << i;
2719    }
2720    unsigned max_sgprs = 1;
2721    if (ctx.program->chip_class >= GFX10 && !is_shift64)
2722       max_sgprs = 2;
2723    if (has_literal)
2724       max_sgprs--;
2725 
2726    unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
2727 
2728    /* keep on applying sgprs until there is nothing left to be done */
2729    while (operand_mask) {
2730       uint32_t sgpr_idx = 0;
2731       uint32_t sgpr_info_id = 0;
2732       uint32_t mask = operand_mask;
2733       /* choose a sgpr */
2734       while (mask) {
2735          unsigned i = u_bit_scan(&mask);
2736          uint16_t uses = ctx.uses[instr->operands[i].tempId()];
2737          if (sgpr_info_id == 0 || uses < ctx.uses[sgpr_info_id]) {
2738             sgpr_idx = i;
2739             sgpr_info_id = instr->operands[i].tempId();
2740          }
2741       }
2742       operand_mask &= ~(1u << sgpr_idx);
2743 
2744       ssa_info& info = ctx.info[sgpr_info_id];
2745 
2746       /* Applying two sgprs require making it VOP3, so don't do it unless it's
2747        * definitively beneficial.
2748        * TODO: this is too conservative because later the use count could be reduced to 1 */
2749       if (!info.is_extract() && num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3() &&
2750           !instr->isSDWA() && instr->format != Format::VOP3P)
2751          break;
2752 
2753       Temp sgpr = info.is_extract() ? info.instr->operands[0].getTemp() : info.temp;
2754       bool new_sgpr = sgpr.id() != sgpr_ids[0] && sgpr.id() != sgpr_ids[1];
2755       if (new_sgpr && num_sgprs >= max_sgprs)
2756          continue;
2757 
2758       if (sgpr_idx == 0)
2759          instr->format = withoutDPP(instr->format);
2760 
2761       if (sgpr_idx == 0 || instr->isVOP3() || instr->isSDWA() || instr->isVOP3P() ||
2762           info.is_extract()) {
2763          /* can_apply_extract() checks SGPR encoding restrictions */
2764          if (info.is_extract() && can_apply_extract(ctx, instr, sgpr_idx, info))
2765             apply_extract(ctx, instr, sgpr_idx, info);
2766          else if (info.is_extract())
2767             continue;
2768          instr->operands[sgpr_idx] = Operand(sgpr);
2769       } else if (can_swap_operands(instr, &instr->opcode)) {
2770          instr->operands[sgpr_idx] = instr->operands[0];
2771          instr->operands[0] = Operand(sgpr);
2772          /* swap bits using a 4-entry LUT */
2773          uint32_t swapped = (0x3120 >> (operand_mask & 0x3)) & 0xf;
2774          operand_mask = (operand_mask & ~0x3) | swapped;
2775       } else if (can_use_VOP3(ctx, instr) && !info.is_extract()) {
2776          to_VOP3(ctx, instr);
2777          instr->operands[sgpr_idx] = Operand(sgpr);
2778       } else {
2779          continue;
2780       }
2781 
2782       if (new_sgpr)
2783          sgpr_ids[num_sgprs++] = sgpr.id();
2784       ctx.uses[sgpr_info_id]--;
2785       ctx.uses[sgpr.id()]++;
2786 
2787       /* TODO: handle when it's a VGPR */
2788       if ((ctx.info[sgpr.id()].label & (label_extract | label_temp)) &&
2789           ctx.info[sgpr.id()].temp.type() == RegType::sgpr)
2790          operand_mask |= 1u << sgpr_idx;
2791    }
2792 }
2793 
2794 template <typename T>
2795 bool
apply_omod_clamp_helper(opt_ctx & ctx,T * instr,ssa_info & def_info)2796 apply_omod_clamp_helper(opt_ctx& ctx, T* instr, ssa_info& def_info)
2797 {
2798    if (!def_info.is_clamp() && (instr->clamp || instr->omod))
2799       return false;
2800 
2801    if (def_info.is_omod2())
2802       instr->omod = 1;
2803    else if (def_info.is_omod4())
2804       instr->omod = 2;
2805    else if (def_info.is_omod5())
2806       instr->omod = 3;
2807    else if (def_info.is_clamp())
2808       instr->clamp = true;
2809 
2810    return true;
2811 }
2812 
2813 /* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */
2814 bool
apply_omod_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr)2815 apply_omod_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2816 {
2817    if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1 ||
2818        !instr_info.can_use_output_modifiers[(int)instr->opcode])
2819       return false;
2820 
2821    bool can_vop3 = can_use_VOP3(ctx, instr);
2822    if (!instr->isSDWA() && !can_vop3)
2823       return false;
2824 
2825    /* omod flushes -0 to +0 and has no effect if denormals are enabled */
2826    bool can_use_omod = (can_vop3 || ctx.program->chip_class >= GFX9); /* SDWA omod is GFX9+ */
2827    if (instr->definitions[0].bytes() == 4)
2828       can_use_omod =
2829          can_use_omod && ctx.fp_mode.denorm32 == 0 && !ctx.fp_mode.preserve_signed_zero_inf_nan32;
2830    else
2831       can_use_omod = can_use_omod && ctx.fp_mode.denorm16_64 == 0 &&
2832                      !ctx.fp_mode.preserve_signed_zero_inf_nan16_64;
2833 
2834    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
2835 
2836    uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5;
2837    if (!def_info.is_clamp() && !(can_use_omod && (def_info.label & omod_labels)))
2838       return false;
2839    /* if the omod/clamp instruction is dead, then the single user of this
2840     * instruction is a different instruction */
2841    if (!ctx.uses[def_info.instr->definitions[0].tempId()])
2842       return false;
2843 
2844    /* MADs/FMAs are created later, so we don't have to update the original add */
2845    assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
2846 
2847    if (instr->isSDWA()) {
2848       if (!apply_omod_clamp_helper(ctx, &instr->sdwa(), def_info))
2849          return false;
2850    } else {
2851       to_VOP3(ctx, instr);
2852       if (!apply_omod_clamp_helper(ctx, &instr->vop3(), def_info))
2853          return false;
2854    }
2855 
2856    instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
2857    ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_insert;
2858    ctx.uses[def_info.instr->definitions[0].tempId()]--;
2859 
2860    return true;
2861 }
2862 
2863 /* Combine an p_insert (or p_extract, in some cases) instruction with instr.
2864  * p_insert(instr(...)) -> instr_insert().
2865  */
2866 bool
apply_insert(opt_ctx & ctx,aco_ptr<Instruction> & instr)2867 apply_insert(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2868 {
2869    if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1)
2870       return false;
2871 
2872    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
2873    if (!def_info.is_insert())
2874       return false;
2875    /* if the insert instruction is dead, then the single user of this
2876     * instruction is a different instruction */
2877    if (!ctx.uses[def_info.instr->definitions[0].tempId()])
2878       return false;
2879 
2880    /* MADs/FMAs are created later, so we don't have to update the original add */
2881    assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
2882 
2883    SubdwordSel sel = parse_insert(def_info.instr);
2884    assert(sel);
2885 
2886    if (instr->isVOP3() && sel.size() == 2 && !sel.sign_extend() &&
2887        can_use_opsel(ctx.program->chip_class, instr->opcode, 3, sel.offset())) {
2888       if (instr->vop3().opsel & (1 << 3))
2889          return false;
2890       if (sel.offset())
2891          instr->vop3().opsel |= 1 << 3;
2892    } else {
2893       if (!can_use_SDWA(ctx.program->chip_class, instr, true))
2894          return false;
2895 
2896       to_SDWA(ctx, instr);
2897       if (instr->sdwa().dst_sel.size() != 4)
2898          return false;
2899       static_cast<SDWA_instruction*>(instr.get())->dst_sel = sel;
2900    }
2901 
2902    instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
2903    ctx.info[instr->definitions[0].tempId()].label = 0;
2904    ctx.uses[def_info.instr->definitions[0].tempId()]--;
2905 
2906    return true;
2907 }
2908 
2909 /* Remove superfluous extract after ds_read like so:
2910  * p_extract(ds_read_uN(), 0, N, 0) -> ds_read_uN()
2911  */
2912 bool
apply_ds_extract(opt_ctx & ctx,aco_ptr<Instruction> & extract)2913 apply_ds_extract(opt_ctx& ctx, aco_ptr<Instruction>& extract)
2914 {
2915    /* Check if p_extract has a usedef operand and is the only user. */
2916    if (!ctx.info[extract->operands[0].tempId()].is_usedef() ||
2917        ctx.uses[extract->operands[0].tempId()] > 1)
2918       return false;
2919 
2920    /* Check if the usedef is a DS instruction. */
2921    Instruction* ds = ctx.info[extract->operands[0].tempId()].instr;
2922    if (ds->format != Format::DS)
2923       return false;
2924 
2925    unsigned extract_idx = extract->operands[1].constantValue();
2926    unsigned bits_extracted = extract->operands[2].constantValue();
2927    unsigned sign_ext = extract->operands[3].constantValue();
2928    unsigned dst_bitsize = extract->definitions[0].bytes() * 8u;
2929 
2930    /* TODO: These are doable, but probably don't occour too often. */
2931    if (extract_idx || sign_ext || dst_bitsize != 32)
2932       return false;
2933 
2934    unsigned bits_loaded = 0;
2935    if (ds->opcode == aco_opcode::ds_read_u8 || ds->opcode == aco_opcode::ds_read_u8_d16)
2936       bits_loaded = 8;
2937    else if (ds->opcode == aco_opcode::ds_read_u16 || ds->opcode == aco_opcode::ds_read_u16_d16)
2938       bits_loaded = 16;
2939    else
2940       return false;
2941 
2942    /* Shrink the DS load if the extracted bit size is smaller. */
2943    bits_loaded = MIN2(bits_loaded, bits_extracted);
2944 
2945    /* Change the DS opcode so it writes the full register. */
2946    if (bits_loaded == 8)
2947       ds->opcode = aco_opcode::ds_read_u8;
2948    else if (bits_loaded == 16)
2949       ds->opcode = aco_opcode::ds_read_u16;
2950    else
2951       unreachable("Forgot to add DS opcode above.");
2952 
2953    /* The DS now produces the exact same thing as the extract, remove the extract. */
2954    std::swap(ds->definitions[0], extract->definitions[0]);
2955    ctx.uses[extract->definitions[0].tempId()] = 0;
2956    ctx.info[ds->definitions[0].tempId()].label = 0;
2957    return true;
2958 }
2959 
2960 /* v_and(a, v_subbrev_co(0, 0, vcc)) -> v_cndmask(0, a, vcc) */
2961 bool
combine_and_subbrev(opt_ctx & ctx,aco_ptr<Instruction> & instr)2962 combine_and_subbrev(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2963 {
2964    if (instr->usesModifiers())
2965       return false;
2966 
2967    for (unsigned i = 0; i < 2; i++) {
2968       Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
2969       if (op_instr && op_instr->opcode == aco_opcode::v_subbrev_co_u32 &&
2970           op_instr->operands[0].constantEquals(0) && op_instr->operands[1].constantEquals(0) &&
2971           !op_instr->usesModifiers()) {
2972 
2973          aco_ptr<Instruction> new_instr;
2974          if (instr->operands[!i].isTemp() &&
2975              instr->operands[!i].getTemp().type() == RegType::vgpr) {
2976             new_instr.reset(
2977                create_instruction<VOP2_instruction>(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1));
2978          } else if (ctx.program->chip_class >= GFX10 ||
2979                     (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
2980             new_instr.reset(create_instruction<VOP3_instruction>(aco_opcode::v_cndmask_b32,
2981                                                                  asVOP3(Format::VOP2), 3, 1));
2982          } else {
2983             return false;
2984          }
2985 
2986          ctx.uses[instr->operands[i].tempId()]--;
2987          if (ctx.uses[instr->operands[i].tempId()])
2988             ctx.uses[op_instr->operands[2].tempId()]++;
2989 
2990          new_instr->operands[0] = Operand::zero();
2991          new_instr->operands[1] = instr->operands[!i];
2992          new_instr->operands[2] = Operand(op_instr->operands[2]);
2993          new_instr->definitions[0] = instr->definitions[0];
2994          instr = std::move(new_instr);
2995          ctx.info[instr->definitions[0].tempId()].label = 0;
2996          return true;
2997       }
2998    }
2999 
3000    return false;
3001 }
3002 
3003 /* v_add_co(c, s_lshl(a, b)) -> v_mad_u32_u24(a, 1<<b, c)
3004  * v_add_co(c, v_lshlrev(a, b)) -> v_mad_u32_u24(b, 1<<a, c)
3005  * v_sub(c, s_lshl(a, b)) -> v_mad_i32_i24(a, -(1<<b), c)
3006  * v_sub(c, v_lshlrev(a, b)) -> v_mad_i32_i24(b, -(1<<a), c)
3007  */
3008 bool
combine_add_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr,bool is_sub)3009 combine_add_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr, bool is_sub)
3010 {
3011    if (instr->usesModifiers())
3012       return false;
3013 
3014    /* Substractions: start at operand 1 to avoid mixup such as
3015     * turning v_sub(v_lshlrev(a, b), c) into v_mad_i32_i24(b, -(1<<a), c)
3016     */
3017    unsigned start_op_idx = is_sub ? 1 : 0;
3018 
3019    /* Don't allow 24-bit operands on subtraction because
3020     * v_mad_i32_i24 applies a sign extension.
3021     */
3022    bool allow_24bit = !is_sub;
3023 
3024    for (unsigned i = start_op_idx; i < 2; i++) {
3025       Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
3026       if (!op_instr)
3027          continue;
3028 
3029       if (op_instr->opcode != aco_opcode::s_lshl_b32 &&
3030           op_instr->opcode != aco_opcode::v_lshlrev_b32)
3031          continue;
3032 
3033       int shift_op_idx = op_instr->opcode == aco_opcode::s_lshl_b32 ? 1 : 0;
3034 
3035       if (op_instr->operands[shift_op_idx].isConstant() &&
3036           ((allow_24bit && op_instr->operands[!shift_op_idx].is24bit()) ||
3037            op_instr->operands[!shift_op_idx].is16bit())) {
3038          uint32_t multiplier = 1 << (op_instr->operands[shift_op_idx].constantValue() % 32u);
3039          if (is_sub)
3040             multiplier = -multiplier;
3041          if (is_sub ? (multiplier < 0xff800000) : (multiplier > 0xffffff))
3042             continue;
3043 
3044          Operand ops[3] = {
3045             op_instr->operands[!shift_op_idx],
3046             Operand::c32(multiplier),
3047             instr->operands[!i],
3048          };
3049          if (!check_vop3_operands(ctx, 3, ops))
3050             return false;
3051 
3052          ctx.uses[instr->operands[i].tempId()]--;
3053 
3054          aco_opcode mad_op = is_sub ? aco_opcode::v_mad_i32_i24 : aco_opcode::v_mad_u32_u24;
3055          aco_ptr<VOP3_instruction> new_instr{
3056             create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 1)};
3057          for (unsigned op_idx = 0; op_idx < 3; ++op_idx)
3058             new_instr->operands[op_idx] = ops[op_idx];
3059          new_instr->definitions[0] = instr->definitions[0];
3060          instr = std::move(new_instr);
3061          ctx.info[instr->definitions[0].tempId()].label = 0;
3062          return true;
3063       }
3064    }
3065 
3066    return false;
3067 }
3068 
3069 void
propagate_swizzles(VOP3P_instruction * instr,uint8_t opsel_lo,uint8_t opsel_hi)3070 propagate_swizzles(VOP3P_instruction* instr, uint8_t opsel_lo, uint8_t opsel_hi)
3071 {
3072    /* propagate swizzles which apply to a result down to the instruction's operands:
3073     * result = a.xy + b.xx -> result.yx = a.yx + b.xx */
3074    assert((opsel_lo & 1) == opsel_lo);
3075    assert((opsel_hi & 1) == opsel_hi);
3076    uint8_t tmp_lo = instr->opsel_lo;
3077    uint8_t tmp_hi = instr->opsel_hi;
3078    bool neg_lo[3] = {instr->neg_lo[0], instr->neg_lo[1], instr->neg_lo[2]};
3079    bool neg_hi[3] = {instr->neg_hi[0], instr->neg_hi[1], instr->neg_hi[2]};
3080    if (opsel_lo == 1) {
3081       instr->opsel_lo = tmp_hi;
3082       for (unsigned i = 0; i < 3; i++)
3083          instr->neg_lo[i] = neg_hi[i];
3084    }
3085    if (opsel_hi == 0) {
3086       instr->opsel_hi = tmp_lo;
3087       for (unsigned i = 0; i < 3; i++)
3088          instr->neg_hi[i] = neg_lo[i];
3089    }
3090 }
3091 
3092 void
combine_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr)3093 combine_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3094 {
3095    VOP3P_instruction* vop3p = &instr->vop3p();
3096 
3097    /* apply clamp */
3098    if (instr->opcode == aco_opcode::v_pk_mul_f16 && instr->operands[1].constantEquals(0x3C00) &&
3099        vop3p->clamp && instr->operands[0].isTemp() && ctx.uses[instr->operands[0].tempId()] == 1) {
3100 
3101       ssa_info& info = ctx.info[instr->operands[0].tempId()];
3102       if (info.is_vop3p() && instr_info.can_use_output_modifiers[(int)info.instr->opcode]) {
3103          VOP3P_instruction* candidate = &ctx.info[instr->operands[0].tempId()].instr->vop3p();
3104          candidate->clamp = true;
3105          propagate_swizzles(candidate, vop3p->opsel_lo, vop3p->opsel_hi);
3106          instr->definitions[0].swapTemp(candidate->definitions[0]);
3107          ctx.info[candidate->definitions[0].tempId()].instr = candidate;
3108          ctx.uses[instr->definitions[0].tempId()]--;
3109          return;
3110       }
3111    }
3112 
3113    /* check for fneg modifiers */
3114    if (instr_info.can_use_input_modifiers[(int)instr->opcode]) {
3115       /* at this point, we only have 2-operand instructions */
3116       assert(instr->operands.size() == 2);
3117       for (unsigned i = 0; i < 2; i++) {
3118          Operand& op = instr->operands[i];
3119          if (!op.isTemp())
3120             continue;
3121 
3122          ssa_info& info = ctx.info[op.tempId()];
3123          if (info.is_vop3p() && info.instr->opcode == aco_opcode::v_pk_mul_f16 &&
3124              info.instr->operands[1].constantEquals(0xBC00)) {
3125             Operand ops[2] = {instr->operands[!i], info.instr->operands[0]};
3126             if (!check_vop3_operands(ctx, 2, ops))
3127                continue;
3128 
3129             VOP3P_instruction* fneg = &info.instr->vop3p();
3130             if (fneg->clamp)
3131                continue;
3132             instr->operands[i] = fneg->operands[0];
3133 
3134             /* opsel_lo/hi is either 0 or 1:
3135              * if 0 - pick selection from fneg->lo
3136              * if 1 - pick selection from fneg->hi
3137              */
3138             bool opsel_lo = (vop3p->opsel_lo >> i) & 1;
3139             bool opsel_hi = (vop3p->opsel_hi >> i) & 1;
3140             bool neg_lo = true ^ fneg->neg_lo[0] ^ fneg->neg_lo[1];
3141             bool neg_hi = true ^ fneg->neg_hi[0] ^ fneg->neg_hi[1];
3142             vop3p->neg_lo[i] ^= opsel_lo ? neg_hi : neg_lo;
3143             vop3p->neg_hi[i] ^= opsel_hi ? neg_hi : neg_lo;
3144             vop3p->opsel_lo ^= ((opsel_lo ? ~fneg->opsel_hi : fneg->opsel_lo) & 1) << i;
3145             vop3p->opsel_hi ^= ((opsel_hi ? ~fneg->opsel_hi : fneg->opsel_lo) & 1) << i;
3146 
3147             if (--ctx.uses[fneg->definitions[0].tempId()])
3148                ctx.uses[fneg->operands[0].tempId()]++;
3149          }
3150       }
3151    }
3152 
3153    if (instr->opcode == aco_opcode::v_pk_add_f16 || instr->opcode == aco_opcode::v_pk_add_u16) {
3154       bool fadd = instr->opcode == aco_opcode::v_pk_add_f16;
3155       if (fadd && instr->definitions[0].isPrecise())
3156          return;
3157 
3158       Instruction* mul_instr = nullptr;
3159       unsigned add_op_idx = 0;
3160       uint8_t opsel_lo = 0, opsel_hi = 0;
3161       uint32_t uses = UINT32_MAX;
3162 
3163       /* find the 'best' mul instruction to combine with the add */
3164       for (unsigned i = 0; i < 2; i++) {
3165          if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_vop3p())
3166             continue;
3167          ssa_info& info = ctx.info[instr->operands[i].tempId()];
3168          if (fadd) {
3169             if (info.instr->opcode != aco_opcode::v_pk_mul_f16 ||
3170                 info.instr->definitions[0].isPrecise())
3171                continue;
3172          } else {
3173             if (info.instr->opcode != aco_opcode::v_pk_mul_lo_u16)
3174                continue;
3175          }
3176 
3177          Operand op[3] = {info.instr->operands[0], info.instr->operands[1], instr->operands[1 - i]};
3178          if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
3179             continue;
3180 
3181          /* no clamp allowed between mul and add */
3182          if (info.instr->vop3p().clamp)
3183             continue;
3184 
3185          mul_instr = info.instr;
3186          add_op_idx = 1 - i;
3187          opsel_lo = (vop3p->opsel_lo >> i) & 1;
3188          opsel_hi = (vop3p->opsel_hi >> i) & 1;
3189          uses = ctx.uses[instr->operands[i].tempId()];
3190       }
3191 
3192       if (!mul_instr)
3193          return;
3194 
3195       /* convert to mad */
3196       Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1], instr->operands[add_op_idx]};
3197       ctx.uses[mul_instr->definitions[0].tempId()]--;
3198       if (ctx.uses[mul_instr->definitions[0].tempId()]) {
3199          if (op[0].isTemp())
3200             ctx.uses[op[0].tempId()]++;
3201          if (op[1].isTemp())
3202             ctx.uses[op[1].tempId()]++;
3203       }
3204 
3205       /* turn packed mul+add into v_pk_fma_f16 */
3206       assert(mul_instr->isVOP3P());
3207       aco_opcode mad = fadd ? aco_opcode::v_pk_fma_f16 : aco_opcode::v_pk_mad_u16;
3208       aco_ptr<VOP3P_instruction> fma{
3209          create_instruction<VOP3P_instruction>(mad, Format::VOP3P, 3, 1)};
3210       VOP3P_instruction* mul = &mul_instr->vop3p();
3211       for (unsigned i = 0; i < 2; i++) {
3212          fma->operands[i] = op[i];
3213          fma->neg_lo[i] = mul->neg_lo[i];
3214          fma->neg_hi[i] = mul->neg_hi[i];
3215       }
3216       fma->operands[2] = op[2];
3217       fma->clamp = vop3p->clamp;
3218       fma->opsel_lo = mul->opsel_lo;
3219       fma->opsel_hi = mul->opsel_hi;
3220       propagate_swizzles(fma.get(), opsel_lo, opsel_hi);
3221       fma->opsel_lo |= (vop3p->opsel_lo << (2 - add_op_idx)) & 0x4;
3222       fma->opsel_hi |= (vop3p->opsel_hi << (2 - add_op_idx)) & 0x4;
3223       fma->neg_lo[2] = vop3p->neg_lo[add_op_idx];
3224       fma->neg_hi[2] = vop3p->neg_hi[add_op_idx];
3225       fma->neg_lo[1] = fma->neg_lo[1] ^ vop3p->neg_lo[1 - add_op_idx];
3226       fma->neg_hi[1] = fma->neg_hi[1] ^ vop3p->neg_hi[1 - add_op_idx];
3227       fma->definitions[0] = instr->definitions[0];
3228       instr = std::move(fma);
3229       ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
3230       return;
3231    }
3232 }
3233 
3234 // TODO: we could possibly move the whole label_instruction pass to combine_instruction:
3235 // this would mean that we'd have to fix the instruction uses while value propagation
3236 
3237 void
combine_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)3238 combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3239 {
3240    if (instr->definitions.empty() || is_dead(ctx.uses, instr.get()))
3241       return;
3242 
3243    if (instr->isVALU()) {
3244       /* Apply SDWA. Do this after label_instruction() so it can remove
3245        * label_extract if not all instructions can take SDWA. */
3246       for (unsigned i = 0; i < instr->operands.size(); i++) {
3247          Operand& op = instr->operands[i];
3248          if (!op.isTemp())
3249             continue;
3250          ssa_info& info = ctx.info[op.tempId()];
3251          if (!info.is_extract())
3252             continue;
3253          /* if there are that many uses, there are likely better combinations */
3254          // TODO: delay applying extract to a point where we know better
3255          if (ctx.uses[op.tempId()] > 4) {
3256             info.label &= ~label_extract;
3257             continue;
3258          }
3259          if (info.is_extract() &&
3260              (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
3261               instr->operands[i].getTemp().type() == RegType::sgpr) &&
3262              can_apply_extract(ctx, instr, i, info)) {
3263             apply_extract(ctx, instr, i, info);
3264             ctx.uses[instr->operands[i].tempId()]--;
3265             instr->operands[i].setTemp(info.instr->operands[0].getTemp());
3266          }
3267       }
3268 
3269       if (can_apply_sgprs(ctx, instr))
3270          apply_sgprs(ctx, instr);
3271       while (apply_omod_clamp(ctx, instr))
3272          ;
3273       apply_insert(ctx, instr);
3274    }
3275 
3276    if (instr->isVOP3P())
3277       return combine_vop3p(ctx, instr);
3278 
3279    if (ctx.info[instr->definitions[0].tempId()].is_vcc_hint()) {
3280       instr->definitions[0].setHint(vcc);
3281    }
3282 
3283    if (instr->isSDWA() || instr->isDPP())
3284       return;
3285 
3286    if (instr->opcode == aco_opcode::p_extract)
3287       apply_ds_extract(ctx, instr);
3288 
3289    /* TODO: There are still some peephole optimizations that could be done:
3290     * - abs(a - b) -> s_absdiff_i32
3291     * - various patterns for s_bitcmp{0,1}_b32 and s_bitset{0,1}_b32
3292     * - patterns for v_alignbit_b32 and v_alignbyte_b32
3293     * These aren't probably too interesting though.
3294     * There are also patterns for v_cmp_class_f{16,32,64}. This is difficult but
3295     * probably more useful than the previously mentioned optimizations.
3296     * The various comparison optimizations also currently only work with 32-bit
3297     * floats. */
3298 
3299    /* neg(mul(a, b)) -> mul(neg(a), b) */
3300    if (ctx.info[instr->definitions[0].tempId()].is_neg() &&
3301        ctx.uses[instr->operands[1].tempId()] == 1) {
3302       Temp val = ctx.info[instr->definitions[0].tempId()].temp;
3303 
3304       if (!ctx.info[val.id()].is_mul())
3305          return;
3306 
3307       Instruction* mul_instr = ctx.info[val.id()].instr;
3308 
3309       if (mul_instr->operands[0].isLiteral())
3310          return;
3311       if (mul_instr->isVOP3() && mul_instr->vop3().clamp)
3312          return;
3313       if (mul_instr->isSDWA() || mul_instr->isDPP())
3314          return;
3315 
3316       /* convert to mul(neg(a), b) */
3317       ctx.uses[mul_instr->definitions[0].tempId()]--;
3318       Definition def = instr->definitions[0];
3319       /* neg(abs(mul(a, b))) -> mul(neg(abs(a)), abs(b)) */
3320       bool is_abs = ctx.info[instr->definitions[0].tempId()].is_abs();
3321       instr.reset(
3322          create_instruction<VOP3_instruction>(mul_instr->opcode, asVOP3(Format::VOP2), 2, 1));
3323       instr->operands[0] = mul_instr->operands[0];
3324       instr->operands[1] = mul_instr->operands[1];
3325       instr->definitions[0] = def;
3326       VOP3_instruction& new_mul = instr->vop3();
3327       if (mul_instr->isVOP3()) {
3328          VOP3_instruction& mul = mul_instr->vop3();
3329          new_mul.neg[0] = mul.neg[0];
3330          new_mul.neg[1] = mul.neg[1];
3331          new_mul.abs[0] = mul.abs[0];
3332          new_mul.abs[1] = mul.abs[1];
3333          new_mul.omod = mul.omod;
3334       }
3335       if (is_abs) {
3336          new_mul.neg[0] = new_mul.neg[1] = false;
3337          new_mul.abs[0] = new_mul.abs[1] = true;
3338       }
3339       new_mul.neg[0] ^= true;
3340       new_mul.clamp = false;
3341 
3342       ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
3343       return;
3344    }
3345 
3346    /* combine mul+add -> mad */
3347    bool mad32 = instr->opcode == aco_opcode::v_add_f32 || instr->opcode == aco_opcode::v_sub_f32 ||
3348                 instr->opcode == aco_opcode::v_subrev_f32;
3349    bool mad16 = instr->opcode == aco_opcode::v_add_f16 || instr->opcode == aco_opcode::v_sub_f16 ||
3350                 instr->opcode == aco_opcode::v_subrev_f16;
3351    bool mad64 = instr->opcode == aco_opcode::v_add_f64;
3352    if (mad16 || mad32 || mad64) {
3353       bool need_fma =
3354          mad32 ? (ctx.fp_mode.denorm32 != 0 || ctx.program->chip_class >= GFX10_3)
3355                : (ctx.fp_mode.denorm16_64 != 0 || ctx.program->chip_class >= GFX10 || mad64);
3356       if (need_fma && instr->definitions[0].isPrecise())
3357          return;
3358       if (need_fma && mad32 && !ctx.program->dev.has_fast_fma32)
3359          return;
3360 
3361       Instruction* mul_instr = nullptr;
3362       unsigned add_op_idx = 0;
3363       uint32_t uses = UINT32_MAX;
3364       /* find the 'best' mul instruction to combine with the add */
3365       for (unsigned i = 0; i < 2; i++) {
3366          if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_mul())
3367             continue;
3368          /* check precision requirements */
3369          ssa_info& info = ctx.info[instr->operands[i].tempId()];
3370          if (need_fma && info.instr->definitions[0].isPrecise())
3371             continue;
3372 
3373          /* no clamp/omod allowed between mul and add */
3374          if (info.instr->isVOP3() && (info.instr->vop3().clamp || info.instr->vop3().omod))
3375             continue;
3376 
3377          Operand op[3] = {info.instr->operands[0], info.instr->operands[1], instr->operands[1 - i]};
3378          if (info.instr->isSDWA() || info.instr->isDPP() || !check_vop3_operands(ctx, 3, op) ||
3379              ctx.uses[instr->operands[i].tempId()] >= uses)
3380             continue;
3381 
3382          mul_instr = info.instr;
3383          add_op_idx = 1 - i;
3384          uses = ctx.uses[instr->operands[i].tempId()];
3385       }
3386 
3387       if (mul_instr) {
3388          /* turn mul+add into v_mad/v_fma */
3389          Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1],
3390                           instr->operands[add_op_idx]};
3391          ctx.uses[mul_instr->definitions[0].tempId()]--;
3392          if (ctx.uses[mul_instr->definitions[0].tempId()]) {
3393             if (op[0].isTemp())
3394                ctx.uses[op[0].tempId()]++;
3395             if (op[1].isTemp())
3396                ctx.uses[op[1].tempId()]++;
3397          }
3398 
3399          bool neg[3] = {false, false, false};
3400          bool abs[3] = {false, false, false};
3401          unsigned omod = 0;
3402          bool clamp = false;
3403 
3404          if (mul_instr->isVOP3()) {
3405             VOP3_instruction& vop3 = mul_instr->vop3();
3406             neg[0] = vop3.neg[0];
3407             neg[1] = vop3.neg[1];
3408             abs[0] = vop3.abs[0];
3409             abs[1] = vop3.abs[1];
3410          }
3411 
3412          if (instr->isVOP3()) {
3413             VOP3_instruction& vop3 = instr->vop3();
3414             neg[2] = vop3.neg[add_op_idx];
3415             abs[2] = vop3.abs[add_op_idx];
3416             omod = vop3.omod;
3417             clamp = vop3.clamp;
3418             /* abs of the multiplication result */
3419             if (vop3.abs[1 - add_op_idx]) {
3420                neg[0] = false;
3421                neg[1] = false;
3422                abs[0] = true;
3423                abs[1] = true;
3424             }
3425             /* neg of the multiplication result */
3426             neg[1] = neg[1] ^ vop3.neg[1 - add_op_idx];
3427          }
3428          if (instr->opcode == aco_opcode::v_sub_f32 || instr->opcode == aco_opcode::v_sub_f16)
3429             neg[1 + add_op_idx] = neg[1 + add_op_idx] ^ true;
3430          else if (instr->opcode == aco_opcode::v_subrev_f32 ||
3431                   instr->opcode == aco_opcode::v_subrev_f16)
3432             neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
3433 
3434          aco_opcode mad_op = need_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
3435          if (mad16)
3436             mad_op = need_fma ? (ctx.program->chip_class == GFX8 ? aco_opcode::v_fma_legacy_f16
3437                                                                  : aco_opcode::v_fma_f16)
3438                               : (ctx.program->chip_class == GFX8 ? aco_opcode::v_mad_legacy_f16
3439                                                                  : aco_opcode::v_mad_f16);
3440          if (mad64)
3441             mad_op = aco_opcode::v_fma_f64;
3442 
3443          aco_ptr<VOP3_instruction> mad{
3444             create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 1)};
3445          for (unsigned i = 0; i < 3; i++) {
3446             mad->operands[i] = op[i];
3447             mad->neg[i] = neg[i];
3448             mad->abs[i] = abs[i];
3449          }
3450          mad->omod = omod;
3451          mad->clamp = clamp;
3452          mad->definitions[0] = instr->definitions[0];
3453 
3454          /* mark this ssa_def to be re-checked for profitability and literals */
3455          ctx.mad_infos.emplace_back(std::move(instr), mul_instr->definitions[0].tempId());
3456          ctx.info[mad->definitions[0].tempId()].set_mad(mad.get(), ctx.mad_infos.size() - 1);
3457          instr = std::move(mad);
3458          return;
3459       }
3460    }
3461    /* v_mul_f32(v_cndmask_b32(0, 1.0, cond), a) -> v_cndmask_b32(0, a, cond) */
3462    else if (instr->opcode == aco_opcode::v_mul_f32 && !instr->isVOP3()) {
3463       for (unsigned i = 0; i < 2; i++) {
3464          if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2f() &&
3465              ctx.uses[instr->operands[i].tempId()] == 1 && instr->operands[!i].isTemp() &&
3466              instr->operands[!i].getTemp().type() == RegType::vgpr) {
3467             ctx.uses[instr->operands[i].tempId()]--;
3468             ctx.uses[ctx.info[instr->operands[i].tempId()].temp.id()]++;
3469 
3470             aco_ptr<VOP2_instruction> new_instr{
3471                create_instruction<VOP2_instruction>(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1)};
3472             new_instr->operands[0] = Operand::zero();
3473             new_instr->operands[1] = instr->operands[!i];
3474             new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
3475             new_instr->definitions[0] = instr->definitions[0];
3476             instr = std::move(new_instr);
3477             ctx.info[instr->definitions[0].tempId()].label = 0;
3478             return;
3479          }
3480       }
3481    } else if (instr->opcode == aco_opcode::v_or_b32 && ctx.program->chip_class >= GFX9) {
3482       if (combine_three_valu_op(ctx, instr, aco_opcode::s_or_b32, aco_opcode::v_or3_b32, "012",
3483                                 1 | 2)) {
3484       } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_or_b32, aco_opcode::v_or3_b32,
3485                                        "012", 1 | 2)) {
3486       } else if (combine_add_or_then_and_lshl(ctx, instr)) {
3487       }
3488    } else if (instr->opcode == aco_opcode::v_xor_b32 && ctx.program->chip_class >= GFX10) {
3489       if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xor3_b32, "012",
3490                                 1 | 2)) {
3491       } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xor3_b32,
3492                                        "012", 1 | 2)) {
3493       }
3494    } else if (instr->opcode == aco_opcode::v_add_u16) {
3495       combine_three_valu_op(
3496          ctx, instr, aco_opcode::v_mul_lo_u16,
3497          ctx.program->chip_class == GFX8 ? aco_opcode::v_mad_legacy_u16 : aco_opcode::v_mad_u16,
3498          "120", 1 | 2);
3499    } else if (instr->opcode == aco_opcode::v_add_u16_e64) {
3500       combine_three_valu_op(ctx, instr, aco_opcode::v_mul_lo_u16_e64, aco_opcode::v_mad_u16, "120",
3501                             1 | 2);
3502    } else if (instr->opcode == aco_opcode::v_add_u32) {
3503       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
3504       } else if (combine_add_bcnt(ctx, instr)) {
3505       } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
3506                                        aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
3507       } else if (ctx.program->chip_class >= GFX9 && !instr->usesModifiers()) {
3508          if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xad_u32, "120",
3509                                    1 | 2)) {
3510          } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xad_u32,
3511                                           "120", 1 | 2)) {
3512          } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_i32, aco_opcode::v_add3_u32,
3513                                           "012", 1 | 2)) {
3514          } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_u32, aco_opcode::v_add3_u32,
3515                                           "012", 1 | 2)) {
3516          } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add3_u32,
3517                                           "012", 1 | 2)) {
3518          } else if (combine_add_or_then_and_lshl(ctx, instr)) {
3519          }
3520       }
3521    } else if (instr->opcode == aco_opcode::v_add_co_u32 ||
3522               instr->opcode == aco_opcode::v_add_co_u32_e64) {
3523       bool carry_out = ctx.uses[instr->definitions[1].tempId()] > 0;
3524       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
3525       } else if (!carry_out && combine_add_bcnt(ctx, instr)) {
3526       } else if (!carry_out && combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
3527                                                      aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
3528       } else if (!carry_out && combine_add_lshl(ctx, instr, false)) {
3529       }
3530    } else if (instr->opcode == aco_opcode::v_sub_u32 || instr->opcode == aco_opcode::v_sub_co_u32 ||
3531               instr->opcode == aco_opcode::v_sub_co_u32_e64) {
3532       bool carry_out =
3533          instr->opcode != aco_opcode::v_sub_u32 && ctx.uses[instr->definitions[1].tempId()] > 0;
3534       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 2)) {
3535       } else if (!carry_out && combine_add_lshl(ctx, instr, true)) {
3536       }
3537    } else if (instr->opcode == aco_opcode::v_subrev_u32 ||
3538               instr->opcode == aco_opcode::v_subrev_co_u32 ||
3539               instr->opcode == aco_opcode::v_subrev_co_u32_e64) {
3540       combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 1);
3541    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && ctx.program->chip_class >= GFX9) {
3542       combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add_lshl_u32, "120",
3543                             2);
3544    } else if ((instr->opcode == aco_opcode::s_add_u32 || instr->opcode == aco_opcode::s_add_i32) &&
3545               ctx.program->chip_class >= GFX9) {
3546       combine_salu_lshl_add(ctx, instr);
3547    } else if (instr->opcode == aco_opcode::s_not_b32 || instr->opcode == aco_opcode::s_not_b64) {
3548       combine_salu_not_bitwise(ctx, instr);
3549    } else if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_or_b32 ||
3550               instr->opcode == aco_opcode::s_and_b64 || instr->opcode == aco_opcode::s_or_b64) {
3551       if (combine_ordering_test(ctx, instr)) {
3552       } else if (combine_comparison_ordering(ctx, instr)) {
3553       } else if (combine_constant_comparison_ordering(ctx, instr)) {
3554       } else if (combine_salu_n2(ctx, instr)) {
3555       }
3556    } else if (instr->opcode == aco_opcode::v_and_b32) {
3557       combine_and_subbrev(ctx, instr);
3558    } else {
3559       aco_opcode min, max, min3, max3, med3;
3560       bool some_gfx9_only;
3561       if (get_minmax_info(instr->opcode, &min, &max, &min3, &max3, &med3, &some_gfx9_only) &&
3562           (!some_gfx9_only || ctx.program->chip_class >= GFX9)) {
3563          if (combine_minmax(ctx, instr, instr->opcode == min ? max : min,
3564                             instr->opcode == min ? min3 : max3)) {
3565          } else {
3566             combine_clamp(ctx, instr, min, max, med3);
3567          }
3568       }
3569    }
3570 
3571    /* do this after combine_salu_n2() */
3572    if (instr->opcode == aco_opcode::s_andn2_b32 || instr->opcode == aco_opcode::s_andn2_b64)
3573       combine_inverse_comparison(ctx, instr);
3574 }
3575 
3576 bool
to_uniform_bool_instr(opt_ctx & ctx,aco_ptr<Instruction> & instr)3577 to_uniform_bool_instr(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3578 {
3579    /* Check every operand to make sure they are suitable. */
3580    for (Operand& op : instr->operands) {
3581       if (!op.isTemp())
3582          return false;
3583       if (!ctx.info[op.tempId()].is_uniform_bool() && !ctx.info[op.tempId()].is_uniform_bitwise())
3584          return false;
3585    }
3586 
3587    switch (instr->opcode) {
3588    case aco_opcode::s_and_b32:
3589    case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_and_b32; break;
3590    case aco_opcode::s_or_b32:
3591    case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_or_b32; break;
3592    case aco_opcode::s_xor_b32:
3593    case aco_opcode::s_xor_b64: instr->opcode = aco_opcode::s_absdiff_i32; break;
3594    default:
3595       /* Don't transform other instructions. They are very unlikely to appear here. */
3596       return false;
3597    }
3598 
3599    for (Operand& op : instr->operands) {
3600       ctx.uses[op.tempId()]--;
3601 
3602       if (ctx.info[op.tempId()].is_uniform_bool()) {
3603          /* Just use the uniform boolean temp. */
3604          op.setTemp(ctx.info[op.tempId()].temp);
3605       } else if (ctx.info[op.tempId()].is_uniform_bitwise()) {
3606          /* Use the SCC definition of the predecessor instruction.
3607           * This allows the predecessor to get picked up by the same optimization (if it has no
3608           * divergent users), and it also makes sure that the current instruction will keep working
3609           * even if the predecessor won't be transformed.
3610           */
3611          Instruction* pred_instr = ctx.info[op.tempId()].instr;
3612          assert(pred_instr->definitions.size() >= 2);
3613          assert(pred_instr->definitions[1].isFixed() &&
3614                 pred_instr->definitions[1].physReg() == scc);
3615          op.setTemp(pred_instr->definitions[1].getTemp());
3616       } else {
3617          unreachable("Invalid operand on uniform bitwise instruction.");
3618       }
3619 
3620       ctx.uses[op.tempId()]++;
3621    }
3622 
3623    instr->definitions[0].setTemp(Temp(instr->definitions[0].tempId(), s1));
3624    assert(instr->operands[0].regClass() == s1);
3625    assert(instr->operands[1].regClass() == s1);
3626    return true;
3627 }
3628 
3629 void
select_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)3630 select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3631 {
3632    const uint32_t threshold = 4;
3633 
3634    if (is_dead(ctx.uses, instr.get())) {
3635       instr.reset();
3636       return;
3637    }
3638 
3639    /* convert split_vector into a copy or extract_vector if only one definition is ever used */
3640    if (instr->opcode == aco_opcode::p_split_vector) {
3641       unsigned num_used = 0;
3642       unsigned idx = 0;
3643       unsigned split_offset = 0;
3644       for (unsigned i = 0, offset = 0; i < instr->definitions.size();
3645            offset += instr->definitions[i++].bytes()) {
3646          if (ctx.uses[instr->definitions[i].tempId()]) {
3647             num_used++;
3648             idx = i;
3649             split_offset = offset;
3650          }
3651       }
3652       bool done = false;
3653       if (num_used == 1 && ctx.info[instr->operands[0].tempId()].is_vec() &&
3654           ctx.uses[instr->operands[0].tempId()] == 1) {
3655          Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
3656 
3657          unsigned off = 0;
3658          Operand op;
3659          for (Operand& vec_op : vec->operands) {
3660             if (off == split_offset) {
3661                op = vec_op;
3662                break;
3663             }
3664             off += vec_op.bytes();
3665          }
3666          if (off != instr->operands[0].bytes() && op.bytes() == instr->definitions[idx].bytes()) {
3667             ctx.uses[instr->operands[0].tempId()]--;
3668             for (Operand& vec_op : vec->operands) {
3669                if (vec_op.isTemp())
3670                   ctx.uses[vec_op.tempId()]--;
3671             }
3672             if (op.isTemp())
3673                ctx.uses[op.tempId()]++;
3674 
3675             aco_ptr<Pseudo_instruction> extract{create_instruction<Pseudo_instruction>(
3676                aco_opcode::p_create_vector, Format::PSEUDO, 1, 1)};
3677             extract->operands[0] = op;
3678             extract->definitions[0] = instr->definitions[idx];
3679             instr = std::move(extract);
3680 
3681             done = true;
3682          }
3683       }
3684 
3685       if (!done && num_used == 1 &&
3686           instr->operands[0].bytes() % instr->definitions[idx].bytes() == 0 &&
3687           split_offset % instr->definitions[idx].bytes() == 0) {
3688          aco_ptr<Pseudo_instruction> extract{create_instruction<Pseudo_instruction>(
3689             aco_opcode::p_extract_vector, Format::PSEUDO, 2, 1)};
3690          extract->operands[0] = instr->operands[0];
3691          extract->operands[1] =
3692             Operand::c32((uint32_t)split_offset / instr->definitions[idx].bytes());
3693          extract->definitions[0] = instr->definitions[idx];
3694          instr = std::move(extract);
3695       }
3696    }
3697 
3698    mad_info* mad_info = NULL;
3699    if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
3700       mad_info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].instr->pass_flags];
3701       /* re-check mad instructions */
3702       if (ctx.uses[mad_info->mul_temp_id] && mad_info->add_instr) {
3703          ctx.uses[mad_info->mul_temp_id]++;
3704          if (instr->operands[0].isTemp())
3705             ctx.uses[instr->operands[0].tempId()]--;
3706          if (instr->operands[1].isTemp())
3707             ctx.uses[instr->operands[1].tempId()]--;
3708          instr.swap(mad_info->add_instr);
3709          mad_info = NULL;
3710       }
3711       /* check literals */
3712       else if (!instr->usesModifiers() && instr->opcode != aco_opcode::v_fma_f64) {
3713          /* FMA can only take literals on GFX10+ */
3714          if ((instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) &&
3715              ctx.program->chip_class < GFX10)
3716             return;
3717          /* There are no v_fmaak_legacy_f16/v_fmamk_legacy_f16 and on chips where VOP3 can take
3718           * literals (GFX10+), these instructions don't exist.
3719           */
3720          if (instr->opcode == aco_opcode::v_fma_legacy_f16)
3721             return;
3722 
3723          bool sgpr_used = false;
3724          uint32_t literal_idx = 0;
3725          uint32_t literal_uses = UINT32_MAX;
3726          for (unsigned i = 0; i < instr->operands.size(); i++) {
3727             if (instr->operands[i].isConstant() && i > 0) {
3728                literal_uses = UINT32_MAX;
3729                break;
3730             }
3731             if (!instr->operands[i].isTemp())
3732                continue;
3733             unsigned bits = get_operand_size(instr, i);
3734             /* if one of the operands is sgpr, we cannot add a literal somewhere else on pre-GFX10
3735              * or operands other than the 1st */
3736             if (instr->operands[i].getTemp().type() == RegType::sgpr &&
3737                 (i > 0 || ctx.program->chip_class < GFX10)) {
3738                if (!sgpr_used && ctx.info[instr->operands[i].tempId()].is_literal(bits)) {
3739                   literal_uses = ctx.uses[instr->operands[i].tempId()];
3740                   literal_idx = i;
3741                } else {
3742                   literal_uses = UINT32_MAX;
3743                }
3744                sgpr_used = true;
3745                /* don't break because we still need to check constants */
3746             } else if (!sgpr_used && ctx.info[instr->operands[i].tempId()].is_literal(bits) &&
3747                        ctx.uses[instr->operands[i].tempId()] < literal_uses) {
3748                literal_uses = ctx.uses[instr->operands[i].tempId()];
3749                literal_idx = i;
3750             }
3751          }
3752 
3753          /* Limit the number of literals to apply to not increase the code
3754           * size too much, but always apply literals for v_mad->v_madak
3755           * because both instructions are 64-bit and this doesn't increase
3756           * code size.
3757           * TODO: try to apply the literals earlier to lower the number of
3758           * uses below threshold
3759           */
3760          if (literal_uses < threshold || literal_idx == 2) {
3761             ctx.uses[instr->operands[literal_idx].tempId()]--;
3762             mad_info->check_literal = true;
3763             mad_info->literal_idx = literal_idx;
3764             return;
3765          }
3766       }
3767    }
3768 
3769    /* Mark SCC needed, so the uniform boolean transformation won't swap the definitions
3770     * when it isn't beneficial */
3771    if (instr->isBranch() && instr->operands.size() && instr->operands[0].isTemp() &&
3772        instr->operands[0].isFixed() && instr->operands[0].physReg() == scc) {
3773       ctx.info[instr->operands[0].tempId()].set_scc_needed();
3774       return;
3775    } else if ((instr->opcode == aco_opcode::s_cselect_b64 ||
3776                instr->opcode == aco_opcode::s_cselect_b32) &&
3777               instr->operands[2].isTemp()) {
3778       ctx.info[instr->operands[2].tempId()].set_scc_needed();
3779    } else if (instr->opcode == aco_opcode::p_wqm && instr->operands[0].isTemp() &&
3780               ctx.info[instr->definitions[0].tempId()].is_scc_needed()) {
3781       /* Propagate label so it is correctly detected by the uniform bool transform */
3782       ctx.info[instr->operands[0].tempId()].set_scc_needed();
3783 
3784       /* Fix definition to SCC, this will prevent RA from adding superfluous moves */
3785       instr->definitions[0].setFixed(scc);
3786    }
3787 
3788    /* check for literals */
3789    if (!instr->isSALU() && !instr->isVALU())
3790       return;
3791 
3792    /* Transform uniform bitwise boolean operations to 32-bit when there are no divergent uses. */
3793    if (instr->definitions.size() && ctx.uses[instr->definitions[0].tempId()] == 0 &&
3794        ctx.info[instr->definitions[0].tempId()].is_uniform_bitwise()) {
3795       bool transform_done = to_uniform_bool_instr(ctx, instr);
3796 
3797       if (transform_done && !ctx.info[instr->definitions[1].tempId()].is_scc_needed()) {
3798          /* Swap the two definition IDs in order to avoid overusing the SCC.
3799           * This reduces extra moves generated by RA. */
3800          uint32_t def0_id = instr->definitions[0].getTemp().id();
3801          uint32_t def1_id = instr->definitions[1].getTemp().id();
3802          instr->definitions[0].setTemp(Temp(def1_id, s1));
3803          instr->definitions[1].setTemp(Temp(def0_id, s1));
3804       }
3805 
3806       return;
3807    }
3808 
3809    /* Combine DPP copies into VALU. This should be done after creating MAD/FMA. */
3810    if (instr->isVALU()) {
3811       for (unsigned i = 0; i < instr->operands.size(); i++) {
3812          if (!instr->operands[i].isTemp())
3813             continue;
3814          ssa_info info = ctx.info[instr->operands[i].tempId()];
3815 
3816          aco_opcode swapped_op;
3817          if (info.is_dpp() && info.instr->pass_flags == instr->pass_flags &&
3818              (i == 0 || can_swap_operands(instr, &swapped_op)) && can_use_DPP(instr, true) &&
3819              !instr->isDPP()) {
3820             convert_to_DPP(instr);
3821             DPP_instruction* dpp = static_cast<DPP_instruction*>(instr.get());
3822             if (i) {
3823                instr->opcode = swapped_op;
3824                std::swap(instr->operands[0], instr->operands[1]);
3825                std::swap(dpp->neg[0], dpp->neg[1]);
3826                std::swap(dpp->abs[0], dpp->abs[1]);
3827             }
3828             if (--ctx.uses[info.instr->definitions[0].tempId()])
3829                ctx.uses[info.instr->operands[0].tempId()]++;
3830             instr->operands[0].setTemp(info.instr->operands[0].getTemp());
3831             dpp->dpp_ctrl = info.instr->dpp().dpp_ctrl;
3832             dpp->bound_ctrl = info.instr->dpp().bound_ctrl;
3833             dpp->neg[0] ^= info.instr->dpp().neg[0] && !dpp->abs[0];
3834             dpp->abs[0] |= info.instr->dpp().abs[0];
3835             break;
3836          }
3837       }
3838    }
3839 
3840    if (instr->isSDWA() || (instr->isVOP3() && ctx.program->chip_class < GFX10) ||
3841        (instr->isVOP3P() && ctx.program->chip_class < GFX10))
3842       return; /* some encodings can't ever take literals */
3843 
3844    /* we do not apply the literals yet as we don't know if it is profitable */
3845    Operand current_literal(s1);
3846 
3847    unsigned literal_id = 0;
3848    unsigned literal_uses = UINT32_MAX;
3849    Operand literal(s1);
3850    unsigned num_operands = 1;
3851    if (instr->isSALU() ||
3852        (ctx.program->chip_class >= GFX10 && (can_use_VOP3(ctx, instr) || instr->isVOP3P())))
3853       num_operands = instr->operands.size();
3854    /* catch VOP2 with a 3rd SGPR operand (e.g. v_cndmask_b32, v_addc_co_u32) */
3855    else if (instr->isVALU() && instr->operands.size() >= 3)
3856       return;
3857 
3858    unsigned sgpr_ids[2] = {0, 0};
3859    bool is_literal_sgpr = false;
3860    uint32_t mask = 0;
3861 
3862    /* choose a literal to apply */
3863    for (unsigned i = 0; i < num_operands; i++) {
3864       Operand op = instr->operands[i];
3865       unsigned bits = get_operand_size(instr, i);
3866 
3867       if (instr->isVALU() && op.isTemp() && op.getTemp().type() == RegType::sgpr &&
3868           op.tempId() != sgpr_ids[0])
3869          sgpr_ids[!!sgpr_ids[0]] = op.tempId();
3870 
3871       if (op.isLiteral()) {
3872          current_literal = op;
3873          continue;
3874       } else if (!op.isTemp() || !ctx.info[op.tempId()].is_literal(bits)) {
3875          continue;
3876       }
3877 
3878       if (!alu_can_accept_constant(instr->opcode, i))
3879          continue;
3880 
3881       if (ctx.uses[op.tempId()] < literal_uses) {
3882          is_literal_sgpr = op.getTemp().type() == RegType::sgpr;
3883          mask = 0;
3884          literal = Operand::c32(ctx.info[op.tempId()].val);
3885          literal_uses = ctx.uses[op.tempId()];
3886          literal_id = op.tempId();
3887       }
3888 
3889       mask |= (op.tempId() == literal_id) << i;
3890    }
3891 
3892    /* don't go over the constant bus limit */
3893    bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64 ||
3894                      instr->opcode == aco_opcode::v_lshrrev_b64 ||
3895                      instr->opcode == aco_opcode::v_ashrrev_i64;
3896    unsigned const_bus_limit = instr->isVALU() ? 1 : UINT32_MAX;
3897    if (ctx.program->chip_class >= GFX10 && !is_shift64)
3898       const_bus_limit = 2;
3899 
3900    unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
3901    if (num_sgprs == const_bus_limit && !is_literal_sgpr)
3902       return;
3903 
3904    if (literal_id && literal_uses < threshold &&
3905        (current_literal.isUndefined() ||
3906         (current_literal.size() == literal.size() &&
3907          current_literal.constantValue() == literal.constantValue()))) {
3908       /* mark the literal to be applied */
3909       while (mask) {
3910          unsigned i = u_bit_scan(&mask);
3911          if (instr->operands[i].isTemp() && instr->operands[i].tempId() == literal_id)
3912             ctx.uses[instr->operands[i].tempId()]--;
3913       }
3914    }
3915 }
3916 
3917 void
apply_literals(opt_ctx & ctx,aco_ptr<Instruction> & instr)3918 apply_literals(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3919 {
3920    /* Cleanup Dead Instructions */
3921    if (!instr)
3922       return;
3923 
3924    /* apply literals on MAD */
3925    if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
3926       mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].instr->pass_flags];
3927       if (info->check_literal &&
3928           (ctx.uses[instr->operands[info->literal_idx].tempId()] == 0 || info->literal_idx == 2)) {
3929          aco_ptr<Instruction> new_mad;
3930 
3931          aco_opcode new_op =
3932             info->literal_idx == 2 ? aco_opcode::v_madak_f32 : aco_opcode::v_madmk_f32;
3933          if (instr->opcode == aco_opcode::v_fma_f32)
3934             new_op = info->literal_idx == 2 ? aco_opcode::v_fmaak_f32 : aco_opcode::v_fmamk_f32;
3935          else if (instr->opcode == aco_opcode::v_mad_f16 ||
3936                   instr->opcode == aco_opcode::v_mad_legacy_f16)
3937             new_op = info->literal_idx == 2 ? aco_opcode::v_madak_f16 : aco_opcode::v_madmk_f16;
3938          else if (instr->opcode == aco_opcode::v_fma_f16)
3939             new_op = info->literal_idx == 2 ? aco_opcode::v_fmaak_f16 : aco_opcode::v_fmamk_f16;
3940 
3941          new_mad.reset(create_instruction<VOP2_instruction>(new_op, Format::VOP2, 3, 1));
3942          if (info->literal_idx == 2) { /* add literal -> madak */
3943             new_mad->operands[0] = instr->operands[0];
3944             new_mad->operands[1] = instr->operands[1];
3945          } else { /* mul literal -> madmk */
3946             new_mad->operands[0] = instr->operands[1 - info->literal_idx];
3947             new_mad->operands[1] = instr->operands[2];
3948          }
3949          new_mad->operands[2] =
3950             Operand::c32(ctx.info[instr->operands[info->literal_idx].tempId()].val);
3951          new_mad->definitions[0] = instr->definitions[0];
3952          ctx.instructions.emplace_back(std::move(new_mad));
3953          return;
3954       }
3955    }
3956 
3957    /* apply literals on other SALU/VALU */
3958    if (instr->isSALU() || instr->isVALU()) {
3959       for (unsigned i = 0; i < instr->operands.size(); i++) {
3960          Operand op = instr->operands[i];
3961          unsigned bits = get_operand_size(instr, i);
3962          if (op.isTemp() && ctx.info[op.tempId()].is_literal(bits) && ctx.uses[op.tempId()] == 0) {
3963             Operand literal = Operand::c32(ctx.info[op.tempId()].val);
3964             instr->format = withoutDPP(instr->format);
3965             if (instr->isVALU() && i > 0 && instr->format != Format::VOP3P)
3966                to_VOP3(ctx, instr);
3967             instr->operands[i] = literal;
3968          }
3969       }
3970    }
3971 
3972    ctx.instructions.emplace_back(std::move(instr));
3973 }
3974 
3975 void
optimize(Program * program)3976 optimize(Program* program)
3977 {
3978    opt_ctx ctx;
3979    ctx.program = program;
3980    std::vector<ssa_info> info(program->peekAllocationId());
3981    ctx.info = info.data();
3982 
3983    /* 1. Bottom-Up DAG pass (forward) to label all ssa-defs */
3984    for (Block& block : program->blocks) {
3985       ctx.fp_mode = block.fp_mode;
3986       for (aco_ptr<Instruction>& instr : block.instructions)
3987          label_instruction(ctx, instr);
3988    }
3989 
3990    ctx.uses = dead_code_analysis(program);
3991 
3992    /* 2. Combine v_mad, omod, clamp and propagate sgpr on VALU instructions */
3993    for (Block& block : program->blocks) {
3994       ctx.fp_mode = block.fp_mode;
3995       for (aco_ptr<Instruction>& instr : block.instructions)
3996          combine_instruction(ctx, instr);
3997    }
3998 
3999    /* 3. Top-Down DAG pass (backward) to select instructions (includes DCE) */
4000    for (auto block_rit = program->blocks.rbegin(); block_rit != program->blocks.rend();
4001         ++block_rit) {
4002       Block* block = &(*block_rit);
4003       ctx.fp_mode = block->fp_mode;
4004       for (auto instr_rit = block->instructions.rbegin(); instr_rit != block->instructions.rend();
4005            ++instr_rit)
4006          select_instruction(ctx, *instr_rit);
4007    }
4008 
4009    /* 4. Add literals to instructions */
4010    for (Block& block : program->blocks) {
4011       ctx.instructions.clear();
4012       ctx.fp_mode = block.fp_mode;
4013       for (aco_ptr<Instruction>& instr : block.instructions)
4014          apply_literals(ctx, instr);
4015       block.instructions.swap(ctx.instructions);
4016    }
4017 }
4018 
4019 } // namespace aco
4020