1 /*
2  * Copyright © 2018 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  */
24 
25 #include "aco_builder.h"
26 #include "aco_ir.h"
27 
28 #include "util/half_float.h"
29 #include "util/memstream.h"
30 
31 #include <algorithm>
32 #include <array>
33 #include <vector>
34 
35 namespace aco {
36 
37 #ifndef NDEBUG
38 void
perfwarn(Program * program,bool cond,const char * msg,Instruction * instr)39 perfwarn(Program* program, bool cond, const char* msg, Instruction* instr)
40 {
41    if (cond) {
42       char* out;
43       size_t outsize;
44       struct u_memstream mem;
45       u_memstream_open(&mem, &out, &outsize);
46       FILE* const memf = u_memstream_get(&mem);
47 
48       fprintf(memf, "%s: ", msg);
49       aco_print_instr(instr, memf);
50       u_memstream_close(&mem);
51 
52       aco_perfwarn(program, out);
53       free(out);
54 
55       if (debug_flags & DEBUG_PERFWARN)
56          exit(1);
57    }
58 }
59 #endif
60 
61 /**
62  * The optimizer works in 4 phases:
63  * (1) The first pass collects information for each ssa-def,
64  *     propagates reg->reg operands of the same type, inline constants
65  *     and neg/abs input modifiers.
66  * (2) The second pass combines instructions like mad, omod, clamp and
67  *     propagates sgpr's on VALU instructions.
68  *     This pass depends on information collected in the first pass.
69  * (3) The third pass goes backwards, and selects instructions,
70  *     i.e. decides if a mad instruction is profitable and eliminates dead code.
71  * (4) The fourth pass cleans up the sequence: literals get applied and dead
72  *     instructions are removed from the sequence.
73  */
74 
75 struct mad_info {
76    aco_ptr<Instruction> add_instr;
77    uint32_t mul_temp_id;
78    uint16_t literal_idx;
79    bool check_literal;
80 
mad_infoaco::mad_info81    mad_info(aco_ptr<Instruction> instr, uint32_t id)
82        : add_instr(std::move(instr)), mul_temp_id(id), literal_idx(0), check_literal(false)
83    {}
84 };
85 
86 enum Label {
87    label_vec = 1 << 0,
88    label_constant_32bit = 1 << 1,
89    /* label_{abs,neg,mul,omod2,omod4,omod5,clamp} are used for both 16 and
90     * 32-bit operations but this shouldn't cause any issues because we don't
91     * look through any conversions */
92    label_abs = 1 << 2,
93    label_neg = 1 << 3,
94    label_mul = 1 << 4,
95    label_temp = 1 << 5,
96    label_literal = 1 << 6,
97    label_mad = 1 << 7,
98    label_omod2 = 1 << 8,
99    label_omod4 = 1 << 9,
100    label_omod5 = 1 << 10,
101    label_clamp = 1 << 12,
102    label_undefined = 1 << 14,
103    label_vcc = 1 << 15,
104    label_b2f = 1 << 16,
105    label_add_sub = 1 << 17,
106    label_bitwise = 1 << 18,
107    label_minmax = 1 << 19,
108    label_vopc = 1 << 20,
109    label_uniform_bool = 1 << 21,
110    label_constant_64bit = 1 << 22,
111    label_uniform_bitwise = 1 << 23,
112    label_scc_invert = 1 << 24,
113    label_vcc_hint = 1 << 25,
114    label_scc_needed = 1 << 26,
115    label_b2i = 1 << 27,
116    label_fcanonicalize = 1 << 28,
117    label_constant_16bit = 1 << 29,
118    label_usedef = 1 << 30,   /* generic label */
119    label_vop3p = 1ull << 31, /* 1ull to prevent sign extension */
120    label_canonicalized = 1ull << 32,
121    label_extract = 1ull << 33,
122    label_insert = 1ull << 34,
123    label_dpp16 = 1ull << 35,
124    label_dpp8 = 1ull << 36,
125    label_f2f32 = 1ull << 37,
126    label_f2f16 = 1ull << 38,
127 };
128 
129 static constexpr uint64_t instr_usedef_labels =
130    label_vec | label_mul | label_mad | label_add_sub | label_vop3p | label_bitwise |
131    label_uniform_bitwise | label_minmax | label_vopc | label_usedef | label_extract | label_dpp16 |
132    label_dpp8 | label_f2f32;
133 static constexpr uint64_t instr_mod_labels =
134    label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert | label_f2f16;
135 
136 static constexpr uint64_t instr_labels = instr_usedef_labels | instr_mod_labels;
137 static constexpr uint64_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f |
138                                         label_uniform_bool | label_scc_invert | label_b2i |
139                                         label_fcanonicalize;
140 static constexpr uint32_t val_labels =
141    label_constant_32bit | label_constant_64bit | label_constant_16bit | label_literal;
142 
143 static_assert((instr_labels & temp_labels) == 0, "labels cannot intersect");
144 static_assert((instr_labels & val_labels) == 0, "labels cannot intersect");
145 static_assert((temp_labels & val_labels) == 0, "labels cannot intersect");
146 
147 struct ssa_info {
148    uint64_t label;
149    union {
150       uint32_t val;
151       Temp temp;
152       Instruction* instr;
153    };
154 
ssa_infoaco::ssa_info155    ssa_info() : label(0) {}
156 
add_labelaco::ssa_info157    void add_label(Label new_label)
158    {
159       /* Since all the instr_usedef_labels use instr for the same thing
160        * (indicating the defining instruction), there is usually no need to
161        * clear any other instr labels. */
162       if (new_label & instr_usedef_labels)
163          label &= ~(instr_mod_labels | temp_labels | val_labels); /* instr, temp and val alias */
164 
165       if (new_label & instr_mod_labels) {
166          label &= ~instr_labels;
167          label &= ~(temp_labels | val_labels); /* instr, temp and val alias */
168       }
169 
170       if (new_label & temp_labels) {
171          label &= ~temp_labels;
172          label &= ~(instr_labels | val_labels); /* instr, temp and val alias */
173       }
174 
175       uint32_t const_labels =
176          label_literal | label_constant_32bit | label_constant_64bit | label_constant_16bit;
177       if (new_label & const_labels) {
178          label &= ~val_labels | const_labels;
179          label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
180       } else if (new_label & val_labels) {
181          label &= ~val_labels;
182          label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
183       }
184 
185       label |= new_label;
186    }
187 
set_vecaco::ssa_info188    void set_vec(Instruction* vec)
189    {
190       add_label(label_vec);
191       instr = vec;
192    }
193 
is_vecaco::ssa_info194    bool is_vec() { return label & label_vec; }
195 
set_constantaco::ssa_info196    void set_constant(chip_class chip, uint64_t constant)
197    {
198       Operand op16 = Operand::c16(constant);
199       Operand op32 = Operand::get_const(chip, constant, 4);
200       add_label(label_literal);
201       val = constant;
202 
203       /* check that no upper bits are lost in case of packed 16bit constants */
204       if (chip >= GFX8 && !op16.isLiteral() && op16.constantValue64() == constant)
205          add_label(label_constant_16bit);
206 
207       if (!op32.isLiteral())
208          add_label(label_constant_32bit);
209 
210       if (Operand::is_constant_representable(constant, 8))
211          add_label(label_constant_64bit);
212 
213       if (label & label_constant_64bit) {
214          val = Operand::c64(constant).constantValue();
215          if (val != constant)
216             label &= ~(label_literal | label_constant_16bit | label_constant_32bit);
217       }
218    }
219 
is_constantaco::ssa_info220    bool is_constant(unsigned bits)
221    {
222       switch (bits) {
223       case 8: return label & label_literal;
224       case 16: return label & label_constant_16bit;
225       case 32: return label & label_constant_32bit;
226       case 64: return label & label_constant_64bit;
227       }
228       return false;
229    }
230 
is_literalaco::ssa_info231    bool is_literal(unsigned bits)
232    {
233       bool is_lit = label & label_literal;
234       switch (bits) {
235       case 8: return false;
236       case 16: return is_lit && ~(label & label_constant_16bit);
237       case 32: return is_lit && ~(label & label_constant_32bit);
238       case 64: return false;
239       }
240       return false;
241    }
242 
is_constant_or_literalaco::ssa_info243    bool is_constant_or_literal(unsigned bits)
244    {
245       if (bits == 64)
246          return label & label_constant_64bit;
247       else
248          return label & label_literal;
249    }
250 
set_absaco::ssa_info251    void set_abs(Temp abs_temp)
252    {
253       add_label(label_abs);
254       temp = abs_temp;
255    }
256 
is_absaco::ssa_info257    bool is_abs() { return label & label_abs; }
258 
set_negaco::ssa_info259    void set_neg(Temp neg_temp)
260    {
261       add_label(label_neg);
262       temp = neg_temp;
263    }
264 
is_negaco::ssa_info265    bool is_neg() { return label & label_neg; }
266 
set_neg_absaco::ssa_info267    void set_neg_abs(Temp neg_abs_temp)
268    {
269       add_label((Label)((uint32_t)label_abs | (uint32_t)label_neg));
270       temp = neg_abs_temp;
271    }
272 
set_mulaco::ssa_info273    void set_mul(Instruction* mul)
274    {
275       add_label(label_mul);
276       instr = mul;
277    }
278 
is_mulaco::ssa_info279    bool is_mul() { return label & label_mul; }
280 
set_tempaco::ssa_info281    void set_temp(Temp tmp)
282    {
283       add_label(label_temp);
284       temp = tmp;
285    }
286 
is_tempaco::ssa_info287    bool is_temp() { return label & label_temp; }
288 
set_madaco::ssa_info289    void set_mad(Instruction* mad, uint32_t mad_info_idx)
290    {
291       add_label(label_mad);
292       mad->pass_flags = mad_info_idx;
293       instr = mad;
294    }
295 
is_madaco::ssa_info296    bool is_mad() { return label & label_mad; }
297 
set_omod2aco::ssa_info298    void set_omod2(Instruction* mul)
299    {
300       add_label(label_omod2);
301       instr = mul;
302    }
303 
is_omod2aco::ssa_info304    bool is_omod2() { return label & label_omod2; }
305 
set_omod4aco::ssa_info306    void set_omod4(Instruction* mul)
307    {
308       add_label(label_omod4);
309       instr = mul;
310    }
311 
is_omod4aco::ssa_info312    bool is_omod4() { return label & label_omod4; }
313 
set_omod5aco::ssa_info314    void set_omod5(Instruction* mul)
315    {
316       add_label(label_omod5);
317       instr = mul;
318    }
319 
is_omod5aco::ssa_info320    bool is_omod5() { return label & label_omod5; }
321 
set_clampaco::ssa_info322    void set_clamp(Instruction* med3)
323    {
324       add_label(label_clamp);
325       instr = med3;
326    }
327 
is_clampaco::ssa_info328    bool is_clamp() { return label & label_clamp; }
329 
set_f2f16aco::ssa_info330    void set_f2f16(Instruction* conv)
331    {
332       add_label(label_f2f16);
333       instr = conv;
334    }
335 
is_f2f16aco::ssa_info336    bool is_f2f16() { return label & label_f2f16; }
337 
set_undefinedaco::ssa_info338    void set_undefined() { add_label(label_undefined); }
339 
is_undefinedaco::ssa_info340    bool is_undefined() { return label & label_undefined; }
341 
set_vccaco::ssa_info342    void set_vcc(Temp vcc_val)
343    {
344       add_label(label_vcc);
345       temp = vcc_val;
346    }
347 
is_vccaco::ssa_info348    bool is_vcc() { return label & label_vcc; }
349 
set_b2faco::ssa_info350    void set_b2f(Temp b2f_val)
351    {
352       add_label(label_b2f);
353       temp = b2f_val;
354    }
355 
is_b2faco::ssa_info356    bool is_b2f() { return label & label_b2f; }
357 
set_add_subaco::ssa_info358    void set_add_sub(Instruction* add_sub_instr)
359    {
360       add_label(label_add_sub);
361       instr = add_sub_instr;
362    }
363 
is_add_subaco::ssa_info364    bool is_add_sub() { return label & label_add_sub; }
365 
set_bitwiseaco::ssa_info366    void set_bitwise(Instruction* bitwise_instr)
367    {
368       add_label(label_bitwise);
369       instr = bitwise_instr;
370    }
371 
is_bitwiseaco::ssa_info372    bool is_bitwise() { return label & label_bitwise; }
373 
set_uniform_bitwiseaco::ssa_info374    void set_uniform_bitwise() { add_label(label_uniform_bitwise); }
375 
is_uniform_bitwiseaco::ssa_info376    bool is_uniform_bitwise() { return label & label_uniform_bitwise; }
377 
set_minmaxaco::ssa_info378    void set_minmax(Instruction* minmax_instr)
379    {
380       add_label(label_minmax);
381       instr = minmax_instr;
382    }
383 
is_minmaxaco::ssa_info384    bool is_minmax() { return label & label_minmax; }
385 
set_vopcaco::ssa_info386    void set_vopc(Instruction* vopc_instr)
387    {
388       add_label(label_vopc);
389       instr = vopc_instr;
390    }
391 
is_vopcaco::ssa_info392    bool is_vopc() { return label & label_vopc; }
393 
set_scc_neededaco::ssa_info394    void set_scc_needed() { add_label(label_scc_needed); }
395 
is_scc_neededaco::ssa_info396    bool is_scc_needed() { return label & label_scc_needed; }
397 
set_scc_invertaco::ssa_info398    void set_scc_invert(Temp scc_inv)
399    {
400       add_label(label_scc_invert);
401       temp = scc_inv;
402    }
403 
is_scc_invertaco::ssa_info404    bool is_scc_invert() { return label & label_scc_invert; }
405 
set_uniform_boolaco::ssa_info406    void set_uniform_bool(Temp uniform_bool)
407    {
408       add_label(label_uniform_bool);
409       temp = uniform_bool;
410    }
411 
is_uniform_boolaco::ssa_info412    bool is_uniform_bool() { return label & label_uniform_bool; }
413 
set_vcc_hintaco::ssa_info414    void set_vcc_hint() { add_label(label_vcc_hint); }
415 
is_vcc_hintaco::ssa_info416    bool is_vcc_hint() { return label & label_vcc_hint; }
417 
set_b2iaco::ssa_info418    void set_b2i(Temp b2i_val)
419    {
420       add_label(label_b2i);
421       temp = b2i_val;
422    }
423 
is_b2iaco::ssa_info424    bool is_b2i() { return label & label_b2i; }
425 
set_usedefaco::ssa_info426    void set_usedef(Instruction* label_instr)
427    {
428       add_label(label_usedef);
429       instr = label_instr;
430    }
431 
is_usedefaco::ssa_info432    bool is_usedef() { return label & label_usedef; }
433 
set_vop3paco::ssa_info434    void set_vop3p(Instruction* vop3p_instr)
435    {
436       add_label(label_vop3p);
437       instr = vop3p_instr;
438    }
439 
is_vop3paco::ssa_info440    bool is_vop3p() { return label & label_vop3p; }
441 
set_fcanonicalizeaco::ssa_info442    void set_fcanonicalize(Temp tmp)
443    {
444       add_label(label_fcanonicalize);
445       temp = tmp;
446    }
447 
is_fcanonicalizeaco::ssa_info448    bool is_fcanonicalize() { return label & label_fcanonicalize; }
449 
set_canonicalizedaco::ssa_info450    void set_canonicalized() { add_label(label_canonicalized); }
451 
is_canonicalizedaco::ssa_info452    bool is_canonicalized() { return label & label_canonicalized; }
453 
set_f2f32aco::ssa_info454    void set_f2f32(Instruction* cvt)
455    {
456       add_label(label_f2f32);
457       instr = cvt;
458    }
459 
is_f2f32aco::ssa_info460    bool is_f2f32() { return label & label_f2f32; }
461 
set_extractaco::ssa_info462    void set_extract(Instruction* extract)
463    {
464       add_label(label_extract);
465       instr = extract;
466    }
467 
is_extractaco::ssa_info468    bool is_extract() { return label & label_extract; }
469 
set_insertaco::ssa_info470    void set_insert(Instruction* insert)
471    {
472       add_label(label_insert);
473       instr = insert;
474    }
475 
is_insertaco::ssa_info476    bool is_insert() { return label & label_insert; }
477 
set_dpp16aco::ssa_info478    void set_dpp16(Instruction* mov)
479    {
480       add_label(label_dpp16);
481       instr = mov;
482    }
483 
set_dpp8aco::ssa_info484    void set_dpp8(Instruction* mov)
485    {
486       add_label(label_dpp8);
487       instr = mov;
488    }
489 
is_dppaco::ssa_info490    bool is_dpp() { return label & (label_dpp16 | label_dpp8); }
is_dpp16aco::ssa_info491    bool is_dpp16() { return label & label_dpp16; }
is_dpp8aco::ssa_info492    bool is_dpp8() { return label & label_dpp8; }
493 };
494 
495 struct opt_ctx {
496    Program* program;
497    float_mode fp_mode;
498    std::vector<aco_ptr<Instruction>> instructions;
499    ssa_info* info;
500    std::pair<uint32_t, Temp> last_literal;
501    std::vector<mad_info> mad_infos;
502    std::vector<uint16_t> uses;
503 };
504 
505 bool
can_use_VOP3(opt_ctx & ctx,const aco_ptr<Instruction> & instr)506 can_use_VOP3(opt_ctx& ctx, const aco_ptr<Instruction>& instr)
507 {
508    if (instr->isVOP3())
509       return true;
510 
511    if (instr->isVOP3P())
512       return false;
513 
514    if (instr->operands.size() && instr->operands[0].isLiteral() && ctx.program->chip_class < GFX10)
515       return false;
516 
517    if (instr->isDPP() || instr->isSDWA())
518       return false;
519 
520    return instr->opcode != aco_opcode::v_madmk_f32 && instr->opcode != aco_opcode::v_madak_f32 &&
521           instr->opcode != aco_opcode::v_madmk_f16 && instr->opcode != aco_opcode::v_madak_f16 &&
522           instr->opcode != aco_opcode::v_fmamk_f32 && instr->opcode != aco_opcode::v_fmaak_f32 &&
523           instr->opcode != aco_opcode::v_fmamk_f16 && instr->opcode != aco_opcode::v_fmaak_f16 &&
524           instr->opcode != aco_opcode::v_readlane_b32 &&
525           instr->opcode != aco_opcode::v_writelane_b32 &&
526           instr->opcode != aco_opcode::v_readfirstlane_b32;
527 }
528 
529 bool
pseudo_propagate_temp(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp temp,unsigned index)530 pseudo_propagate_temp(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp temp, unsigned index)
531 {
532    if (instr->definitions.empty())
533       return false;
534 
535    const bool vgpr =
536       instr->opcode == aco_opcode::p_as_uniform ||
537       std::all_of(instr->definitions.begin(), instr->definitions.end(),
538                   [](const Definition& def) { return def.regClass().type() == RegType::vgpr; });
539 
540    /* don't propagate VGPRs into SGPR instructions */
541    if (temp.type() == RegType::vgpr && !vgpr)
542       return false;
543 
544    bool can_accept_sgpr =
545       ctx.program->chip_class >= GFX9 ||
546       std::none_of(instr->definitions.begin(), instr->definitions.end(),
547                    [](const Definition& def) { return def.regClass().is_subdword(); });
548 
549    switch (instr->opcode) {
550    case aco_opcode::p_phi:
551    case aco_opcode::p_linear_phi:
552    case aco_opcode::p_parallelcopy:
553    case aco_opcode::p_create_vector:
554       if (temp.bytes() != instr->operands[index].bytes())
555          return false;
556       break;
557    case aco_opcode::p_extract_vector:
558    case aco_opcode::p_extract:
559       if (temp.type() == RegType::sgpr && !can_accept_sgpr)
560          return false;
561       break;
562    case aco_opcode::p_split_vector: {
563       if (temp.type() == RegType::sgpr && !can_accept_sgpr)
564          return false;
565       /* don't increase the vector size */
566       if (temp.bytes() > instr->operands[index].bytes())
567          return false;
568       /* We can decrease the vector size as smaller temporaries are only
569        * propagated by p_as_uniform instructions.
570        * If this propagation leads to invalid IR or hits the assertion below,
571        * it means that some undefined bytes within a dword are begin accessed
572        * and a bug in instruction_selection is likely. */
573       int decrease = instr->operands[index].bytes() - temp.bytes();
574       while (decrease > 0) {
575          decrease -= instr->definitions.back().bytes();
576          instr->definitions.pop_back();
577       }
578       assert(decrease == 0);
579       break;
580    }
581    case aco_opcode::p_as_uniform:
582       if (temp.regClass() == instr->definitions[0].regClass())
583          instr->opcode = aco_opcode::p_parallelcopy;
584       break;
585    default: return false;
586    }
587 
588    instr->operands[index].setTemp(temp);
589    return true;
590 }
591 
592 /* This expects the DPP modifier to be removed. */
593 bool
can_apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)594 can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
595 {
596    if (instr->isSDWA() && ctx.program->chip_class < GFX9)
597       return false;
598    return instr->opcode != aco_opcode::v_readfirstlane_b32 &&
599           instr->opcode != aco_opcode::v_readlane_b32 &&
600           instr->opcode != aco_opcode::v_readlane_b32_e64 &&
601           instr->opcode != aco_opcode::v_writelane_b32 &&
602           instr->opcode != aco_opcode::v_writelane_b32_e64 &&
603           instr->opcode != aco_opcode::v_permlane16_b32 &&
604           instr->opcode != aco_opcode::v_permlanex16_b32;
605 }
606 
607 void
to_VOP3(opt_ctx & ctx,aco_ptr<Instruction> & instr)608 to_VOP3(opt_ctx& ctx, aco_ptr<Instruction>& instr)
609 {
610    if (instr->isVOP3())
611       return;
612 
613    aco_ptr<Instruction> tmp = std::move(instr);
614    Format format = asVOP3(tmp->format);
615    instr.reset(create_instruction<VOP3_instruction>(tmp->opcode, format, tmp->operands.size(),
616                                                     tmp->definitions.size()));
617    std::copy(tmp->operands.cbegin(), tmp->operands.cend(), instr->operands.begin());
618    for (unsigned i = 0; i < instr->definitions.size(); i++) {
619       instr->definitions[i] = tmp->definitions[i];
620       if (instr->definitions[i].isTemp()) {
621          ssa_info& info = ctx.info[instr->definitions[i].tempId()];
622          if (info.label & instr_usedef_labels && info.instr == tmp.get())
623             info.instr = instr.get();
624       }
625    }
626    /* we don't need to update any instr_mod_labels because they either haven't
627     * been applied yet or this instruction isn't dead and so they've been ignored */
628 
629    instr->pass_flags = tmp->pass_flags;
630 }
631 
632 bool
is_operand_vgpr(Operand op)633 is_operand_vgpr(Operand op)
634 {
635    return op.isTemp() && op.getTemp().type() == RegType::vgpr;
636 }
637 
638 void
to_SDWA(opt_ctx & ctx,aco_ptr<Instruction> & instr)639 to_SDWA(opt_ctx& ctx, aco_ptr<Instruction>& instr)
640 {
641    aco_ptr<Instruction> tmp = convert_to_SDWA(ctx.program->chip_class, instr);
642    if (!tmp)
643       return;
644 
645    for (unsigned i = 0; i < instr->definitions.size(); i++) {
646       ssa_info& info = ctx.info[instr->definitions[i].tempId()];
647       if (info.label & instr_labels && info.instr == tmp.get())
648          info.instr = instr.get();
649    }
650 }
651 
652 /* only covers special cases */
653 bool
alu_can_accept_constant(aco_opcode opcode,unsigned operand)654 alu_can_accept_constant(aco_opcode opcode, unsigned operand)
655 {
656    switch (opcode) {
657    case aco_opcode::v_interp_p2_f32:
658    case aco_opcode::v_mac_f32:
659    case aco_opcode::v_writelane_b32:
660    case aco_opcode::v_writelane_b32_e64:
661    case aco_opcode::v_cndmask_b32: return operand != 2;
662    case aco_opcode::s_addk_i32:
663    case aco_opcode::s_mulk_i32:
664    case aco_opcode::p_wqm:
665    case aco_opcode::p_extract_vector:
666    case aco_opcode::p_split_vector:
667    case aco_opcode::v_readlane_b32:
668    case aco_opcode::v_readlane_b32_e64:
669    case aco_opcode::v_readfirstlane_b32:
670    case aco_opcode::p_extract:
671    case aco_opcode::p_insert: return operand != 0;
672    default: return true;
673    }
674 }
675 
676 bool
valu_can_accept_vgpr(aco_ptr<Instruction> & instr,unsigned operand)677 valu_can_accept_vgpr(aco_ptr<Instruction>& instr, unsigned operand)
678 {
679    if (instr->opcode == aco_opcode::v_readlane_b32 ||
680        instr->opcode == aco_opcode::v_readlane_b32_e64 ||
681        instr->opcode == aco_opcode::v_writelane_b32 ||
682        instr->opcode == aco_opcode::v_writelane_b32_e64)
683       return operand != 1;
684    if (instr->opcode == aco_opcode::v_permlane16_b32 ||
685        instr->opcode == aco_opcode::v_permlanex16_b32)
686       return operand == 0;
687    return true;
688 }
689 
690 /* check constant bus and literal limitations */
691 bool
check_vop3_operands(opt_ctx & ctx,unsigned num_operands,Operand * operands)692 check_vop3_operands(opt_ctx& ctx, unsigned num_operands, Operand* operands)
693 {
694    int limit = ctx.program->chip_class >= GFX10 ? 2 : 1;
695    Operand literal32(s1);
696    Operand literal64(s2);
697    unsigned num_sgprs = 0;
698    unsigned sgpr[] = {0, 0};
699 
700    for (unsigned i = 0; i < num_operands; i++) {
701       Operand op = operands[i];
702 
703       if (op.hasRegClass() && op.regClass().type() == RegType::sgpr) {
704          /* two reads of the same SGPR count as 1 to the limit */
705          if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
706             if (num_sgprs < 2)
707                sgpr[num_sgprs++] = op.tempId();
708             limit--;
709             if (limit < 0)
710                return false;
711          }
712       } else if (op.isLiteral()) {
713          if (ctx.program->chip_class < GFX10)
714             return false;
715 
716          if (!literal32.isUndefined() && literal32.constantValue() != op.constantValue())
717             return false;
718          if (!literal64.isUndefined() && literal64.constantValue() != op.constantValue())
719             return false;
720 
721          /* Any number of 32-bit literals counts as only 1 to the limit. Same
722           * (but separately) for 64-bit literals. */
723          if (op.size() == 1 && literal32.isUndefined()) {
724             limit--;
725             literal32 = op;
726          } else if (op.size() == 2 && literal64.isUndefined()) {
727             limit--;
728             literal64 = op;
729          }
730 
731          if (limit < 0)
732             return false;
733       }
734    }
735 
736    return true;
737 }
738 
739 bool
parse_base_offset(opt_ctx & ctx,Instruction * instr,unsigned op_index,Temp * base,uint32_t * offset,bool prevent_overflow)740 parse_base_offset(opt_ctx& ctx, Instruction* instr, unsigned op_index, Temp* base, uint32_t* offset,
741                   bool prevent_overflow)
742 {
743    Operand op = instr->operands[op_index];
744 
745    if (!op.isTemp())
746       return false;
747    Temp tmp = op.getTemp();
748    if (!ctx.info[tmp.id()].is_add_sub())
749       return false;
750 
751    Instruction* add_instr = ctx.info[tmp.id()].instr;
752 
753    switch (add_instr->opcode) {
754    case aco_opcode::v_add_u32:
755    case aco_opcode::v_add_co_u32:
756    case aco_opcode::v_add_co_u32_e64:
757    case aco_opcode::s_add_i32:
758    case aco_opcode::s_add_u32: break;
759    default: return false;
760    }
761    if (prevent_overflow && !add_instr->definitions[0].isNUW())
762       return false;
763 
764    if (add_instr->usesModifiers())
765       return false;
766 
767    for (unsigned i = 0; i < 2; i++) {
768       if (add_instr->operands[i].isConstant()) {
769          *offset = add_instr->operands[i].constantValue();
770       } else if (add_instr->operands[i].isTemp() &&
771                  ctx.info[add_instr->operands[i].tempId()].is_constant_or_literal(32)) {
772          *offset = ctx.info[add_instr->operands[i].tempId()].val;
773       } else {
774          continue;
775       }
776       if (!add_instr->operands[!i].isTemp())
777          continue;
778 
779       uint32_t offset2 = 0;
780       if (parse_base_offset(ctx, add_instr, !i, base, &offset2, prevent_overflow)) {
781          *offset += offset2;
782       } else {
783          *base = add_instr->operands[!i].getTemp();
784       }
785       return true;
786    }
787 
788    return false;
789 }
790 
791 void
skip_smem_offset_align(opt_ctx & ctx,SMEM_instruction * smem)792 skip_smem_offset_align(opt_ctx& ctx, SMEM_instruction* smem)
793 {
794    bool soe = smem->operands.size() >= (!smem->definitions.empty() ? 3 : 4);
795    if (soe && !smem->operands[1].isConstant())
796       return;
797    /* We don't need to check the constant offset because the address seems to be calculated with
798     * (offset&-4 + const_offset&-4), not (offset+const_offset)&-4.
799     */
800 
801    Operand& op = smem->operands[soe ? smem->operands.size() - 1 : 1];
802    if (!op.isTemp() || !ctx.info[op.tempId()].is_bitwise())
803       return;
804 
805    Instruction* bitwise_instr = ctx.info[op.tempId()].instr;
806    if (bitwise_instr->opcode != aco_opcode::s_and_b32)
807       return;
808 
809    if (bitwise_instr->operands[0].constantEquals(-4) &&
810        bitwise_instr->operands[1].isOfType(op.regClass().type()))
811       op.setTemp(bitwise_instr->operands[1].getTemp());
812    else if (bitwise_instr->operands[1].constantEquals(-4) &&
813             bitwise_instr->operands[0].isOfType(op.regClass().type()))
814       op.setTemp(bitwise_instr->operands[0].getTemp());
815 }
816 
817 void
smem_combine(opt_ctx & ctx,aco_ptr<Instruction> & instr)818 smem_combine(opt_ctx& ctx, aco_ptr<Instruction>& instr)
819 {
820    /* skip &-4 before offset additions: load((a + 16) & -4, 0) */
821    if (!instr->operands.empty())
822       skip_smem_offset_align(ctx, &instr->smem());
823 
824    /* propagate constants and combine additions */
825    if (!instr->operands.empty() && instr->operands[1].isTemp()) {
826       SMEM_instruction& smem = instr->smem();
827       ssa_info info = ctx.info[instr->operands[1].tempId()];
828 
829       Temp base;
830       uint32_t offset;
831       bool prevent_overflow = smem.operands[0].size() > 2 || smem.prevent_overflow;
832       if (info.is_constant_or_literal(32) &&
833           ((ctx.program->chip_class == GFX6 && info.val <= 0x3FF) ||
834            (ctx.program->chip_class == GFX7 && info.val <= 0xFFFFFFFF) ||
835            (ctx.program->chip_class >= GFX8 && info.val <= 0xFFFFF))) {
836          instr->operands[1] = Operand::c32(info.val);
837       } else if (parse_base_offset(ctx, instr.get(), 1, &base, &offset, prevent_overflow) &&
838                  base.regClass() == s1 && offset <= 0xFFFFF && ctx.program->chip_class >= GFX9 &&
839                  offset % 4u == 0) {
840          bool soe = smem.operands.size() >= (!smem.definitions.empty() ? 3 : 4);
841          if (soe) {
842             if (ctx.info[smem.operands.back().tempId()].is_constant_or_literal(32) &&
843                 ctx.info[smem.operands.back().tempId()].val == 0) {
844                smem.operands[1] = Operand::c32(offset);
845                smem.operands.back() = Operand(base);
846             }
847          } else {
848             SMEM_instruction* new_instr = create_instruction<SMEM_instruction>(
849                smem.opcode, Format::SMEM, smem.operands.size() + 1, smem.definitions.size());
850             new_instr->operands[0] = smem.operands[0];
851             new_instr->operands[1] = Operand::c32(offset);
852             if (smem.definitions.empty())
853                new_instr->operands[2] = smem.operands[2];
854             new_instr->operands.back() = Operand(base);
855             if (!smem.definitions.empty())
856                new_instr->definitions[0] = smem.definitions[0];
857             new_instr->sync = smem.sync;
858             new_instr->glc = smem.glc;
859             new_instr->dlc = smem.dlc;
860             new_instr->nv = smem.nv;
861             new_instr->disable_wqm = smem.disable_wqm;
862             instr.reset(new_instr);
863          }
864       }
865    }
866 
867    /* skip &-4 after offset additions: load(a & -4, 16) */
868    if (!instr->operands.empty())
869       skip_smem_offset_align(ctx, &instr->smem());
870 }
871 
872 unsigned
get_operand_size(aco_ptr<Instruction> & instr,unsigned index)873 get_operand_size(aco_ptr<Instruction>& instr, unsigned index)
874 {
875    if (instr->isPseudo())
876       return instr->operands[index].bytes() * 8u;
877    else if (instr->opcode == aco_opcode::v_mad_u64_u32 ||
878             instr->opcode == aco_opcode::v_mad_i64_i32)
879       return index == 2 ? 64 : 32;
880    else if (instr->opcode == aco_opcode::v_fma_mix_f32 ||
881             instr->opcode == aco_opcode::v_fma_mixlo_f16)
882       return instr->vop3p().opsel_hi & (1u << index) ? 16 : 32;
883    else if (instr->isVALU() || instr->isSALU())
884       return instr_info.operand_size[(int)instr->opcode];
885    else
886       return 0;
887 }
888 
889 Operand
get_constant_op(opt_ctx & ctx,ssa_info info,uint32_t bits)890 get_constant_op(opt_ctx& ctx, ssa_info info, uint32_t bits)
891 {
892    if (bits == 64)
893       return Operand::c32_or_c64(info.val, true);
894    return Operand::get_const(ctx.program->chip_class, info.val, bits / 8u);
895 }
896 
897 void
propagate_constants_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info,unsigned i)898 propagate_constants_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info, unsigned i)
899 {
900    if (!info.is_constant_or_literal(32))
901       return;
902 
903    assert(instr->operands[i].isTemp());
904    unsigned bits = get_operand_size(instr, i);
905    if (info.is_constant(bits)) {
906       instr->operands[i] = get_constant_op(ctx, info, bits);
907       return;
908    }
909 
910    /* try to fold inline constants */
911    VOP3P_instruction* vop3p = &instr->vop3p();
912    Operand const_lo = Operand::c16(info.val);
913    Operand const_hi = Operand::c16(info.val >> 16);
914    bool opsel_lo = (vop3p->opsel_lo >> i) & 1;
915    bool opsel_hi = (vop3p->opsel_hi >> i) & 1;
916 
917    if (const_hi.isLiteral() && (opsel_lo || opsel_hi))
918       return;
919    if (const_lo.isLiteral() && !(opsel_lo && opsel_hi))
920       return;
921 
922    if (opsel_lo == opsel_hi) {
923       /* use the single 16bit value */
924       instr->operands[i] = opsel_lo ? const_hi : const_lo;
925 
926       /* opsel must point to lo for both halves */
927       vop3p->opsel_lo &= ~(1 << i);
928       vop3p->opsel_hi &= ~(1 << i);
929    } else if (const_lo == const_hi) {
930       /* both constants are the same */
931       instr->operands[i] = const_lo;
932 
933       /* opsel must point to lo for both halves */
934       vop3p->opsel_lo &= ~(1 << i);
935       vop3p->opsel_hi &= ~(1 << i);
936    } else if (const_lo == Operand::c16(0)) {
937       /* don't inline FP constants into integer instructions */
938       // TODO: check if negative integers are zero- or sign-extended
939       if (bits == 32 && const_hi.constantValue() > 64u)
940          return;
941 
942       instr->operands[i] = const_hi;
943 
944       /* redirect opsel selection */
945       vop3p->opsel_lo ^= (1 << i);
946       vop3p->opsel_hi ^= (1 << i);
947    } else if (bits == 16 && const_lo.constantValue() == (const_hi.constantValue() ^ (1 << 15))) {
948       /* const_lo == -const_hi */
949       if (!instr_info.can_use_input_modifiers[(int)instr->opcode])
950          return;
951 
952       instr->operands[i] = Operand::c16(const_lo.constantValue() & 0x7FFF);
953       bool neg_lo = const_lo.constantValue() & (1 << 15);
954       vop3p->neg_lo[i] ^= opsel_lo ^ neg_lo;
955       vop3p->neg_hi[i] ^= opsel_hi ^ neg_lo;
956 
957       /* opsel must point to lo for both operands */
958       vop3p->opsel_lo &= ~(1 << i);
959       vop3p->opsel_hi &= ~(1 << i);
960    }
961 }
962 
963 bool
fixed_to_exec(Operand op)964 fixed_to_exec(Operand op)
965 {
966    return op.isFixed() && op.physReg() == exec;
967 }
968 
969 SubdwordSel
parse_extract(Instruction * instr)970 parse_extract(Instruction* instr)
971 {
972    if (instr->opcode == aco_opcode::p_extract) {
973       unsigned size = instr->operands[2].constantValue() / 8;
974       unsigned offset = instr->operands[1].constantValue() * size;
975       bool sext = instr->operands[3].constantEquals(1);
976       return SubdwordSel(size, offset, sext);
977    } else if (instr->opcode == aco_opcode::p_insert && instr->operands[1].constantEquals(0)) {
978       return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
979    } else if (instr->opcode == aco_opcode::p_extract_vector) {
980       unsigned size = instr->definitions[0].bytes();
981       unsigned offset = instr->operands[1].constantValue() * size;
982       if (size <= 2)
983          return SubdwordSel(size, offset, false);
984    } else if (instr->opcode == aco_opcode::p_split_vector) {
985       assert(instr->operands[0].bytes() == 4 && instr->definitions[1].bytes() == 2);
986       return SubdwordSel(2, 2, false);
987    }
988 
989    return SubdwordSel();
990 }
991 
992 SubdwordSel
parse_insert(Instruction * instr)993 parse_insert(Instruction* instr)
994 {
995    if (instr->opcode == aco_opcode::p_extract && instr->operands[3].constantEquals(0) &&
996        instr->operands[1].constantEquals(0)) {
997       return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
998    } else if (instr->opcode == aco_opcode::p_insert) {
999       unsigned size = instr->operands[2].constantValue() / 8;
1000       unsigned offset = instr->operands[1].constantValue() * size;
1001       return SubdwordSel(size, offset, false);
1002    } else {
1003       return SubdwordSel();
1004    }
1005 }
1006 
1007 bool
can_apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)1008 can_apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
1009 {
1010    if (idx >= 2)
1011       return false;
1012 
1013    Temp tmp = info.instr->operands[0].getTemp();
1014    SubdwordSel sel = parse_extract(info.instr);
1015 
1016    if (!sel) {
1017       return false;
1018    } else if (sel.size() == 4) {
1019       return true;
1020    } else if (instr->opcode == aco_opcode::v_cvt_f32_u32 && sel.size() == 1 && !sel.sign_extend()) {
1021       return true;
1022    } else if (can_use_SDWA(ctx.program->chip_class, instr, true) &&
1023               (tmp.type() == RegType::vgpr || ctx.program->chip_class >= GFX9)) {
1024       if (instr->isSDWA() && instr->sdwa().sel[idx] != SubdwordSel::dword)
1025          return false;
1026       return true;
1027    } else if (instr->isVOP3() && sel.size() == 2 &&
1028               can_use_opsel(ctx.program->chip_class, instr->opcode, idx) &&
1029               !(instr->vop3().opsel & (1 << idx))) {
1030       return true;
1031    } else if (instr->opcode == aco_opcode::p_extract) {
1032       SubdwordSel instrSel = parse_extract(instr.get());
1033 
1034       /* the outer offset must be within extracted range */
1035       if (instrSel.offset() >= sel.size())
1036          return false;
1037 
1038       /* don't remove the sign-extension when increasing the size further */
1039       if (instrSel.size() > sel.size() && !instrSel.sign_extend() && sel.sign_extend())
1040          return false;
1041 
1042       return true;
1043    }
1044 
1045    return false;
1046 }
1047 
1048 /* Combine an p_extract (or p_insert, in some cases) instruction with instr.
1049  * instr(p_extract(...)) -> instr()
1050  */
1051 void
apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)1052 apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
1053 {
1054    Temp tmp = info.instr->operands[0].getTemp();
1055    SubdwordSel sel = parse_extract(info.instr);
1056    assert(sel);
1057 
1058    instr->operands[idx].set16bit(false);
1059    instr->operands[idx].set24bit(false);
1060 
1061    ctx.info[tmp.id()].label &= ~label_insert;
1062 
1063    if (sel.size() == 4) {
1064       /* full dword selection */
1065    } else if (instr->opcode == aco_opcode::v_cvt_f32_u32 && sel.size() == 1 && !sel.sign_extend()) {
1066       switch (sel.offset()) {
1067       case 0: instr->opcode = aco_opcode::v_cvt_f32_ubyte0; break;
1068       case 1: instr->opcode = aco_opcode::v_cvt_f32_ubyte1; break;
1069       case 2: instr->opcode = aco_opcode::v_cvt_f32_ubyte2; break;
1070       case 3: instr->opcode = aco_opcode::v_cvt_f32_ubyte3; break;
1071       }
1072    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() &&
1073               sel.offset() == 0 &&
1074               ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) ||
1075                (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) {
1076       /* The undesireable upper bits are already shifted out. */
1077       return;
1078    } else if (can_use_SDWA(ctx.program->chip_class, instr, true) &&
1079               (tmp.type() == RegType::vgpr || ctx.program->chip_class >= GFX9)) {
1080       to_SDWA(ctx, instr);
1081       static_cast<SDWA_instruction*>(instr.get())->sel[idx] = sel;
1082    } else if (instr->isVOP3()) {
1083       if (sel.offset())
1084          instr->vop3().opsel |= 1 << idx;
1085    } else if (instr->opcode == aco_opcode::p_extract) {
1086       SubdwordSel instrSel = parse_extract(instr.get());
1087 
1088       unsigned size = std::min(sel.size(), instrSel.size());
1089       unsigned offset = sel.offset() + instrSel.offset();
1090       unsigned sign_extend =
1091          instrSel.sign_extend() && (sel.sign_extend() || instrSel.size() <= sel.size());
1092 
1093       instr->operands[1] = Operand::c32(offset / size);
1094       instr->operands[2] = Operand::c32(size * 8u);
1095       instr->operands[3] = Operand::c32(sign_extend);
1096       return;
1097    }
1098 
1099    /* Output modifier, label_vopc and label_f2f32 seem to be the only one worth keeping at the
1100     * moment
1101     */
1102    for (Definition& def : instr->definitions)
1103       ctx.info[def.tempId()].label &= (label_vopc | label_f2f32 | instr_mod_labels);
1104 }
1105 
1106 void
check_sdwa_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr)1107 check_sdwa_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1108 {
1109    for (unsigned i = 0; i < instr->operands.size(); i++) {
1110       Operand op = instr->operands[i];
1111       if (!op.isTemp())
1112          continue;
1113       ssa_info& info = ctx.info[op.tempId()];
1114       if (info.is_extract() && (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
1115                                 op.getTemp().type() == RegType::sgpr)) {
1116          if (!can_apply_extract(ctx, instr, i, info))
1117             info.label &= ~label_extract;
1118       }
1119    }
1120 }
1121 
1122 bool
does_fp_op_flush_denorms(opt_ctx & ctx,aco_opcode op)1123 does_fp_op_flush_denorms(opt_ctx& ctx, aco_opcode op)
1124 {
1125    if (ctx.program->chip_class <= GFX8) {
1126       switch (op) {
1127       case aco_opcode::v_min_f32:
1128       case aco_opcode::v_max_f32:
1129       case aco_opcode::v_med3_f32:
1130       case aco_opcode::v_min3_f32:
1131       case aco_opcode::v_max3_f32:
1132       case aco_opcode::v_min_f16:
1133       case aco_opcode::v_max_f16: return false;
1134       default: break;
1135       }
1136    }
1137    return op != aco_opcode::v_cndmask_b32;
1138 }
1139 
1140 bool
can_eliminate_fcanonicalize(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp tmp)1141 can_eliminate_fcanonicalize(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp tmp)
1142 {
1143    float_mode* fp = &ctx.fp_mode;
1144    if (ctx.info[tmp.id()].is_canonicalized() ||
1145        (tmp.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
1146       return true;
1147 
1148    aco_opcode op = instr->opcode;
1149    return instr_info.can_use_input_modifiers[(int)op] && does_fp_op_flush_denorms(ctx, op);
1150 }
1151 
1152 bool
is_copy_label(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info)1153 is_copy_label(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info)
1154 {
1155    return info.is_temp() ||
1156           (info.is_fcanonicalize() && can_eliminate_fcanonicalize(ctx, instr, info.temp));
1157 }
1158 
1159 bool
is_op_canonicalized(opt_ctx & ctx,Operand op)1160 is_op_canonicalized(opt_ctx& ctx, Operand op)
1161 {
1162    float_mode* fp = &ctx.fp_mode;
1163    if ((op.isTemp() && ctx.info[op.tempId()].is_canonicalized()) ||
1164        (op.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
1165       return true;
1166 
1167    if (op.isConstant() || (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(32))) {
1168       uint32_t val = op.isTemp() ? ctx.info[op.tempId()].val : op.constantValue();
1169       if (op.bytes() == 2)
1170          return (val & 0x7fff) == 0 || (val & 0x7fff) > 0x3ff;
1171       else if (op.bytes() == 4)
1172          return (val & 0x7fffffff) == 0 || (val & 0x7fffffff) > 0x7fffff;
1173    }
1174    return false;
1175 }
1176 
1177 void
label_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)1178 label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1179 {
1180    if (instr->isSALU() || instr->isVALU() || instr->isPseudo()) {
1181       ASSERTED bool all_const = false;
1182       for (Operand& op : instr->operands)
1183          all_const =
1184             all_const && (!op.isTemp() || ctx.info[op.tempId()].is_constant_or_literal(32));
1185       perfwarn(ctx.program, all_const, "All instruction operands are constant", instr.get());
1186 
1187       ASSERTED bool is_copy = instr->opcode == aco_opcode::s_mov_b32 ||
1188                               instr->opcode == aco_opcode::s_mov_b64 ||
1189                               instr->opcode == aco_opcode::v_mov_b32;
1190       perfwarn(ctx.program, is_copy && !instr->usesModifiers(), "Use p_parallelcopy instead",
1191                instr.get());
1192    }
1193 
1194    if (instr->isSMEM())
1195       smem_combine(ctx, instr);
1196 
1197    for (unsigned i = 0; i < instr->operands.size(); i++) {
1198       if (!instr->operands[i].isTemp())
1199          continue;
1200 
1201       ssa_info info = ctx.info[instr->operands[i].tempId()];
1202       /* propagate undef */
1203       if (info.is_undefined() && is_phi(instr))
1204          instr->operands[i] = Operand(instr->operands[i].regClass());
1205       /* propagate reg->reg of same type */
1206       while (info.is_temp() && info.temp.regClass() == instr->operands[i].getTemp().regClass()) {
1207          instr->operands[i].setTemp(ctx.info[instr->operands[i].tempId()].temp);
1208          info = ctx.info[info.temp.id()];
1209       }
1210 
1211       /* PSEUDO: propagate temporaries */
1212       if (instr->isPseudo()) {
1213          while (info.is_temp()) {
1214             pseudo_propagate_temp(ctx, instr, info.temp, i);
1215             info = ctx.info[info.temp.id()];
1216          }
1217       }
1218 
1219       /* SALU / PSEUDO: propagate inline constants */
1220       if (instr->isSALU() || instr->isPseudo()) {
1221          unsigned bits = get_operand_size(instr, i);
1222          if ((info.is_constant(bits) || (info.is_literal(bits) && instr->isPseudo())) &&
1223              !instr->operands[i].isFixed() && alu_can_accept_constant(instr->opcode, i)) {
1224             instr->operands[i] = get_constant_op(ctx, info, bits);
1225             continue;
1226          }
1227       }
1228 
1229       /* VALU: propagate neg, abs & inline constants */
1230       else if (instr->isVALU()) {
1231          if (is_copy_label(ctx, instr, info) && info.temp.type() == RegType::vgpr &&
1232              valu_can_accept_vgpr(instr, i)) {
1233             instr->operands[i].setTemp(info.temp);
1234             info = ctx.info[info.temp.id()];
1235          }
1236          /* applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier */
1237          if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(ctx, instr) &&
1238              instr->operands.size() == 1) {
1239             instr->format = withoutDPP(instr->format);
1240             instr->operands[i].setTemp(info.temp);
1241             info = ctx.info[info.temp.id()];
1242          }
1243 
1244          /* for instructions other than v_cndmask_b32, the size of the instruction should match the
1245           * operand size */
1246          unsigned can_use_mod =
1247             instr->opcode != aco_opcode::v_cndmask_b32 || instr->operands[i].getTemp().bytes() == 4;
1248          can_use_mod = can_use_mod && instr_info.can_use_input_modifiers[(int)instr->opcode];
1249 
1250          if (instr->isSDWA())
1251             can_use_mod = can_use_mod && instr->sdwa().sel[i].size() == 4;
1252          else
1253             can_use_mod = can_use_mod && (instr->isDPP16() || can_use_VOP3(ctx, instr));
1254 
1255          unsigned bits = get_operand_size(instr, i);
1256          bool mod_bitsize_compat = instr->operands[i].bytes() * 8 == bits;
1257 
1258          if (info.is_neg() && instr->opcode == aco_opcode::v_add_f32 && mod_bitsize_compat) {
1259             instr->opcode = i ? aco_opcode::v_sub_f32 : aco_opcode::v_subrev_f32;
1260             instr->operands[i].setTemp(info.temp);
1261          } else if (info.is_neg() && instr->opcode == aco_opcode::v_add_f16 && mod_bitsize_compat) {
1262             instr->opcode = i ? aco_opcode::v_sub_f16 : aco_opcode::v_subrev_f16;
1263             instr->operands[i].setTemp(info.temp);
1264          } else if (info.is_neg() && can_use_mod && mod_bitsize_compat &&
1265                     can_eliminate_fcanonicalize(ctx, instr, info.temp)) {
1266             if (!instr->isDPP() && !instr->isSDWA())
1267                to_VOP3(ctx, instr);
1268             instr->operands[i].setTemp(info.temp);
1269             if (instr->isDPP16() && !instr->dpp16().abs[i])
1270                instr->dpp16().neg[i] = true;
1271             else if (instr->isSDWA() && !instr->sdwa().abs[i])
1272                instr->sdwa().neg[i] = true;
1273             else if (instr->isVOP3() && !instr->vop3().abs[i])
1274                instr->vop3().neg[i] = true;
1275          }
1276          if (info.is_abs() && can_use_mod && mod_bitsize_compat &&
1277              can_eliminate_fcanonicalize(ctx, instr, info.temp)) {
1278             if (!instr->isDPP() && !instr->isSDWA())
1279                to_VOP3(ctx, instr);
1280             instr->operands[i] = Operand(info.temp);
1281             if (instr->isDPP16())
1282                instr->dpp16().abs[i] = true;
1283             else if (instr->isSDWA())
1284                instr->sdwa().abs[i] = true;
1285             else
1286                instr->vop3().abs[i] = true;
1287             continue;
1288          }
1289 
1290          if (instr->isVOP3P()) {
1291             propagate_constants_vop3p(ctx, instr, info, i);
1292             continue;
1293          }
1294 
1295          if (info.is_constant(bits) && alu_can_accept_constant(instr->opcode, i) &&
1296              (!instr->isSDWA() || ctx.program->chip_class >= GFX9)) {
1297             Operand op = get_constant_op(ctx, info, bits);
1298             perfwarn(ctx.program, instr->opcode == aco_opcode::v_cndmask_b32 && i == 2,
1299                      "v_cndmask_b32 with a constant selector", instr.get());
1300             if (i == 0 || instr->isSDWA() || instr->opcode == aco_opcode::v_readlane_b32 ||
1301                 instr->opcode == aco_opcode::v_writelane_b32) {
1302                instr->format = withoutDPP(instr->format);
1303                instr->operands[i] = op;
1304                continue;
1305             } else if (!instr->isVOP3() && can_swap_operands(instr, &instr->opcode)) {
1306                instr->operands[i] = instr->operands[0];
1307                instr->operands[0] = op;
1308                continue;
1309             } else if (can_use_VOP3(ctx, instr)) {
1310                to_VOP3(ctx, instr);
1311                instr->operands[i] = op;
1312                continue;
1313             }
1314          }
1315       }
1316 
1317       /* MUBUF: propagate constants and combine additions */
1318       else if (instr->isMUBUF()) {
1319          MUBUF_instruction& mubuf = instr->mubuf();
1320          Temp base;
1321          uint32_t offset;
1322          while (info.is_temp())
1323             info = ctx.info[info.temp.id()];
1324 
1325          /* According to AMDGPUDAGToDAGISel::SelectMUBUFScratchOffen(), vaddr
1326           * overflow for scratch accesses works only on GFX9+ and saddr overflow
1327           * never works. Since swizzling is the only thing that separates
1328           * scratch accesses and other accesses and swizzling changing how
1329           * addressing works significantly, this probably applies to swizzled
1330           * MUBUF accesses. */
1331          bool vaddr_prevent_overflow = mubuf.swizzled && ctx.program->chip_class < GFX9;
1332          bool saddr_prevent_overflow = mubuf.swizzled;
1333 
1334          if (mubuf.offen && i == 1 && info.is_constant_or_literal(32) &&
1335              mubuf.offset + info.val < 4096) {
1336             assert(!mubuf.idxen);
1337             instr->operands[1] = Operand(v1);
1338             mubuf.offset += info.val;
1339             mubuf.offen = false;
1340             continue;
1341          } else if (i == 2 && info.is_constant_or_literal(32) && mubuf.offset + info.val < 4096) {
1342             instr->operands[2] = Operand::c32(0);
1343             mubuf.offset += info.val;
1344             continue;
1345          } else if (mubuf.offen && i == 1 &&
1346                     parse_base_offset(ctx, instr.get(), i, &base, &offset,
1347                                       vaddr_prevent_overflow) &&
1348                     base.regClass() == v1 && mubuf.offset + offset < 4096) {
1349             assert(!mubuf.idxen);
1350             instr->operands[1].setTemp(base);
1351             mubuf.offset += offset;
1352             continue;
1353          } else if (i == 2 &&
1354                     parse_base_offset(ctx, instr.get(), i, &base, &offset,
1355                                       saddr_prevent_overflow) &&
1356                     base.regClass() == s1 && mubuf.offset + offset < 4096) {
1357             instr->operands[i].setTemp(base);
1358             mubuf.offset += offset;
1359             continue;
1360          }
1361       }
1362 
1363       /* DS: combine additions */
1364       else if (instr->isDS()) {
1365 
1366          DS_instruction& ds = instr->ds();
1367          Temp base;
1368          uint32_t offset;
1369          bool has_usable_ds_offset = ctx.program->chip_class >= GFX7;
1370          if (has_usable_ds_offset && i == 0 &&
1371              parse_base_offset(ctx, instr.get(), i, &base, &offset, false) &&
1372              base.regClass() == instr->operands[i].regClass() &&
1373              instr->opcode != aco_opcode::ds_swizzle_b32) {
1374             if (instr->opcode == aco_opcode::ds_write2_b32 ||
1375                 instr->opcode == aco_opcode::ds_read2_b32 ||
1376                 instr->opcode == aco_opcode::ds_write2_b64 ||
1377                 instr->opcode == aco_opcode::ds_read2_b64) {
1378                unsigned mask = (instr->opcode == aco_opcode::ds_write2_b64 ||
1379                                 instr->opcode == aco_opcode::ds_read2_b64)
1380                                   ? 0x7
1381                                   : 0x3;
1382                unsigned shifts = (instr->opcode == aco_opcode::ds_write2_b64 ||
1383                                   instr->opcode == aco_opcode::ds_read2_b64)
1384                                     ? 3
1385                                     : 2;
1386 
1387                if ((offset & mask) == 0 && ds.offset0 + (offset >> shifts) <= 255 &&
1388                    ds.offset1 + (offset >> shifts) <= 255) {
1389                   instr->operands[i].setTemp(base);
1390                   ds.offset0 += offset >> shifts;
1391                   ds.offset1 += offset >> shifts;
1392                }
1393             } else {
1394                if (ds.offset0 + offset <= 65535) {
1395                   instr->operands[i].setTemp(base);
1396                   ds.offset0 += offset;
1397                }
1398             }
1399          }
1400       }
1401 
1402       else if (instr->isBranch()) {
1403          if (ctx.info[instr->operands[0].tempId()].is_scc_invert()) {
1404             /* Flip the branch instruction to get rid of the scc_invert instruction */
1405             instr->opcode = instr->opcode == aco_opcode::p_cbranch_z ? aco_opcode::p_cbranch_nz
1406                                                                      : aco_opcode::p_cbranch_z;
1407             instr->operands[0].setTemp(ctx.info[instr->operands[0].tempId()].temp);
1408          }
1409       }
1410    }
1411 
1412    /* if this instruction doesn't define anything, return */
1413    if (instr->definitions.empty()) {
1414       check_sdwa_extract(ctx, instr);
1415       return;
1416    }
1417 
1418    if (instr->isVALU() || instr->isVINTRP()) {
1419       if (instr_info.can_use_output_modifiers[(int)instr->opcode] || instr->isVINTRP() ||
1420           instr->opcode == aco_opcode::v_cndmask_b32) {
1421          bool canonicalized = true;
1422          if (!does_fp_op_flush_denorms(ctx, instr->opcode)) {
1423             unsigned ops = instr->opcode == aco_opcode::v_cndmask_b32 ? 2 : instr->operands.size();
1424             for (unsigned i = 0; canonicalized && (i < ops); i++)
1425                canonicalized = is_op_canonicalized(ctx, instr->operands[i]);
1426          }
1427          if (canonicalized)
1428             ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1429       }
1430 
1431       if (instr->isVOPC()) {
1432          ctx.info[instr->definitions[0].tempId()].set_vopc(instr.get());
1433          check_sdwa_extract(ctx, instr);
1434          return;
1435       }
1436       if (instr->isVOP3P()) {
1437          ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
1438          return;
1439       }
1440    }
1441 
1442    switch (instr->opcode) {
1443    case aco_opcode::p_create_vector: {
1444       bool copy_prop = instr->operands.size() == 1 && instr->operands[0].isTemp() &&
1445                        instr->operands[0].regClass() == instr->definitions[0].regClass();
1446       if (copy_prop) {
1447          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1448          break;
1449       }
1450 
1451       /* expand vector operands */
1452       std::vector<Operand> ops;
1453       unsigned offset = 0;
1454       for (const Operand& op : instr->operands) {
1455          /* ensure that any expanded operands are properly aligned */
1456          bool aligned = offset % 4 == 0 || op.bytes() < 4;
1457          offset += op.bytes();
1458          if (aligned && op.isTemp() && ctx.info[op.tempId()].is_vec()) {
1459             Instruction* vec = ctx.info[op.tempId()].instr;
1460             for (const Operand& vec_op : vec->operands)
1461                ops.emplace_back(vec_op);
1462          } else {
1463             ops.emplace_back(op);
1464          }
1465       }
1466 
1467       /* combine expanded operands to new vector */
1468       if (ops.size() != instr->operands.size()) {
1469          assert(ops.size() > instr->operands.size());
1470          Definition def = instr->definitions[0];
1471          instr.reset(create_instruction<Pseudo_instruction>(aco_opcode::p_create_vector,
1472                                                             Format::PSEUDO, ops.size(), 1));
1473          for (unsigned i = 0; i < ops.size(); i++) {
1474             if (ops[i].isTemp() && ctx.info[ops[i].tempId()].is_temp() &&
1475                 ops[i].regClass() == ctx.info[ops[i].tempId()].temp.regClass())
1476                ops[i].setTemp(ctx.info[ops[i].tempId()].temp);
1477             instr->operands[i] = ops[i];
1478          }
1479          instr->definitions[0] = def;
1480       } else {
1481          for (unsigned i = 0; i < ops.size(); i++) {
1482             assert(instr->operands[i] == ops[i]);
1483          }
1484       }
1485       ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1486       break;
1487    }
1488    case aco_opcode::p_split_vector: {
1489       ssa_info& info = ctx.info[instr->operands[0].tempId()];
1490 
1491       if (info.is_constant_or_literal(32)) {
1492          uint64_t val = info.val;
1493          for (Definition def : instr->definitions) {
1494             uint32_t mask = u_bit_consecutive(0, def.bytes() * 8u);
1495             ctx.info[def.tempId()].set_constant(ctx.program->chip_class, val & mask);
1496             val >>= def.bytes() * 8u;
1497          }
1498          break;
1499       } else if (!info.is_vec()) {
1500          if (instr->operands[0].bytes() == 4 && instr->definitions.size() == 2) {
1501             /* D16 subdword split */
1502             ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1503             if (instr->definitions[1].bytes() == 2)
1504                ctx.info[instr->definitions[1].tempId()].set_extract(instr.get());
1505          }
1506          break;
1507       }
1508 
1509       Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1510       unsigned split_offset = 0;
1511       unsigned vec_offset = 0;
1512       unsigned vec_index = 0;
1513       for (unsigned i = 0; i < instr->definitions.size();
1514            split_offset += instr->definitions[i++].bytes()) {
1515          while (vec_offset < split_offset && vec_index < vec->operands.size())
1516             vec_offset += vec->operands[vec_index++].bytes();
1517 
1518          if (vec_offset != split_offset ||
1519              vec->operands[vec_index].bytes() != instr->definitions[i].bytes())
1520             continue;
1521 
1522          Operand vec_op = vec->operands[vec_index];
1523          if (vec_op.isConstant()) {
1524             ctx.info[instr->definitions[i].tempId()].set_constant(ctx.program->chip_class,
1525                                                                   vec_op.constantValue64());
1526          } else if (vec_op.isUndefined()) {
1527             ctx.info[instr->definitions[i].tempId()].set_undefined();
1528          } else {
1529             assert(vec_op.isTemp());
1530             ctx.info[instr->definitions[i].tempId()].set_temp(vec_op.getTemp());
1531          }
1532       }
1533       break;
1534    }
1535    case aco_opcode::p_extract_vector: { /* mov */
1536       ssa_info& info = ctx.info[instr->operands[0].tempId()];
1537       const unsigned index = instr->operands[1].constantValue();
1538       const unsigned dst_offset = index * instr->definitions[0].bytes();
1539 
1540       if (info.is_vec()) {
1541          /* check if we index directly into a vector element */
1542          Instruction* vec = info.instr;
1543          unsigned offset = 0;
1544 
1545          for (const Operand& op : vec->operands) {
1546             if (offset < dst_offset) {
1547                offset += op.bytes();
1548                continue;
1549             } else if (offset != dst_offset || op.bytes() != instr->definitions[0].bytes()) {
1550                break;
1551             }
1552             instr->operands[0] = op;
1553             break;
1554          }
1555       } else if (info.is_constant_or_literal(32)) {
1556          /* propagate constants */
1557          uint32_t mask = u_bit_consecutive(0, instr->definitions[0].bytes() * 8u);
1558          uint32_t val = (info.val >> (dst_offset * 8u)) & mask;
1559          instr->operands[0] =
1560             Operand::get_const(ctx.program->chip_class, val, instr->definitions[0].bytes());
1561          ;
1562       }
1563 
1564       if (instr->operands[0].bytes() != instr->definitions[0].bytes()) {
1565          if (instr->operands[0].size() != 1)
1566             break;
1567 
1568          if (index == 0)
1569             ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1570          else
1571             ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
1572          break;
1573       }
1574 
1575       /* convert this extract into a copy instruction */
1576       instr->opcode = aco_opcode::p_parallelcopy;
1577       instr->operands.pop_back();
1578       FALLTHROUGH;
1579    }
1580    case aco_opcode::p_parallelcopy: /* propagate */
1581       if (instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_vec() &&
1582           instr->operands[0].regClass() != instr->definitions[0].regClass()) {
1583          /* We might not be able to copy-propagate if it's a SGPR->VGPR copy, so
1584           * duplicate the vector instead.
1585           */
1586          Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1587          aco_ptr<Instruction> old_copy = std::move(instr);
1588 
1589          instr.reset(create_instruction<Pseudo_instruction>(
1590             aco_opcode::p_create_vector, Format::PSEUDO, vec->operands.size(), 1));
1591          instr->definitions[0] = old_copy->definitions[0];
1592          std::copy(vec->operands.begin(), vec->operands.end(), instr->operands.begin());
1593          for (unsigned i = 0; i < vec->operands.size(); i++) {
1594             Operand& op = instr->operands[i];
1595             if (op.isTemp() && ctx.info[op.tempId()].is_temp() &&
1596                 ctx.info[op.tempId()].temp.type() == instr->definitions[0].regClass().type())
1597                op.setTemp(ctx.info[op.tempId()].temp);
1598          }
1599          ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1600          break;
1601       }
1602       FALLTHROUGH;
1603    case aco_opcode::p_as_uniform:
1604       if (instr->definitions[0].isFixed()) {
1605          /* don't copy-propagate copies into fixed registers */
1606       } else if (instr->usesModifiers()) {
1607          // TODO
1608       } else if (instr->operands[0].isConstant()) {
1609          ctx.info[instr->definitions[0].tempId()].set_constant(
1610             ctx.program->chip_class, instr->operands[0].constantValue64());
1611       } else if (instr->operands[0].isTemp()) {
1612          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1613          if (ctx.info[instr->operands[0].tempId()].is_canonicalized())
1614             ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1615       } else {
1616          assert(instr->operands[0].isFixed());
1617       }
1618       break;
1619    case aco_opcode::v_mov_b32:
1620       if (instr->isDPP16()) {
1621          /* anything else doesn't make sense in SSA */
1622          assert(instr->dpp16().row_mask == 0xf && instr->dpp16().bank_mask == 0xf);
1623          ctx.info[instr->definitions[0].tempId()].set_dpp16(instr.get());
1624       } else if (instr->isDPP8()) {
1625          ctx.info[instr->definitions[0].tempId()].set_dpp8(instr.get());
1626       }
1627       break;
1628    case aco_opcode::p_is_helper:
1629       if (!ctx.program->needs_wqm)
1630          ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->chip_class, 0u);
1631       break;
1632    case aco_opcode::v_mul_f64: ctx.info[instr->definitions[0].tempId()].set_mul(instr.get()); break;
1633    case aco_opcode::v_mul_f16:
1634    case aco_opcode::v_mul_f32:
1635    case aco_opcode::v_mul_legacy_f32: { /* omod */
1636       ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
1637 
1638       /* TODO: try to move the negate/abs modifier to the consumer instead */
1639       bool uses_mods = instr->usesModifiers();
1640       bool fp16 = instr->opcode == aco_opcode::v_mul_f16;
1641 
1642       for (unsigned i = 0; i < 2; i++) {
1643          if (instr->operands[!i].isConstant() && instr->operands[i].isTemp()) {
1644             if (!instr->isDPP() && !instr->isSDWA() &&
1645                 (instr->operands[!i].constantEquals(fp16 ? 0x3c00 : 0x3f800000) ||   /* 1.0 */
1646                  instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u))) { /* -1.0 */
1647                bool neg1 = instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u);
1648 
1649                VOP3_instruction* vop3 = instr->isVOP3() ? &instr->vop3() : NULL;
1650                if (vop3 && (vop3->abs[!i] || vop3->neg[!i] || vop3->clamp || vop3->omod))
1651                   continue;
1652 
1653                bool abs = vop3 && vop3->abs[i];
1654                bool neg = neg1 ^ (vop3 && vop3->neg[i]);
1655 
1656                Temp other = instr->operands[i].getTemp();
1657                if (abs && neg && other.type() == RegType::vgpr)
1658                   ctx.info[instr->definitions[0].tempId()].set_neg_abs(other);
1659                else if (abs && !neg && other.type() == RegType::vgpr)
1660                   ctx.info[instr->definitions[0].tempId()].set_abs(other);
1661                else if (!abs && neg && other.type() == RegType::vgpr)
1662                   ctx.info[instr->definitions[0].tempId()].set_neg(other);
1663                else if (!abs && !neg)
1664                   ctx.info[instr->definitions[0].tempId()].set_fcanonicalize(other);
1665             } else if (uses_mods) {
1666                continue;
1667             } else if (instr->operands[!i].constantValue() ==
1668                        (fp16 ? 0x4000 : 0x40000000)) { /* 2.0 */
1669                ctx.info[instr->operands[i].tempId()].set_omod2(instr.get());
1670             } else if (instr->operands[!i].constantValue() ==
1671                        (fp16 ? 0x4400 : 0x40800000)) { /* 4.0 */
1672                ctx.info[instr->operands[i].tempId()].set_omod4(instr.get());
1673             } else if (instr->operands[!i].constantValue() ==
1674                        (fp16 ? 0x3800 : 0x3f000000)) { /* 0.5 */
1675                ctx.info[instr->operands[i].tempId()].set_omod5(instr.get());
1676             } else if (instr->operands[!i].constantValue() == 0u &&
1677                        (!(fp16 ? ctx.fp_mode.preserve_signed_zero_inf_nan16_64
1678                                : ctx.fp_mode.preserve_signed_zero_inf_nan32) ||
1679                         instr->opcode == aco_opcode::v_mul_legacy_f32)) { /* 0.0 */
1680                ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->chip_class, 0u);
1681             } else {
1682                continue;
1683             }
1684             break;
1685          }
1686       }
1687       break;
1688    }
1689    case aco_opcode::v_mul_lo_u16:
1690    case aco_opcode::v_mul_lo_u16_e64:
1691    case aco_opcode::v_mul_u32_u24:
1692       ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
1693       break;
1694    case aco_opcode::v_med3_f16:
1695    case aco_opcode::v_med3_f32: { /* clamp */
1696       VOP3_instruction& vop3 = instr->vop3();
1697       if (vop3.abs[0] || vop3.abs[1] || vop3.abs[2] || vop3.neg[0] || vop3.neg[1] || vop3.neg[2] ||
1698           vop3.omod != 0 || vop3.opsel != 0)
1699          break;
1700 
1701       unsigned idx = 0;
1702       bool found_zero = false, found_one = false;
1703       bool is_fp16 = instr->opcode == aco_opcode::v_med3_f16;
1704       for (unsigned i = 0; i < 3; i++) {
1705          if (instr->operands[i].constantEquals(0))
1706             found_zero = true;
1707          else if (instr->operands[i].constantEquals(is_fp16 ? 0x3c00 : 0x3f800000)) /* 1.0 */
1708             found_one = true;
1709          else
1710             idx = i;
1711       }
1712       if (found_zero && found_one && instr->operands[idx].isTemp())
1713          ctx.info[instr->operands[idx].tempId()].set_clamp(instr.get());
1714       break;
1715    }
1716    case aco_opcode::v_cndmask_b32:
1717       if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(0xFFFFFFFF))
1718          ctx.info[instr->definitions[0].tempId()].set_vcc(instr->operands[2].getTemp());
1719       else if (instr->operands[0].constantEquals(0) &&
1720                instr->operands[1].constantEquals(0x3f800000u))
1721          ctx.info[instr->definitions[0].tempId()].set_b2f(instr->operands[2].getTemp());
1722       else if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(1))
1723          ctx.info[instr->definitions[0].tempId()].set_b2i(instr->operands[2].getTemp());
1724 
1725       ctx.info[instr->operands[2].tempId()].set_vcc_hint();
1726       break;
1727    case aco_opcode::v_cmp_lg_u32:
1728       if (instr->format == Format::VOPC && /* don't optimize VOP3 / SDWA / DPP */
1729           instr->operands[0].constantEquals(0) && instr->operands[1].isTemp() &&
1730           ctx.info[instr->operands[1].tempId()].is_vcc())
1731          ctx.info[instr->definitions[0].tempId()].set_temp(
1732             ctx.info[instr->operands[1].tempId()].temp);
1733       break;
1734    case aco_opcode::p_linear_phi: {
1735       /* lower_bool_phis() can create phis like this */
1736       bool all_same_temp = instr->operands[0].isTemp();
1737       /* this check is needed when moving uniform loop counters out of a divergent loop */
1738       if (all_same_temp)
1739          all_same_temp = instr->definitions[0].regClass() == instr->operands[0].regClass();
1740       for (unsigned i = 1; all_same_temp && (i < instr->operands.size()); i++) {
1741          if (!instr->operands[i].isTemp() ||
1742              instr->operands[i].tempId() != instr->operands[0].tempId())
1743             all_same_temp = false;
1744       }
1745       if (all_same_temp) {
1746          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1747       } else {
1748          bool all_undef = instr->operands[0].isUndefined();
1749          for (unsigned i = 1; all_undef && (i < instr->operands.size()); i++) {
1750             if (!instr->operands[i].isUndefined())
1751                all_undef = false;
1752          }
1753          if (all_undef)
1754             ctx.info[instr->definitions[0].tempId()].set_undefined();
1755       }
1756       break;
1757    }
1758    case aco_opcode::v_add_u32:
1759    case aco_opcode::v_add_co_u32:
1760    case aco_opcode::v_add_co_u32_e64:
1761    case aco_opcode::s_add_i32:
1762    case aco_opcode::s_add_u32:
1763    case aco_opcode::v_subbrev_co_u32:
1764       ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
1765       break;
1766    case aco_opcode::s_not_b32:
1767    case aco_opcode::s_not_b64:
1768       if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
1769          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1770          ctx.info[instr->definitions[1].tempId()].set_scc_invert(
1771             ctx.info[instr->operands[0].tempId()].temp);
1772       } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
1773          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1774          ctx.info[instr->definitions[1].tempId()].set_scc_invert(
1775             ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1776       }
1777       ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
1778       break;
1779    case aco_opcode::s_and_b32:
1780    case aco_opcode::s_and_b64:
1781       if (fixed_to_exec(instr->operands[1]) && instr->operands[0].isTemp()) {
1782          if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
1783             /* Try to get rid of the superfluous s_cselect + s_and_b64 that comes from turning a
1784              * uniform bool into divergent */
1785             ctx.info[instr->definitions[1].tempId()].set_temp(
1786                ctx.info[instr->operands[0].tempId()].temp);
1787             ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
1788                ctx.info[instr->operands[0].tempId()].temp);
1789             break;
1790          } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
1791             /* Try to get rid of the superfluous s_and_b64, since the uniform bitwise instruction
1792              * already produces the same SCC */
1793             ctx.info[instr->definitions[1].tempId()].set_temp(
1794                ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1795             ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
1796                ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1797             break;
1798          } else if ((ctx.program->stage.num_sw_stages() > 1 ||
1799                      ctx.program->stage.hw == HWStage::NGG) &&
1800                     instr->pass_flags == 1) {
1801             /* In case of merged shaders, pass_flags=1 means that all lanes are active (exec=-1), so
1802              * s_and is unnecessary. */
1803             ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1804             break;
1805          } else if (ctx.info[instr->operands[0].tempId()].is_vopc()) {
1806             Instruction* vopc_instr = ctx.info[instr->operands[0].tempId()].instr;
1807             /* Remove superfluous s_and when the VOPC instruction uses the same exec and thus
1808              * already produces the same result */
1809             if (vopc_instr->pass_flags == instr->pass_flags) {
1810                assert(instr->pass_flags > 0);
1811                ctx.info[instr->definitions[0].tempId()].set_temp(
1812                   vopc_instr->definitions[0].getTemp());
1813                break;
1814             }
1815          }
1816       }
1817       FALLTHROUGH;
1818    case aco_opcode::s_or_b32:
1819    case aco_opcode::s_or_b64:
1820    case aco_opcode::s_xor_b32:
1821    case aco_opcode::s_xor_b64:
1822       if (std::all_of(instr->operands.begin(), instr->operands.end(),
1823                       [&ctx](const Operand& op)
1824                       {
1825                          return op.isTemp() && (ctx.info[op.tempId()].is_uniform_bool() ||
1826                                                 ctx.info[op.tempId()].is_uniform_bitwise());
1827                       })) {
1828          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1829       }
1830       FALLTHROUGH;
1831    case aco_opcode::s_lshl_b32:
1832    case aco_opcode::v_or_b32:
1833    case aco_opcode::v_lshlrev_b32:
1834    case aco_opcode::v_bcnt_u32_b32:
1835    case aco_opcode::v_and_b32:
1836    case aco_opcode::v_xor_b32:
1837       ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
1838       break;
1839    case aco_opcode::v_min_f32:
1840    case aco_opcode::v_min_f16:
1841    case aco_opcode::v_min_u32:
1842    case aco_opcode::v_min_i32:
1843    case aco_opcode::v_min_u16:
1844    case aco_opcode::v_min_i16:
1845    case aco_opcode::v_max_f32:
1846    case aco_opcode::v_max_f16:
1847    case aco_opcode::v_max_u32:
1848    case aco_opcode::v_max_i32:
1849    case aco_opcode::v_max_u16:
1850    case aco_opcode::v_max_i16:
1851       ctx.info[instr->definitions[0].tempId()].set_minmax(instr.get());
1852       break;
1853    case aco_opcode::s_cselect_b64:
1854    case aco_opcode::s_cselect_b32:
1855       if (instr->operands[0].constantEquals((unsigned)-1) && instr->operands[1].constantEquals(0)) {
1856          /* Found a cselect that operates on a uniform bool that comes from eg. s_cmp */
1857          ctx.info[instr->definitions[0].tempId()].set_uniform_bool(instr->operands[2].getTemp());
1858       }
1859       if (instr->operands[2].isTemp() && ctx.info[instr->operands[2].tempId()].is_scc_invert()) {
1860          /* Flip the operands to get rid of the scc_invert instruction */
1861          std::swap(instr->operands[0], instr->operands[1]);
1862          instr->operands[2].setTemp(ctx.info[instr->operands[2].tempId()].temp);
1863       }
1864       break;
1865    case aco_opcode::p_wqm:
1866       if (instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_scc_invert()) {
1867          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1868       }
1869       break;
1870    case aco_opcode::s_mul_i32:
1871       /* Testing every uint32_t shows that 0x3f800000*n is never a denormal.
1872        * This pattern is created from a uniform nir_op_b2f. */
1873       if (instr->operands[0].constantEquals(0x3f800000u))
1874          ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1875       break;
1876    case aco_opcode::p_extract: {
1877       if (instr->definitions[0].bytes() == 4) {
1878          ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
1879          if (instr->operands[0].regClass() == v1 && parse_insert(instr.get()))
1880             ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
1881       }
1882       break;
1883    }
1884    case aco_opcode::p_insert: {
1885       if (instr->operands[0].bytes() == 4) {
1886          if (instr->operands[0].regClass() == v1)
1887             ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
1888          if (parse_extract(instr.get()))
1889             ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
1890          ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
1891       }
1892       break;
1893    }
1894    case aco_opcode::ds_read_u8:
1895    case aco_opcode::ds_read_u8_d16:
1896    case aco_opcode::ds_read_u16:
1897    case aco_opcode::ds_read_u16_d16: {
1898       ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
1899       break;
1900    }
1901    case aco_opcode::v_cvt_f16_f32: {
1902       if (instr->operands[0].isTemp())
1903          ctx.info[instr->operands[0].tempId()].set_f2f16(instr.get());
1904       break;
1905    }
1906    case aco_opcode::v_cvt_f32_f16: {
1907       if (instr->operands[0].isTemp())
1908          ctx.info[instr->definitions[0].tempId()].set_f2f32(instr.get());
1909       break;
1910    }
1911    default: break;
1912    }
1913 
1914    /* Don't remove label_extract if we can't apply the extract to
1915     * neg/abs instructions because we'll likely combine it into another valu. */
1916    if (!(ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)))
1917       check_sdwa_extract(ctx, instr);
1918 }
1919 
1920 unsigned
original_temp_id(opt_ctx & ctx,Temp tmp)1921 original_temp_id(opt_ctx& ctx, Temp tmp)
1922 {
1923    if (ctx.info[tmp.id()].is_temp())
1924       return ctx.info[tmp.id()].temp.id();
1925    else
1926       return tmp.id();
1927 }
1928 
1929 void
decrease_uses(opt_ctx & ctx,Instruction * instr)1930 decrease_uses(opt_ctx& ctx, Instruction* instr)
1931 {
1932    if (!--ctx.uses[instr->definitions[0].tempId()]) {
1933       for (const Operand& op : instr->operands) {
1934          if (op.isTemp())
1935             ctx.uses[op.tempId()]--;
1936       }
1937    }
1938 }
1939 
1940 Instruction*
follow_operand(opt_ctx & ctx,Operand op,bool ignore_uses=false)1941 follow_operand(opt_ctx& ctx, Operand op, bool ignore_uses = false)
1942 {
1943    if (!op.isTemp() || !(ctx.info[op.tempId()].label & instr_usedef_labels))
1944       return nullptr;
1945    if (!ignore_uses && ctx.uses[op.tempId()] > 1)
1946       return nullptr;
1947 
1948    Instruction* instr = ctx.info[op.tempId()].instr;
1949 
1950    if (instr->definitions.size() == 2) {
1951       assert(instr->definitions[0].isTemp() && instr->definitions[0].tempId() == op.tempId());
1952       if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
1953          return nullptr;
1954    }
1955 
1956    return instr;
1957 }
1958 
1959 /* s_or_b64(neq(a, a), neq(b, b)) -> v_cmp_u_f32(a, b)
1960  * s_and_b64(eq(a, a), eq(b, b)) -> v_cmp_o_f32(a, b) */
1961 bool
combine_ordering_test(opt_ctx & ctx,aco_ptr<Instruction> & instr)1962 combine_ordering_test(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1963 {
1964    if (instr->definitions[0].regClass() != ctx.program->lane_mask)
1965       return false;
1966    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
1967       return false;
1968 
1969    bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
1970 
1971    bool neg[2] = {false, false};
1972    bool abs[2] = {false, false};
1973    uint8_t opsel = 0;
1974    Instruction* op_instr[2];
1975    Temp op[2];
1976 
1977    unsigned bitsize = 0;
1978    for (unsigned i = 0; i < 2; i++) {
1979       op_instr[i] = follow_operand(ctx, instr->operands[i], true);
1980       if (!op_instr[i])
1981          return false;
1982 
1983       aco_opcode expected_cmp = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
1984       unsigned op_bitsize = get_cmp_bitsize(op_instr[i]->opcode);
1985 
1986       if (get_f32_cmp(op_instr[i]->opcode) != expected_cmp)
1987          return false;
1988       if (bitsize && op_bitsize != bitsize)
1989          return false;
1990       if (!op_instr[i]->operands[0].isTemp() || !op_instr[i]->operands[1].isTemp())
1991          return false;
1992 
1993       if (op_instr[i]->isVOP3()) {
1994          VOP3_instruction& vop3 = op_instr[i]->vop3();
1995          if (vop3.neg[0] != vop3.neg[1] || vop3.abs[0] != vop3.abs[1] || vop3.opsel == 1 ||
1996              vop3.opsel == 2)
1997             return false;
1998          neg[i] = vop3.neg[0];
1999          abs[i] = vop3.abs[0];
2000          opsel |= (vop3.opsel & 1) << i;
2001       } else if (op_instr[i]->isSDWA()) {
2002          return false;
2003       }
2004 
2005       Temp op0 = op_instr[i]->operands[0].getTemp();
2006       Temp op1 = op_instr[i]->operands[1].getTemp();
2007       if (original_temp_id(ctx, op0) != original_temp_id(ctx, op1))
2008          return false;
2009 
2010       op[i] = op1;
2011       bitsize = op_bitsize;
2012    }
2013 
2014    if (op[1].type() == RegType::sgpr)
2015       std::swap(op[0], op[1]);
2016    unsigned num_sgprs = (op[0].type() == RegType::sgpr) + (op[1].type() == RegType::sgpr);
2017    if (num_sgprs > (ctx.program->chip_class >= GFX10 ? 2 : 1))
2018       return false;
2019 
2020    ctx.uses[op[0].id()]++;
2021    ctx.uses[op[1].id()]++;
2022    decrease_uses(ctx, op_instr[0]);
2023    decrease_uses(ctx, op_instr[1]);
2024 
2025    aco_opcode new_op = aco_opcode::num_opcodes;
2026    switch (bitsize) {
2027    case 16: new_op = is_or ? aco_opcode::v_cmp_u_f16 : aco_opcode::v_cmp_o_f16; break;
2028    case 32: new_op = is_or ? aco_opcode::v_cmp_u_f32 : aco_opcode::v_cmp_o_f32; break;
2029    case 64: new_op = is_or ? aco_opcode::v_cmp_u_f64 : aco_opcode::v_cmp_o_f64; break;
2030    }
2031    Instruction* new_instr;
2032    if (neg[0] || neg[1] || abs[0] || abs[1] || opsel || num_sgprs > 1) {
2033       VOP3_instruction* vop3 =
2034          create_instruction<VOP3_instruction>(new_op, asVOP3(Format::VOPC), 2, 1);
2035       for (unsigned i = 0; i < 2; i++) {
2036          vop3->neg[i] = neg[i];
2037          vop3->abs[i] = abs[i];
2038       }
2039       vop3->opsel = opsel;
2040       new_instr = static_cast<Instruction*>(vop3);
2041    } else {
2042       new_instr = create_instruction<VOPC_instruction>(new_op, Format::VOPC, 2, 1);
2043       instr->definitions[0].setHint(vcc);
2044    }
2045    new_instr->operands[0] = Operand(op[0]);
2046    new_instr->operands[1] = Operand(op[1]);
2047    new_instr->definitions[0] = instr->definitions[0];
2048 
2049    ctx.info[instr->definitions[0].tempId()].label = 0;
2050    ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2051 
2052    instr.reset(new_instr);
2053 
2054    return true;
2055 }
2056 
2057 /* s_or_b64(v_cmp_u_f32(a, b), cmp(a, b)) -> get_unordered(cmp)(a, b)
2058  * s_and_b64(v_cmp_o_f32(a, b), cmp(a, b)) -> get_ordered(cmp)(a, b) */
2059 bool
combine_comparison_ordering(opt_ctx & ctx,aco_ptr<Instruction> & instr)2060 combine_comparison_ordering(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2061 {
2062    if (instr->definitions[0].regClass() != ctx.program->lane_mask)
2063       return false;
2064    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2065       return false;
2066 
2067    bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
2068    aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_u_f32 : aco_opcode::v_cmp_o_f32;
2069 
2070    Instruction* nan_test = follow_operand(ctx, instr->operands[0], true);
2071    Instruction* cmp = follow_operand(ctx, instr->operands[1], true);
2072    if (!nan_test || !cmp)
2073       return false;
2074    if (nan_test->isSDWA() || cmp->isSDWA())
2075       return false;
2076 
2077    if (get_f32_cmp(cmp->opcode) == expected_nan_test)
2078       std::swap(nan_test, cmp);
2079    else if (get_f32_cmp(nan_test->opcode) != expected_nan_test)
2080       return false;
2081 
2082    if (!is_cmp(cmp->opcode) || get_cmp_bitsize(cmp->opcode) != get_cmp_bitsize(nan_test->opcode))
2083       return false;
2084 
2085    if (!nan_test->operands[0].isTemp() || !nan_test->operands[1].isTemp())
2086       return false;
2087    if (!cmp->operands[0].isTemp() || !cmp->operands[1].isTemp())
2088       return false;
2089 
2090    unsigned prop_cmp0 = original_temp_id(ctx, cmp->operands[0].getTemp());
2091    unsigned prop_cmp1 = original_temp_id(ctx, cmp->operands[1].getTemp());
2092    unsigned prop_nan0 = original_temp_id(ctx, nan_test->operands[0].getTemp());
2093    unsigned prop_nan1 = original_temp_id(ctx, nan_test->operands[1].getTemp());
2094    if (prop_cmp0 != prop_nan0 && prop_cmp0 != prop_nan1)
2095       return false;
2096    if (prop_cmp1 != prop_nan0 && prop_cmp1 != prop_nan1)
2097       return false;
2098 
2099    ctx.uses[cmp->operands[0].tempId()]++;
2100    ctx.uses[cmp->operands[1].tempId()]++;
2101    decrease_uses(ctx, nan_test);
2102    decrease_uses(ctx, cmp);
2103 
2104    aco_opcode new_op = is_or ? get_unordered(cmp->opcode) : get_ordered(cmp->opcode);
2105    Instruction* new_instr;
2106    if (cmp->isVOP3()) {
2107       VOP3_instruction* new_vop3 =
2108          create_instruction<VOP3_instruction>(new_op, asVOP3(Format::VOPC), 2, 1);
2109       VOP3_instruction& cmp_vop3 = cmp->vop3();
2110       memcpy(new_vop3->abs, cmp_vop3.abs, sizeof(new_vop3->abs));
2111       memcpy(new_vop3->neg, cmp_vop3.neg, sizeof(new_vop3->neg));
2112       new_vop3->clamp = cmp_vop3.clamp;
2113       new_vop3->omod = cmp_vop3.omod;
2114       new_vop3->opsel = cmp_vop3.opsel;
2115       new_instr = new_vop3;
2116    } else {
2117       new_instr = create_instruction<VOPC_instruction>(new_op, Format::VOPC, 2, 1);
2118       instr->definitions[0].setHint(vcc);
2119    }
2120    new_instr->operands[0] = cmp->operands[0];
2121    new_instr->operands[1] = cmp->operands[1];
2122    new_instr->definitions[0] = instr->definitions[0];
2123 
2124    ctx.info[instr->definitions[0].tempId()].label = 0;
2125    ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2126 
2127    instr.reset(new_instr);
2128 
2129    return true;
2130 }
2131 
2132 bool
is_operand_constant(opt_ctx & ctx,Operand op,unsigned bit_size,uint64_t * value)2133 is_operand_constant(opt_ctx& ctx, Operand op, unsigned bit_size, uint64_t* value)
2134 {
2135    if (op.isConstant()) {
2136       *value = op.constantValue64();
2137       return true;
2138    } else if (op.isTemp()) {
2139       unsigned id = original_temp_id(ctx, op.getTemp());
2140       if (!ctx.info[id].is_constant_or_literal(bit_size))
2141          return false;
2142       *value = get_constant_op(ctx, ctx.info[id], bit_size).constantValue64();
2143       return true;
2144    }
2145    return false;
2146 }
2147 
2148 bool
is_constant_nan(uint64_t value,unsigned bit_size)2149 is_constant_nan(uint64_t value, unsigned bit_size)
2150 {
2151    if (bit_size == 16)
2152       return ((value >> 10) & 0x1f) == 0x1f && (value & 0x3ff);
2153    else if (bit_size == 32)
2154       return ((value >> 23) & 0xff) == 0xff && (value & 0x7fffff);
2155    else
2156       return ((value >> 52) & 0x7ff) == 0x7ff && (value & 0xfffffffffffff);
2157 }
2158 
2159 /* s_or_b64(v_cmp_neq_f32(a, a), cmp(a, #b)) and b is not NaN -> get_unordered(cmp)(a, b)
2160  * s_and_b64(v_cmp_eq_f32(a, a), cmp(a, #b)) and b is not NaN -> get_ordered(cmp)(a, b) */
2161 bool
combine_constant_comparison_ordering(opt_ctx & ctx,aco_ptr<Instruction> & instr)2162 combine_constant_comparison_ordering(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2163 {
2164    if (instr->definitions[0].regClass() != ctx.program->lane_mask)
2165       return false;
2166    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2167       return false;
2168 
2169    bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
2170 
2171    Instruction* nan_test = follow_operand(ctx, instr->operands[0], true);
2172    Instruction* cmp = follow_operand(ctx, instr->operands[1], true);
2173 
2174    if (!nan_test || !cmp || nan_test->isSDWA() || cmp->isSDWA())
2175       return false;
2176    if (nan_test->isSDWA() || cmp->isSDWA())
2177       return false;
2178 
2179    aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
2180    if (get_f32_cmp(cmp->opcode) == expected_nan_test)
2181       std::swap(nan_test, cmp);
2182    else if (get_f32_cmp(nan_test->opcode) != expected_nan_test)
2183       return false;
2184 
2185    unsigned bit_size = get_cmp_bitsize(cmp->opcode);
2186    if (!is_cmp(cmp->opcode) || get_cmp_bitsize(nan_test->opcode) != bit_size)
2187       return false;
2188 
2189    if (!nan_test->operands[0].isTemp() || !nan_test->operands[1].isTemp())
2190       return false;
2191    if (!cmp->operands[0].isTemp() && !cmp->operands[1].isTemp())
2192       return false;
2193 
2194    unsigned prop_nan0 = original_temp_id(ctx, nan_test->operands[0].getTemp());
2195    unsigned prop_nan1 = original_temp_id(ctx, nan_test->operands[1].getTemp());
2196    if (prop_nan0 != prop_nan1)
2197       return false;
2198 
2199    if (nan_test->isVOP3()) {
2200       VOP3_instruction& vop3 = nan_test->vop3();
2201       if (vop3.neg[0] != vop3.neg[1] || vop3.abs[0] != vop3.abs[1] || vop3.opsel == 1 ||
2202           vop3.opsel == 2)
2203          return false;
2204    }
2205 
2206    int constant_operand = -1;
2207    for (unsigned i = 0; i < 2; i++) {
2208       if (cmp->operands[i].isTemp() &&
2209           original_temp_id(ctx, cmp->operands[i].getTemp()) == prop_nan0) {
2210          constant_operand = !i;
2211          break;
2212       }
2213    }
2214    if (constant_operand == -1)
2215       return false;
2216 
2217    uint64_t constant_value;
2218    if (!is_operand_constant(ctx, cmp->operands[constant_operand], bit_size, &constant_value))
2219       return false;
2220    if (is_constant_nan(constant_value, bit_size))
2221       return false;
2222 
2223    if (cmp->operands[0].isTemp())
2224       ctx.uses[cmp->operands[0].tempId()]++;
2225    if (cmp->operands[1].isTemp())
2226       ctx.uses[cmp->operands[1].tempId()]++;
2227    decrease_uses(ctx, nan_test);
2228    decrease_uses(ctx, cmp);
2229 
2230    aco_opcode new_op = is_or ? get_unordered(cmp->opcode) : get_ordered(cmp->opcode);
2231    Instruction* new_instr;
2232    if (cmp->isVOP3()) {
2233       VOP3_instruction* new_vop3 =
2234          create_instruction<VOP3_instruction>(new_op, asVOP3(Format::VOPC), 2, 1);
2235       VOP3_instruction& cmp_vop3 = cmp->vop3();
2236       memcpy(new_vop3->abs, cmp_vop3.abs, sizeof(new_vop3->abs));
2237       memcpy(new_vop3->neg, cmp_vop3.neg, sizeof(new_vop3->neg));
2238       new_vop3->clamp = cmp_vop3.clamp;
2239       new_vop3->omod = cmp_vop3.omod;
2240       new_vop3->opsel = cmp_vop3.opsel;
2241       new_instr = new_vop3;
2242    } else {
2243       new_instr = create_instruction<VOPC_instruction>(new_op, Format::VOPC, 2, 1);
2244       instr->definitions[0].setHint(vcc);
2245    }
2246    new_instr->operands[0] = cmp->operands[0];
2247    new_instr->operands[1] = cmp->operands[1];
2248    new_instr->definitions[0] = instr->definitions[0];
2249 
2250    ctx.info[instr->definitions[0].tempId()].label = 0;
2251    ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2252 
2253    instr.reset(new_instr);
2254 
2255    return true;
2256 }
2257 
2258 /* s_andn2(exec, cmp(a, b)) -> get_inverse(cmp)(a, b) */
2259 bool
combine_inverse_comparison(opt_ctx & ctx,aco_ptr<Instruction> & instr)2260 combine_inverse_comparison(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2261 {
2262    if (!instr->operands[0].isFixed() || instr->operands[0].physReg() != exec)
2263       return false;
2264    if (ctx.uses[instr->definitions[1].tempId()])
2265       return false;
2266 
2267    Instruction* cmp = follow_operand(ctx, instr->operands[1]);
2268    if (!cmp)
2269       return false;
2270 
2271    aco_opcode new_opcode = get_inverse(cmp->opcode);
2272    if (new_opcode == aco_opcode::num_opcodes)
2273       return false;
2274 
2275    if (cmp->operands[0].isTemp())
2276       ctx.uses[cmp->operands[0].tempId()]++;
2277    if (cmp->operands[1].isTemp())
2278       ctx.uses[cmp->operands[1].tempId()]++;
2279    decrease_uses(ctx, cmp);
2280 
2281    /* This creates a new instruction instead of modifying the existing
2282     * comparison so that the comparison is done with the correct exec mask. */
2283    Instruction* new_instr;
2284    if (cmp->isVOP3()) {
2285       VOP3_instruction* new_vop3 =
2286          create_instruction<VOP3_instruction>(new_opcode, asVOP3(Format::VOPC), 2, 1);
2287       VOP3_instruction& cmp_vop3 = cmp->vop3();
2288       memcpy(new_vop3->abs, cmp_vop3.abs, sizeof(new_vop3->abs));
2289       memcpy(new_vop3->neg, cmp_vop3.neg, sizeof(new_vop3->neg));
2290       new_vop3->clamp = cmp_vop3.clamp;
2291       new_vop3->omod = cmp_vop3.omod;
2292       new_vop3->opsel = cmp_vop3.opsel;
2293       new_instr = new_vop3;
2294    } else if (cmp->isSDWA()) {
2295       SDWA_instruction* new_sdwa = create_instruction<SDWA_instruction>(
2296          new_opcode, (Format)((uint16_t)Format::SDWA | (uint16_t)Format::VOPC), 2, 1);
2297       SDWA_instruction& cmp_sdwa = cmp->sdwa();
2298       memcpy(new_sdwa->abs, cmp_sdwa.abs, sizeof(new_sdwa->abs));
2299       memcpy(new_sdwa->sel, cmp_sdwa.sel, sizeof(new_sdwa->sel));
2300       memcpy(new_sdwa->neg, cmp_sdwa.neg, sizeof(new_sdwa->neg));
2301       new_sdwa->dst_sel = cmp_sdwa.dst_sel;
2302       new_sdwa->clamp = cmp_sdwa.clamp;
2303       new_sdwa->omod = cmp_sdwa.omod;
2304       new_instr = new_sdwa;
2305    } else if (cmp->isDPP16()) {
2306       DPP16_instruction* new_dpp = create_instruction<DPP16_instruction>(
2307          new_opcode, (Format)((uint16_t)Format::DPP16 | (uint16_t)Format::VOPC), 2, 1);
2308       DPP16_instruction& cmp_dpp = cmp->dpp16();
2309       memcpy(new_dpp->abs, cmp_dpp.abs, sizeof(new_dpp->abs));
2310       memcpy(new_dpp->neg, cmp_dpp.neg, sizeof(new_dpp->neg));
2311       new_dpp->dpp_ctrl = cmp_dpp.dpp_ctrl;
2312       new_dpp->row_mask = cmp_dpp.row_mask;
2313       new_dpp->bank_mask = cmp_dpp.bank_mask;
2314       new_dpp->bound_ctrl = cmp_dpp.bound_ctrl;
2315       new_instr = new_dpp;
2316    } else if (cmp->isDPP8()) {
2317       DPP8_instruction* new_dpp = create_instruction<DPP8_instruction>(
2318          new_opcode, (Format)((uint16_t)Format::DPP8 | (uint16_t)Format::VOPC), 2, 1);
2319       DPP8_instruction& cmp_dpp = cmp->dpp8();
2320       memcpy(new_dpp->lane_sel, cmp_dpp.lane_sel, sizeof(new_dpp->lane_sel));
2321       new_instr = new_dpp;
2322    } else {
2323       new_instr = create_instruction<VOPC_instruction>(new_opcode, Format::VOPC, 2, 1);
2324       instr->definitions[0].setHint(vcc);
2325    }
2326    new_instr->operands[0] = cmp->operands[0];
2327    new_instr->operands[1] = cmp->operands[1];
2328    new_instr->definitions[0] = instr->definitions[0];
2329 
2330    ctx.info[instr->definitions[0].tempId()].label = 0;
2331    ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2332 
2333    instr.reset(new_instr);
2334 
2335    return true;
2336 }
2337 
2338 /* op1(op2(1, 2), 0) if swap = false
2339  * op1(0, op2(1, 2)) if swap = true */
2340 bool
match_op3_for_vop3(opt_ctx & ctx,aco_opcode op1,aco_opcode op2,Instruction * op1_instr,bool swap,const char * shuffle_str,Operand operands[3],bool neg[3],bool abs[3],uint8_t * opsel,bool * op1_clamp,uint8_t * op1_omod,bool * inbetween_neg,bool * inbetween_abs,bool * inbetween_opsel,bool * precise)2341 match_op3_for_vop3(opt_ctx& ctx, aco_opcode op1, aco_opcode op2, Instruction* op1_instr, bool swap,
2342                    const char* shuffle_str, Operand operands[3], bool neg[3], bool abs[3],
2343                    uint8_t* opsel, bool* op1_clamp, uint8_t* op1_omod, bool* inbetween_neg,
2344                    bool* inbetween_abs, bool* inbetween_opsel, bool* precise)
2345 {
2346    /* checks */
2347    if (op1_instr->opcode != op1)
2348       return false;
2349 
2350    Instruction* op2_instr = follow_operand(ctx, op1_instr->operands[swap]);
2351    if (!op2_instr || op2_instr->opcode != op2)
2352       return false;
2353    if (fixed_to_exec(op2_instr->operands[0]) || fixed_to_exec(op2_instr->operands[1]))
2354       return false;
2355 
2356    VOP3_instruction* op1_vop3 = op1_instr->isVOP3() ? &op1_instr->vop3() : NULL;
2357    VOP3_instruction* op2_vop3 = op2_instr->isVOP3() ? &op2_instr->vop3() : NULL;
2358 
2359    if (op1_instr->isSDWA() || op2_instr->isSDWA())
2360       return false;
2361    if (op1_instr->isDPP() || op2_instr->isDPP())
2362       return false;
2363 
2364    /* don't support inbetween clamp/omod */
2365    if (op2_vop3 && (op2_vop3->clamp || op2_vop3->omod))
2366       return false;
2367 
2368    /* get operands and modifiers and check inbetween modifiers */
2369    *op1_clamp = op1_vop3 ? op1_vop3->clamp : false;
2370    *op1_omod = op1_vop3 ? op1_vop3->omod : 0u;
2371 
2372    if (inbetween_neg)
2373       *inbetween_neg = op1_vop3 ? op1_vop3->neg[swap] : false;
2374    else if (op1_vop3 && op1_vop3->neg[swap])
2375       return false;
2376 
2377    if (inbetween_abs)
2378       *inbetween_abs = op1_vop3 ? op1_vop3->abs[swap] : false;
2379    else if (op1_vop3 && op1_vop3->abs[swap])
2380       return false;
2381 
2382    if (inbetween_opsel)
2383       *inbetween_opsel = op1_vop3 ? op1_vop3->opsel & (1 << (unsigned)swap) : false;
2384    else if (op1_vop3 && op1_vop3->opsel & (1 << (unsigned)swap))
2385       return false;
2386 
2387    *precise = op1_instr->definitions[0].isPrecise() || op2_instr->definitions[0].isPrecise();
2388 
2389    int shuffle[3];
2390    shuffle[shuffle_str[0] - '0'] = 0;
2391    shuffle[shuffle_str[1] - '0'] = 1;
2392    shuffle[shuffle_str[2] - '0'] = 2;
2393 
2394    operands[shuffle[0]] = op1_instr->operands[!swap];
2395    neg[shuffle[0]] = op1_vop3 ? op1_vop3->neg[!swap] : false;
2396    abs[shuffle[0]] = op1_vop3 ? op1_vop3->abs[!swap] : false;
2397    if (op1_vop3 && (op1_vop3->opsel & (1 << (unsigned)!swap)))
2398       *opsel |= 1 << shuffle[0];
2399 
2400    for (unsigned i = 0; i < 2; i++) {
2401       operands[shuffle[i + 1]] = op2_instr->operands[i];
2402       neg[shuffle[i + 1]] = op2_vop3 ? op2_vop3->neg[i] : false;
2403       abs[shuffle[i + 1]] = op2_vop3 ? op2_vop3->abs[i] : false;
2404       if (op2_vop3 && op2_vop3->opsel & (1 << i))
2405          *opsel |= 1 << shuffle[i + 1];
2406    }
2407 
2408    /* check operands */
2409    if (!check_vop3_operands(ctx, 3, operands))
2410       return false;
2411 
2412    return true;
2413 }
2414 
2415 void
create_vop3_for_op3(opt_ctx & ctx,aco_opcode opcode,aco_ptr<Instruction> & instr,Operand operands[3],bool neg[3],bool abs[3],uint8_t opsel,bool clamp,unsigned omod)2416 create_vop3_for_op3(opt_ctx& ctx, aco_opcode opcode, aco_ptr<Instruction>& instr,
2417                     Operand operands[3], bool neg[3], bool abs[3], uint8_t opsel, bool clamp,
2418                     unsigned omod)
2419 {
2420    VOP3_instruction* new_instr = create_instruction<VOP3_instruction>(opcode, Format::VOP3, 3, 1);
2421    memcpy(new_instr->abs, abs, sizeof(bool[3]));
2422    memcpy(new_instr->neg, neg, sizeof(bool[3]));
2423    new_instr->clamp = clamp;
2424    new_instr->omod = omod;
2425    new_instr->opsel = opsel;
2426    new_instr->operands[0] = operands[0];
2427    new_instr->operands[1] = operands[1];
2428    new_instr->operands[2] = operands[2];
2429    new_instr->definitions[0] = instr->definitions[0];
2430    ctx.info[instr->definitions[0].tempId()].label = 0;
2431 
2432    instr.reset(new_instr);
2433 }
2434 
2435 bool
combine_three_valu_op(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode op2,aco_opcode new_op,const char * shuffle,uint8_t ops)2436 combine_three_valu_op(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode op2, aco_opcode new_op,
2437                       const char* shuffle, uint8_t ops)
2438 {
2439    for (unsigned swap = 0; swap < 2; swap++) {
2440       if (!((1 << swap) & ops))
2441          continue;
2442 
2443       Operand operands[3];
2444       bool neg[3], abs[3], clamp, precise;
2445       uint8_t opsel = 0, omod = 0;
2446       if (match_op3_for_vop3(ctx, instr->opcode, op2, instr.get(), swap, shuffle, operands, neg,
2447                              abs, &opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
2448          ctx.uses[instr->operands[swap].tempId()]--;
2449          create_vop3_for_op3(ctx, new_op, instr, operands, neg, abs, opsel, clamp, omod);
2450          return true;
2451       }
2452    }
2453    return false;
2454 }
2455 
2456 /* creates v_lshl_add_u32, v_lshl_or_b32 or v_and_or_b32 */
2457 bool
combine_add_or_then_and_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr)2458 combine_add_or_then_and_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2459 {
2460    bool is_or = instr->opcode == aco_opcode::v_or_b32;
2461    aco_opcode new_op_lshl = is_or ? aco_opcode::v_lshl_or_b32 : aco_opcode::v_lshl_add_u32;
2462 
2463    if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::s_and_b32, aco_opcode::v_and_or_b32,
2464                                       "120", 1 | 2))
2465       return true;
2466    if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32,
2467                                       "120", 1 | 2))
2468       return true;
2469    if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, new_op_lshl, "120", 1 | 2))
2470       return true;
2471    if (combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, new_op_lshl, "210", 1 | 2))
2472       return true;
2473 
2474    if (instr->isSDWA() || instr->isDPP())
2475       return false;
2476 
2477    /* v_or_b32(p_extract(a, 0, 8/16, 0), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2478     * v_or_b32(p_insert(a, 0, 8/16), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2479     * v_or_b32(p_insert(a, 24/16, 8/16), b) -> v_lshl_or_b32(a, 24/16, b)
2480     * v_add_u32(p_insert(a, 24/16, 8/16), b) -> v_lshl_add_b32(a, 24/16, b)
2481     */
2482    for (unsigned i = 0; i < 2; i++) {
2483       Instruction* extins = follow_operand(ctx, instr->operands[i]);
2484       if (!extins)
2485          continue;
2486 
2487       aco_opcode op;
2488       Operand operands[3];
2489 
2490       if (extins->opcode == aco_opcode::p_insert &&
2491           (extins->operands[1].constantValue() + 1) * extins->operands[2].constantValue() == 32) {
2492          op = new_op_lshl;
2493          operands[1] =
2494             Operand::c32(extins->operands[1].constantValue() * extins->operands[2].constantValue());
2495       } else if (is_or &&
2496                  (extins->opcode == aco_opcode::p_insert ||
2497                   (extins->opcode == aco_opcode::p_extract &&
2498                    extins->operands[3].constantEquals(0))) &&
2499                  extins->operands[1].constantEquals(0)) {
2500          op = aco_opcode::v_and_or_b32;
2501          operands[1] = Operand::c32(extins->operands[2].constantEquals(8) ? 0xffu : 0xffffu);
2502       } else {
2503          continue;
2504       }
2505 
2506       operands[0] = extins->operands[0];
2507       operands[2] = instr->operands[!i];
2508 
2509       if (!check_vop3_operands(ctx, 3, operands))
2510          continue;
2511 
2512       bool neg[3] = {}, abs[3] = {};
2513       uint8_t opsel = 0, omod = 0;
2514       bool clamp = false;
2515       if (instr->isVOP3())
2516          clamp = instr->vop3().clamp;
2517 
2518       ctx.uses[instr->operands[i].tempId()]--;
2519       create_vop3_for_op3(ctx, op, instr, operands, neg, abs, opsel, clamp, omod);
2520       return true;
2521    }
2522 
2523    return false;
2524 }
2525 
2526 bool
combine_minmax(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode opposite,aco_opcode minmax3)2527 combine_minmax(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode opposite, aco_opcode minmax3)
2528 {
2529    /* TODO: this can handle SDWA min/max instructions by using opsel */
2530    if (combine_three_valu_op(ctx, instr, instr->opcode, minmax3, "012", 1 | 2))
2531       return true;
2532 
2533    /* min(-max(a, b), c) -> min3(c, -a, -b) *
2534     * max(-min(a, b), c) -> max3(c, -a, -b) */
2535    for (unsigned swap = 0; swap < 2; swap++) {
2536       Operand operands[3];
2537       bool neg[3], abs[3], clamp, precise;
2538       uint8_t opsel = 0, omod = 0;
2539       bool inbetween_neg;
2540       if (match_op3_for_vop3(ctx, instr->opcode, opposite, instr.get(), swap, "012", operands, neg,
2541                              abs, &opsel, &clamp, &omod, &inbetween_neg, NULL, NULL, &precise) &&
2542           inbetween_neg) {
2543          ctx.uses[instr->operands[swap].tempId()]--;
2544          neg[1] = !neg[1];
2545          neg[2] = !neg[2];
2546          create_vop3_for_op3(ctx, minmax3, instr, operands, neg, abs, opsel, clamp, omod);
2547          return true;
2548       }
2549    }
2550    return false;
2551 }
2552 
2553 /* s_not_b32(s_and_b32(a, b)) -> s_nand_b32(a, b)
2554  * s_not_b32(s_or_b32(a, b)) -> s_nor_b32(a, b)
2555  * s_not_b32(s_xor_b32(a, b)) -> s_xnor_b32(a, b)
2556  * s_not_b64(s_and_b64(a, b)) -> s_nand_b64(a, b)
2557  * s_not_b64(s_or_b64(a, b)) -> s_nor_b64(a, b)
2558  * s_not_b64(s_xor_b64(a, b)) -> s_xnor_b64(a, b) */
2559 bool
combine_salu_not_bitwise(opt_ctx & ctx,aco_ptr<Instruction> & instr)2560 combine_salu_not_bitwise(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2561 {
2562    /* checks */
2563    if (!instr->operands[0].isTemp())
2564       return false;
2565    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2566       return false;
2567 
2568    Instruction* op2_instr = follow_operand(ctx, instr->operands[0]);
2569    if (!op2_instr)
2570       return false;
2571    switch (op2_instr->opcode) {
2572    case aco_opcode::s_and_b32:
2573    case aco_opcode::s_or_b32:
2574    case aco_opcode::s_xor_b32:
2575    case aco_opcode::s_and_b64:
2576    case aco_opcode::s_or_b64:
2577    case aco_opcode::s_xor_b64: break;
2578    default: return false;
2579    }
2580 
2581    /* create instruction */
2582    std::swap(instr->definitions[0], op2_instr->definitions[0]);
2583    std::swap(instr->definitions[1], op2_instr->definitions[1]);
2584    ctx.uses[instr->operands[0].tempId()]--;
2585    ctx.info[op2_instr->definitions[0].tempId()].label = 0;
2586 
2587    switch (op2_instr->opcode) {
2588    case aco_opcode::s_and_b32: op2_instr->opcode = aco_opcode::s_nand_b32; break;
2589    case aco_opcode::s_or_b32: op2_instr->opcode = aco_opcode::s_nor_b32; break;
2590    case aco_opcode::s_xor_b32: op2_instr->opcode = aco_opcode::s_xnor_b32; break;
2591    case aco_opcode::s_and_b64: op2_instr->opcode = aco_opcode::s_nand_b64; break;
2592    case aco_opcode::s_or_b64: op2_instr->opcode = aco_opcode::s_nor_b64; break;
2593    case aco_opcode::s_xor_b64: op2_instr->opcode = aco_opcode::s_xnor_b64; break;
2594    default: break;
2595    }
2596 
2597    return true;
2598 }
2599 
2600 /* s_and_b32(a, s_not_b32(b)) -> s_andn2_b32(a, b)
2601  * s_or_b32(a, s_not_b32(b)) -> s_orn2_b32(a, b)
2602  * s_and_b64(a, s_not_b64(b)) -> s_andn2_b64(a, b)
2603  * s_or_b64(a, s_not_b64(b)) -> s_orn2_b64(a, b) */
2604 bool
combine_salu_n2(opt_ctx & ctx,aco_ptr<Instruction> & instr)2605 combine_salu_n2(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2606 {
2607    if (instr->definitions[0].isTemp() && ctx.info[instr->definitions[0].tempId()].is_uniform_bool())
2608       return false;
2609 
2610    for (unsigned i = 0; i < 2; i++) {
2611       Instruction* op2_instr = follow_operand(ctx, instr->operands[i]);
2612       if (!op2_instr || (op2_instr->opcode != aco_opcode::s_not_b32 &&
2613                          op2_instr->opcode != aco_opcode::s_not_b64))
2614          continue;
2615       if (ctx.uses[op2_instr->definitions[1].tempId()] || fixed_to_exec(op2_instr->operands[0]))
2616          continue;
2617 
2618       if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
2619           instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
2620          continue;
2621 
2622       ctx.uses[instr->operands[i].tempId()]--;
2623       instr->operands[0] = instr->operands[!i];
2624       instr->operands[1] = op2_instr->operands[0];
2625       ctx.info[instr->definitions[0].tempId()].label = 0;
2626 
2627       switch (instr->opcode) {
2628       case aco_opcode::s_and_b32: instr->opcode = aco_opcode::s_andn2_b32; break;
2629       case aco_opcode::s_or_b32: instr->opcode = aco_opcode::s_orn2_b32; break;
2630       case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_andn2_b64; break;
2631       case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_orn2_b64; break;
2632       default: break;
2633       }
2634 
2635       return true;
2636    }
2637    return false;
2638 }
2639 
2640 /* s_add_{i32,u32}(a, s_lshl_b32(b, <n>)) -> s_lshl<n>_add_u32(a, b) */
2641 bool
combine_salu_lshl_add(opt_ctx & ctx,aco_ptr<Instruction> & instr)2642 combine_salu_lshl_add(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2643 {
2644    if (instr->opcode == aco_opcode::s_add_i32 && ctx.uses[instr->definitions[1].tempId()])
2645       return false;
2646 
2647    for (unsigned i = 0; i < 2; i++) {
2648       Instruction* op2_instr = follow_operand(ctx, instr->operands[i], true);
2649       if (!op2_instr || op2_instr->opcode != aco_opcode::s_lshl_b32 ||
2650           ctx.uses[op2_instr->definitions[1].tempId()])
2651          continue;
2652       if (!op2_instr->operands[1].isConstant() || fixed_to_exec(op2_instr->operands[0]))
2653          continue;
2654 
2655       uint32_t shift = op2_instr->operands[1].constantValue();
2656       if (shift < 1 || shift > 4)
2657          continue;
2658 
2659       if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
2660           instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
2661          continue;
2662 
2663       ctx.uses[instr->operands[i].tempId()]--;
2664       instr->operands[1] = instr->operands[!i];
2665       instr->operands[0] = op2_instr->operands[0];
2666       ctx.info[instr->definitions[0].tempId()].label = 0;
2667 
2668       instr->opcode = std::array<aco_opcode, 4>{
2669          aco_opcode::s_lshl1_add_u32, aco_opcode::s_lshl2_add_u32, aco_opcode::s_lshl3_add_u32,
2670          aco_opcode::s_lshl4_add_u32}[shift - 1];
2671 
2672       return true;
2673    }
2674    return false;
2675 }
2676 
2677 bool
combine_add_sub_b2i(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode new_op,uint8_t ops)2678 combine_add_sub_b2i(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode new_op, uint8_t ops)
2679 {
2680    if (instr->usesModifiers())
2681       return false;
2682 
2683    for (unsigned i = 0; i < 2; i++) {
2684       if (!((1 << i) & ops))
2685          continue;
2686       if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2i() &&
2687           ctx.uses[instr->operands[i].tempId()] == 1) {
2688 
2689          aco_ptr<Instruction> new_instr;
2690          if (instr->operands[!i].isTemp() &&
2691              instr->operands[!i].getTemp().type() == RegType::vgpr) {
2692             new_instr.reset(create_instruction<VOP2_instruction>(new_op, Format::VOP2, 3, 2));
2693          } else if (ctx.program->chip_class >= GFX10 ||
2694                     (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
2695             new_instr.reset(
2696                create_instruction<VOP3_instruction>(new_op, asVOP3(Format::VOP2), 3, 2));
2697          } else {
2698             return false;
2699          }
2700          ctx.uses[instr->operands[i].tempId()]--;
2701          new_instr->definitions[0] = instr->definitions[0];
2702          if (instr->definitions.size() == 2) {
2703             new_instr->definitions[1] = instr->definitions[1];
2704          } else {
2705             new_instr->definitions[1] =
2706                Definition(ctx.program->allocateTmp(ctx.program->lane_mask));
2707             /* Make sure the uses vector is large enough and the number of
2708              * uses properly initialized to 0.
2709              */
2710             ctx.uses.push_back(0);
2711          }
2712          new_instr->definitions[1].setHint(vcc);
2713          new_instr->operands[0] = Operand::zero();
2714          new_instr->operands[1] = instr->operands[!i];
2715          new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
2716          instr = std::move(new_instr);
2717          ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
2718          return true;
2719       }
2720    }
2721 
2722    return false;
2723 }
2724 
2725 bool
combine_add_bcnt(opt_ctx & ctx,aco_ptr<Instruction> & instr)2726 combine_add_bcnt(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2727 {
2728    if (instr->usesModifiers())
2729       return false;
2730 
2731    for (unsigned i = 0; i < 2; i++) {
2732       Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
2733       if (op_instr && op_instr->opcode == aco_opcode::v_bcnt_u32_b32 &&
2734           !op_instr->usesModifiers() && op_instr->operands[0].isTemp() &&
2735           op_instr->operands[0].getTemp().type() == RegType::vgpr &&
2736           op_instr->operands[1].constantEquals(0)) {
2737          aco_ptr<Instruction> new_instr{
2738             create_instruction<VOP3_instruction>(aco_opcode::v_bcnt_u32_b32, Format::VOP3, 2, 1)};
2739          ctx.uses[instr->operands[i].tempId()]--;
2740          new_instr->operands[0] = op_instr->operands[0];
2741          new_instr->operands[1] = instr->operands[!i];
2742          new_instr->definitions[0] = instr->definitions[0];
2743          instr = std::move(new_instr);
2744          ctx.info[instr->definitions[0].tempId()].label = 0;
2745 
2746          return true;
2747       }
2748    }
2749 
2750    return false;
2751 }
2752 
2753 bool
get_minmax_info(aco_opcode op,aco_opcode * min,aco_opcode * max,aco_opcode * min3,aco_opcode * max3,aco_opcode * med3,bool * some_gfx9_only)2754 get_minmax_info(aco_opcode op, aco_opcode* min, aco_opcode* max, aco_opcode* min3, aco_opcode* max3,
2755                 aco_opcode* med3, bool* some_gfx9_only)
2756 {
2757    switch (op) {
2758 #define MINMAX(type, gfx9)                                                                         \
2759    case aco_opcode::v_min_##type:                                                                  \
2760    case aco_opcode::v_max_##type:                                                                  \
2761    case aco_opcode::v_med3_##type:                                                                 \
2762       *min = aco_opcode::v_min_##type;                                                             \
2763       *max = aco_opcode::v_max_##type;                                                             \
2764       *med3 = aco_opcode::v_med3_##type;                                                           \
2765       *min3 = aco_opcode::v_min3_##type;                                                           \
2766       *max3 = aco_opcode::v_max3_##type;                                                           \
2767       *some_gfx9_only = gfx9;                                                                      \
2768       return true;
2769       MINMAX(f32, false)
2770       MINMAX(u32, false)
2771       MINMAX(i32, false)
2772       MINMAX(f16, true)
2773       MINMAX(u16, true)
2774       MINMAX(i16, true)
2775 #undef MINMAX
2776    default: return false;
2777    }
2778 }
2779 
2780 /* when ub > lb:
2781  * v_min_{f,u,i}{16,32}(v_max_{f,u,i}{16,32}(a, lb), ub) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
2782  * v_max_{f,u,i}{16,32}(v_min_{f,u,i}{16,32}(a, ub), lb) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
2783  */
2784 bool
combine_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode min,aco_opcode max,aco_opcode med)2785 combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode min, aco_opcode max,
2786               aco_opcode med)
2787 {
2788    /* TODO: GLSL's clamp(x, minVal, maxVal) and SPIR-V's
2789     * FClamp(x, minVal, maxVal)/NClamp(x, minVal, maxVal) are undefined if
2790     * minVal > maxVal, which means we can always select it to a v_med3_f32 */
2791    aco_opcode other_op;
2792    if (instr->opcode == min)
2793       other_op = max;
2794    else if (instr->opcode == max)
2795       other_op = min;
2796    else
2797       return false;
2798 
2799    for (unsigned swap = 0; swap < 2; swap++) {
2800       Operand operands[3];
2801       bool neg[3], abs[3], clamp, precise;
2802       uint8_t opsel = 0, omod = 0;
2803       if (match_op3_for_vop3(ctx, instr->opcode, other_op, instr.get(), swap, "012", operands, neg,
2804                              abs, &opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
2805          /* max(min(src, upper), lower) returns upper if src is NaN, but
2806           * med3(src, lower, upper) returns lower.
2807           */
2808          if (precise && instr->opcode != min)
2809             continue;
2810 
2811          int const0_idx = -1, const1_idx = -1;
2812          uint32_t const0 = 0, const1 = 0;
2813          for (int i = 0; i < 3; i++) {
2814             uint32_t val;
2815             if (operands[i].isConstant()) {
2816                val = operands[i].constantValue();
2817             } else if (operands[i].isTemp() &&
2818                        ctx.info[operands[i].tempId()].is_constant_or_literal(32)) {
2819                val = ctx.info[operands[i].tempId()].val;
2820             } else {
2821                continue;
2822             }
2823             if (const0_idx >= 0) {
2824                const1_idx = i;
2825                const1 = val;
2826             } else {
2827                const0_idx = i;
2828                const0 = val;
2829             }
2830          }
2831          if (const0_idx < 0 || const1_idx < 0)
2832             continue;
2833 
2834          if (opsel & (1 << const0_idx))
2835             const0 >>= 16;
2836          if (opsel & (1 << const1_idx))
2837             const1 >>= 16;
2838 
2839          int lower_idx = const0_idx;
2840          switch (min) {
2841          case aco_opcode::v_min_f32:
2842          case aco_opcode::v_min_f16: {
2843             float const0_f, const1_f;
2844             if (min == aco_opcode::v_min_f32) {
2845                memcpy(&const0_f, &const0, 4);
2846                memcpy(&const1_f, &const1, 4);
2847             } else {
2848                const0_f = _mesa_half_to_float(const0);
2849                const1_f = _mesa_half_to_float(const1);
2850             }
2851             if (abs[const0_idx])
2852                const0_f = fabsf(const0_f);
2853             if (abs[const1_idx])
2854                const1_f = fabsf(const1_f);
2855             if (neg[const0_idx])
2856                const0_f = -const0_f;
2857             if (neg[const1_idx])
2858                const1_f = -const1_f;
2859             lower_idx = const0_f < const1_f ? const0_idx : const1_idx;
2860             break;
2861          }
2862          case aco_opcode::v_min_u32: {
2863             lower_idx = const0 < const1 ? const0_idx : const1_idx;
2864             break;
2865          }
2866          case aco_opcode::v_min_u16: {
2867             lower_idx = (uint16_t)const0 < (uint16_t)const1 ? const0_idx : const1_idx;
2868             break;
2869          }
2870          case aco_opcode::v_min_i32: {
2871             int32_t const0_i =
2872                const0 & 0x80000000u ? -2147483648 + (int32_t)(const0 & 0x7fffffffu) : const0;
2873             int32_t const1_i =
2874                const1 & 0x80000000u ? -2147483648 + (int32_t)(const1 & 0x7fffffffu) : const1;
2875             lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
2876             break;
2877          }
2878          case aco_opcode::v_min_i16: {
2879             int16_t const0_i = const0 & 0x8000u ? -32768 + (int16_t)(const0 & 0x7fffu) : const0;
2880             int16_t const1_i = const1 & 0x8000u ? -32768 + (int16_t)(const1 & 0x7fffu) : const1;
2881             lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
2882             break;
2883          }
2884          default: break;
2885          }
2886          int upper_idx = lower_idx == const0_idx ? const1_idx : const0_idx;
2887 
2888          if (instr->opcode == min) {
2889             if (upper_idx != 0 || lower_idx == 0)
2890                return false;
2891          } else {
2892             if (upper_idx == 0 || lower_idx != 0)
2893                return false;
2894          }
2895 
2896          ctx.uses[instr->operands[swap].tempId()]--;
2897          create_vop3_for_op3(ctx, med, instr, operands, neg, abs, opsel, clamp, omod);
2898 
2899          return true;
2900       }
2901    }
2902 
2903    return false;
2904 }
2905 
2906 void
apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)2907 apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2908 {
2909    bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64 ||
2910                      instr->opcode == aco_opcode::v_lshrrev_b64 ||
2911                      instr->opcode == aco_opcode::v_ashrrev_i64;
2912 
2913    /* find candidates and create the set of sgprs already read */
2914    unsigned sgpr_ids[2] = {0, 0};
2915    uint32_t operand_mask = 0;
2916    bool has_literal = false;
2917    for (unsigned i = 0; i < instr->operands.size(); i++) {
2918       if (instr->operands[i].isLiteral())
2919          has_literal = true;
2920       if (!instr->operands[i].isTemp())
2921          continue;
2922       if (instr->operands[i].getTemp().type() == RegType::sgpr) {
2923          if (instr->operands[i].tempId() != sgpr_ids[0])
2924             sgpr_ids[!!sgpr_ids[0]] = instr->operands[i].tempId();
2925       }
2926       ssa_info& info = ctx.info[instr->operands[i].tempId()];
2927       if (is_copy_label(ctx, instr, info) && info.temp.type() == RegType::sgpr)
2928          operand_mask |= 1u << i;
2929       if (info.is_extract() && info.instr->operands[0].getTemp().type() == RegType::sgpr)
2930          operand_mask |= 1u << i;
2931    }
2932    unsigned max_sgprs = 1;
2933    if (ctx.program->chip_class >= GFX10 && !is_shift64)
2934       max_sgprs = 2;
2935    if (has_literal)
2936       max_sgprs--;
2937 
2938    unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
2939 
2940    /* keep on applying sgprs until there is nothing left to be done */
2941    while (operand_mask) {
2942       uint32_t sgpr_idx = 0;
2943       uint32_t sgpr_info_id = 0;
2944       uint32_t mask = operand_mask;
2945       /* choose a sgpr */
2946       while (mask) {
2947          unsigned i = u_bit_scan(&mask);
2948          uint16_t uses = ctx.uses[instr->operands[i].tempId()];
2949          if (sgpr_info_id == 0 || uses < ctx.uses[sgpr_info_id]) {
2950             sgpr_idx = i;
2951             sgpr_info_id = instr->operands[i].tempId();
2952          }
2953       }
2954       operand_mask &= ~(1u << sgpr_idx);
2955 
2956       ssa_info& info = ctx.info[sgpr_info_id];
2957 
2958       /* Applying two sgprs require making it VOP3, so don't do it unless it's
2959        * definitively beneficial.
2960        * TODO: this is too conservative because later the use count could be reduced to 1 */
2961       if (!info.is_extract() && num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3() &&
2962           !instr->isSDWA() && instr->format != Format::VOP3P)
2963          break;
2964 
2965       Temp sgpr = info.is_extract() ? info.instr->operands[0].getTemp() : info.temp;
2966       bool new_sgpr = sgpr.id() != sgpr_ids[0] && sgpr.id() != sgpr_ids[1];
2967       if (new_sgpr && num_sgprs >= max_sgprs)
2968          continue;
2969 
2970       if (sgpr_idx == 0)
2971          instr->format = withoutDPP(instr->format);
2972 
2973       if (sgpr_idx == 0 || instr->isVOP3() || instr->isSDWA() || instr->isVOP3P() ||
2974           info.is_extract()) {
2975          /* can_apply_extract() checks SGPR encoding restrictions */
2976          if (info.is_extract() && can_apply_extract(ctx, instr, sgpr_idx, info))
2977             apply_extract(ctx, instr, sgpr_idx, info);
2978          else if (info.is_extract())
2979             continue;
2980          instr->operands[sgpr_idx] = Operand(sgpr);
2981       } else if (can_swap_operands(instr, &instr->opcode)) {
2982          instr->operands[sgpr_idx] = instr->operands[0];
2983          instr->operands[0] = Operand(sgpr);
2984          /* swap bits using a 4-entry LUT */
2985          uint32_t swapped = (0x3120 >> (operand_mask & 0x3)) & 0xf;
2986          operand_mask = (operand_mask & ~0x3) | swapped;
2987       } else if (can_use_VOP3(ctx, instr) && !info.is_extract()) {
2988          to_VOP3(ctx, instr);
2989          instr->operands[sgpr_idx] = Operand(sgpr);
2990       } else {
2991          continue;
2992       }
2993 
2994       if (new_sgpr)
2995          sgpr_ids[num_sgprs++] = sgpr.id();
2996       ctx.uses[sgpr_info_id]--;
2997       ctx.uses[sgpr.id()]++;
2998 
2999       /* TODO: handle when it's a VGPR */
3000       if ((ctx.info[sgpr.id()].label & (label_extract | label_temp)) &&
3001           ctx.info[sgpr.id()].temp.type() == RegType::sgpr)
3002          operand_mask |= 1u << sgpr_idx;
3003    }
3004 }
3005 
3006 template <typename T>
3007 bool
apply_omod_clamp_helper(opt_ctx & ctx,T * instr,ssa_info & def_info)3008 apply_omod_clamp_helper(opt_ctx& ctx, T* instr, ssa_info& def_info)
3009 {
3010    if (!def_info.is_clamp() && (instr->clamp || instr->omod))
3011       return false;
3012 
3013    if (def_info.is_omod2())
3014       instr->omod = 1;
3015    else if (def_info.is_omod4())
3016       instr->omod = 2;
3017    else if (def_info.is_omod5())
3018       instr->omod = 3;
3019    else if (def_info.is_clamp())
3020       instr->clamp = true;
3021 
3022    return true;
3023 }
3024 
3025 /* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */
3026 bool
apply_omod_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr)3027 apply_omod_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3028 {
3029    if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1 ||
3030        !instr_info.can_use_output_modifiers[(int)instr->opcode])
3031       return false;
3032 
3033    bool can_vop3 = can_use_VOP3(ctx, instr);
3034    bool is_mad_mix =
3035       instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16;
3036    if (!instr->isSDWA() && !is_mad_mix && !can_vop3)
3037       return false;
3038 
3039    /* omod flushes -0 to +0 and has no effect if denormals are enabled. SDWA omod is GFX9+. */
3040    bool can_use_omod = (can_vop3 || ctx.program->chip_class >= GFX9) && !instr->isVOP3P();
3041    if (instr->definitions[0].bytes() == 4)
3042       can_use_omod =
3043          can_use_omod && ctx.fp_mode.denorm32 == 0 && !ctx.fp_mode.preserve_signed_zero_inf_nan32;
3044    else
3045       can_use_omod = can_use_omod && ctx.fp_mode.denorm16_64 == 0 &&
3046                      !ctx.fp_mode.preserve_signed_zero_inf_nan16_64;
3047 
3048    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3049 
3050    uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5;
3051    if (!def_info.is_clamp() && !(can_use_omod && (def_info.label & omod_labels)))
3052       return false;
3053    /* if the omod/clamp instruction is dead, then the single user of this
3054     * instruction is a different instruction */
3055    if (!ctx.uses[def_info.instr->definitions[0].tempId()])
3056       return false;
3057 
3058    if (def_info.instr->definitions[0].bytes() != instr->definitions[0].bytes())
3059       return false;
3060 
3061    /* MADs/FMAs are created later, so we don't have to update the original add */
3062    assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
3063 
3064    if (instr->isSDWA()) {
3065       if (!apply_omod_clamp_helper(ctx, &instr->sdwa(), def_info))
3066          return false;
3067    } else if (instr->isVOP3P()) {
3068       assert(def_info.is_clamp());
3069       instr->vop3p().clamp = true;
3070    } else {
3071       to_VOP3(ctx, instr);
3072       if (!apply_omod_clamp_helper(ctx, &instr->vop3(), def_info))
3073          return false;
3074    }
3075 
3076    instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
3077    ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_insert | label_f2f16;
3078    ctx.uses[def_info.instr->definitions[0].tempId()]--;
3079 
3080    return true;
3081 }
3082 
3083 /* Combine an p_insert (or p_extract, in some cases) instruction with instr.
3084  * p_insert(instr(...)) -> instr_insert().
3085  */
3086 bool
apply_insert(opt_ctx & ctx,aco_ptr<Instruction> & instr)3087 apply_insert(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3088 {
3089    if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1)
3090       return false;
3091 
3092    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3093    if (!def_info.is_insert())
3094       return false;
3095    /* if the insert instruction is dead, then the single user of this
3096     * instruction is a different instruction */
3097    if (!ctx.uses[def_info.instr->definitions[0].tempId()])
3098       return false;
3099 
3100    /* MADs/FMAs are created later, so we don't have to update the original add */
3101    assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
3102 
3103    SubdwordSel sel = parse_insert(def_info.instr);
3104    assert(sel);
3105 
3106    if (instr->isVOP3() && sel.size() == 2 && !sel.sign_extend() &&
3107        can_use_opsel(ctx.program->chip_class, instr->opcode, -1)) {
3108       if (instr->vop3().opsel & (1 << 3))
3109          return false;
3110       if (sel.offset())
3111          instr->vop3().opsel |= 1 << 3;
3112    } else {
3113       if (!can_use_SDWA(ctx.program->chip_class, instr, true))
3114          return false;
3115 
3116       to_SDWA(ctx, instr);
3117       if (instr->sdwa().dst_sel.size() != 4)
3118          return false;
3119       static_cast<SDWA_instruction*>(instr.get())->dst_sel = sel;
3120    }
3121 
3122    instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
3123    ctx.info[instr->definitions[0].tempId()].label = 0;
3124    ctx.uses[def_info.instr->definitions[0].tempId()]--;
3125 
3126    return true;
3127 }
3128 
3129 /* Remove superfluous extract after ds_read like so:
3130  * p_extract(ds_read_uN(), 0, N, 0) -> ds_read_uN()
3131  */
3132 bool
apply_ds_extract(opt_ctx & ctx,aco_ptr<Instruction> & extract)3133 apply_ds_extract(opt_ctx& ctx, aco_ptr<Instruction>& extract)
3134 {
3135    /* Check if p_extract has a usedef operand and is the only user. */
3136    if (!ctx.info[extract->operands[0].tempId()].is_usedef() ||
3137        ctx.uses[extract->operands[0].tempId()] > 1)
3138       return false;
3139 
3140    /* Check if the usedef is a DS instruction. */
3141    Instruction* ds = ctx.info[extract->operands[0].tempId()].instr;
3142    if (ds->format != Format::DS)
3143       return false;
3144 
3145    unsigned extract_idx = extract->operands[1].constantValue();
3146    unsigned bits_extracted = extract->operands[2].constantValue();
3147    unsigned sign_ext = extract->operands[3].constantValue();
3148    unsigned dst_bitsize = extract->definitions[0].bytes() * 8u;
3149 
3150    /* TODO: These are doable, but probably don't occour too often. */
3151    if (extract_idx || sign_ext || dst_bitsize != 32)
3152       return false;
3153 
3154    unsigned bits_loaded = 0;
3155    if (ds->opcode == aco_opcode::ds_read_u8 || ds->opcode == aco_opcode::ds_read_u8_d16)
3156       bits_loaded = 8;
3157    else if (ds->opcode == aco_opcode::ds_read_u16 || ds->opcode == aco_opcode::ds_read_u16_d16)
3158       bits_loaded = 16;
3159    else
3160       return false;
3161 
3162    /* Shrink the DS load if the extracted bit size is smaller. */
3163    bits_loaded = MIN2(bits_loaded, bits_extracted);
3164 
3165    /* Change the DS opcode so it writes the full register. */
3166    if (bits_loaded == 8)
3167       ds->opcode = aco_opcode::ds_read_u8;
3168    else if (bits_loaded == 16)
3169       ds->opcode = aco_opcode::ds_read_u16;
3170    else
3171       unreachable("Forgot to add DS opcode above.");
3172 
3173    /* The DS now produces the exact same thing as the extract, remove the extract. */
3174    std::swap(ds->definitions[0], extract->definitions[0]);
3175    ctx.uses[extract->definitions[0].tempId()] = 0;
3176    ctx.info[ds->definitions[0].tempId()].label = 0;
3177    return true;
3178 }
3179 
3180 /* v_and(a, v_subbrev_co(0, 0, vcc)) -> v_cndmask(0, a, vcc) */
3181 bool
combine_and_subbrev(opt_ctx & ctx,aco_ptr<Instruction> & instr)3182 combine_and_subbrev(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3183 {
3184    if (instr->usesModifiers())
3185       return false;
3186 
3187    for (unsigned i = 0; i < 2; i++) {
3188       Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3189       if (op_instr && op_instr->opcode == aco_opcode::v_subbrev_co_u32 &&
3190           op_instr->operands[0].constantEquals(0) && op_instr->operands[1].constantEquals(0) &&
3191           !op_instr->usesModifiers()) {
3192 
3193          aco_ptr<Instruction> new_instr;
3194          if (instr->operands[!i].isTemp() &&
3195              instr->operands[!i].getTemp().type() == RegType::vgpr) {
3196             new_instr.reset(
3197                create_instruction<VOP2_instruction>(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1));
3198          } else if (ctx.program->chip_class >= GFX10 ||
3199                     (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
3200             new_instr.reset(create_instruction<VOP3_instruction>(aco_opcode::v_cndmask_b32,
3201                                                                  asVOP3(Format::VOP2), 3, 1));
3202          } else {
3203             return false;
3204          }
3205 
3206          ctx.uses[instr->operands[i].tempId()]--;
3207          if (ctx.uses[instr->operands[i].tempId()])
3208             ctx.uses[op_instr->operands[2].tempId()]++;
3209 
3210          new_instr->operands[0] = Operand::zero();
3211          new_instr->operands[1] = instr->operands[!i];
3212          new_instr->operands[2] = Operand(op_instr->operands[2]);
3213          new_instr->definitions[0] = instr->definitions[0];
3214          instr = std::move(new_instr);
3215          ctx.info[instr->definitions[0].tempId()].label = 0;
3216          return true;
3217       }
3218    }
3219 
3220    return false;
3221 }
3222 
3223 /* v_add_co(c, s_lshl(a, b)) -> v_mad_u32_u24(a, 1<<b, c)
3224  * v_add_co(c, v_lshlrev(a, b)) -> v_mad_u32_u24(b, 1<<a, c)
3225  * v_sub(c, s_lshl(a, b)) -> v_mad_i32_i24(a, -(1<<b), c)
3226  * v_sub(c, v_lshlrev(a, b)) -> v_mad_i32_i24(b, -(1<<a), c)
3227  */
3228 bool
combine_add_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr,bool is_sub)3229 combine_add_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr, bool is_sub)
3230 {
3231    if (instr->usesModifiers())
3232       return false;
3233 
3234    /* Substractions: start at operand 1 to avoid mixup such as
3235     * turning v_sub(v_lshlrev(a, b), c) into v_mad_i32_i24(b, -(1<<a), c)
3236     */
3237    unsigned start_op_idx = is_sub ? 1 : 0;
3238 
3239    /* Don't allow 24-bit operands on subtraction because
3240     * v_mad_i32_i24 applies a sign extension.
3241     */
3242    bool allow_24bit = !is_sub;
3243 
3244    for (unsigned i = start_op_idx; i < 2; i++) {
3245       Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
3246       if (!op_instr)
3247          continue;
3248 
3249       if (op_instr->opcode != aco_opcode::s_lshl_b32 &&
3250           op_instr->opcode != aco_opcode::v_lshlrev_b32)
3251          continue;
3252 
3253       int shift_op_idx = op_instr->opcode == aco_opcode::s_lshl_b32 ? 1 : 0;
3254 
3255       if (op_instr->operands[shift_op_idx].isConstant() &&
3256           ((allow_24bit && op_instr->operands[!shift_op_idx].is24bit()) ||
3257            op_instr->operands[!shift_op_idx].is16bit())) {
3258          uint32_t multiplier = 1 << (op_instr->operands[shift_op_idx].constantValue() % 32u);
3259          if (is_sub)
3260             multiplier = -multiplier;
3261          if (is_sub ? (multiplier < 0xff800000) : (multiplier > 0xffffff))
3262             continue;
3263 
3264          Operand ops[3] = {
3265             op_instr->operands[!shift_op_idx],
3266             Operand::c32(multiplier),
3267             instr->operands[!i],
3268          };
3269          if (!check_vop3_operands(ctx, 3, ops))
3270             return false;
3271 
3272          ctx.uses[instr->operands[i].tempId()]--;
3273 
3274          aco_opcode mad_op = is_sub ? aco_opcode::v_mad_i32_i24 : aco_opcode::v_mad_u32_u24;
3275          aco_ptr<VOP3_instruction> new_instr{
3276             create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 1)};
3277          for (unsigned op_idx = 0; op_idx < 3; ++op_idx)
3278             new_instr->operands[op_idx] = ops[op_idx];
3279          new_instr->definitions[0] = instr->definitions[0];
3280          instr = std::move(new_instr);
3281          ctx.info[instr->definitions[0].tempId()].label = 0;
3282          return true;
3283       }
3284    }
3285 
3286    return false;
3287 }
3288 
3289 void
propagate_swizzles(VOP3P_instruction * instr,uint8_t opsel_lo,uint8_t opsel_hi)3290 propagate_swizzles(VOP3P_instruction* instr, uint8_t opsel_lo, uint8_t opsel_hi)
3291 {
3292    /* propagate swizzles which apply to a result down to the instruction's operands:
3293     * result = a.xy + b.xx -> result.yx = a.yx + b.xx */
3294    assert((opsel_lo & 1) == opsel_lo);
3295    assert((opsel_hi & 1) == opsel_hi);
3296    uint8_t tmp_lo = instr->opsel_lo;
3297    uint8_t tmp_hi = instr->opsel_hi;
3298    bool neg_lo[3] = {instr->neg_lo[0], instr->neg_lo[1], instr->neg_lo[2]};
3299    bool neg_hi[3] = {instr->neg_hi[0], instr->neg_hi[1], instr->neg_hi[2]};
3300    if (opsel_lo == 1) {
3301       instr->opsel_lo = tmp_hi;
3302       for (unsigned i = 0; i < 3; i++)
3303          instr->neg_lo[i] = neg_hi[i];
3304    }
3305    if (opsel_hi == 0) {
3306       instr->opsel_hi = tmp_lo;
3307       for (unsigned i = 0; i < 3; i++)
3308          instr->neg_hi[i] = neg_lo[i];
3309    }
3310 }
3311 
3312 void
combine_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr)3313 combine_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3314 {
3315    VOP3P_instruction* vop3p = &instr->vop3p();
3316 
3317    /* apply clamp */
3318    if (instr->opcode == aco_opcode::v_pk_mul_f16 && instr->operands[1].constantEquals(0x3C00) &&
3319        vop3p->clamp && instr->operands[0].isTemp() && ctx.uses[instr->operands[0].tempId()] == 1) {
3320 
3321       ssa_info& info = ctx.info[instr->operands[0].tempId()];
3322       if (info.is_vop3p() && instr_info.can_use_output_modifiers[(int)info.instr->opcode]) {
3323          VOP3P_instruction* candidate = &ctx.info[instr->operands[0].tempId()].instr->vop3p();
3324          candidate->clamp = true;
3325          propagate_swizzles(candidate, vop3p->opsel_lo, vop3p->opsel_hi);
3326          instr->definitions[0].swapTemp(candidate->definitions[0]);
3327          ctx.info[candidate->definitions[0].tempId()].instr = candidate;
3328          ctx.uses[instr->definitions[0].tempId()]--;
3329          return;
3330       }
3331    }
3332 
3333    /* check for fneg modifiers */
3334    if (instr_info.can_use_input_modifiers[(int)instr->opcode]) {
3335       for (unsigned i = 0; i < instr->operands.size(); i++) {
3336          Operand& op = instr->operands[i];
3337          if (!op.isTemp())
3338             continue;
3339 
3340          ssa_info& info = ctx.info[op.tempId()];
3341          if (info.is_vop3p() && info.instr->opcode == aco_opcode::v_pk_mul_f16 &&
3342              info.instr->operands[1].constantEquals(0x3C00)) {
3343             Operand ops[3];
3344             for (unsigned j = 0; j < instr->operands.size(); j++)
3345                ops[j] = instr->operands[j];
3346             ops[i] = info.instr->operands[0];
3347             if (!check_vop3_operands(ctx, instr->operands.size(), ops))
3348                continue;
3349 
3350             VOP3P_instruction* fneg = &info.instr->vop3p();
3351             if (fneg->clamp)
3352                continue;
3353             instr->operands[i] = fneg->operands[0];
3354 
3355             /* opsel_lo/hi is either 0 or 1:
3356              * if 0 - pick selection from fneg->lo
3357              * if 1 - pick selection from fneg->hi
3358              */
3359             bool opsel_lo = (vop3p->opsel_lo >> i) & 1;
3360             bool opsel_hi = (vop3p->opsel_hi >> i) & 1;
3361             bool neg_lo = fneg->neg_lo[0] ^ fneg->neg_lo[1];
3362             bool neg_hi = fneg->neg_hi[0] ^ fneg->neg_hi[1];
3363             vop3p->neg_lo[i] ^= opsel_lo ? neg_hi : neg_lo;
3364             vop3p->neg_hi[i] ^= opsel_hi ? neg_hi : neg_lo;
3365             vop3p->opsel_lo ^= ((opsel_lo ? ~fneg->opsel_hi : fneg->opsel_lo) & 1) << i;
3366             vop3p->opsel_hi ^= ((opsel_hi ? ~fneg->opsel_hi : fneg->opsel_lo) & 1) << i;
3367 
3368             if (--ctx.uses[fneg->definitions[0].tempId()])
3369                ctx.uses[fneg->operands[0].tempId()]++;
3370          }
3371       }
3372    }
3373 
3374    if (instr->opcode == aco_opcode::v_pk_add_f16 || instr->opcode == aco_opcode::v_pk_add_u16) {
3375       bool fadd = instr->opcode == aco_opcode::v_pk_add_f16;
3376       if (fadd && instr->definitions[0].isPrecise())
3377          return;
3378 
3379       Instruction* mul_instr = nullptr;
3380       unsigned add_op_idx = 0;
3381       uint8_t opsel_lo = 0, opsel_hi = 0;
3382       uint32_t uses = UINT32_MAX;
3383 
3384       /* find the 'best' mul instruction to combine with the add */
3385       for (unsigned i = 0; i < 2; i++) {
3386          if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_vop3p())
3387             continue;
3388          ssa_info& info = ctx.info[instr->operands[i].tempId()];
3389          if (fadd) {
3390             if (info.instr->opcode != aco_opcode::v_pk_mul_f16 ||
3391                 info.instr->definitions[0].isPrecise())
3392                continue;
3393          } else {
3394             if (info.instr->opcode != aco_opcode::v_pk_mul_lo_u16)
3395                continue;
3396          }
3397 
3398          Operand op[3] = {info.instr->operands[0], info.instr->operands[1], instr->operands[1 - i]};
3399          if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
3400             continue;
3401 
3402          /* no clamp allowed between mul and add */
3403          if (info.instr->vop3p().clamp)
3404             continue;
3405 
3406          mul_instr = info.instr;
3407          add_op_idx = 1 - i;
3408          opsel_lo = (vop3p->opsel_lo >> i) & 1;
3409          opsel_hi = (vop3p->opsel_hi >> i) & 1;
3410          uses = ctx.uses[instr->operands[i].tempId()];
3411       }
3412 
3413       if (!mul_instr)
3414          return;
3415 
3416       /* convert to mad */
3417       Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1], instr->operands[add_op_idx]};
3418       ctx.uses[mul_instr->definitions[0].tempId()]--;
3419       if (ctx.uses[mul_instr->definitions[0].tempId()]) {
3420          if (op[0].isTemp())
3421             ctx.uses[op[0].tempId()]++;
3422          if (op[1].isTemp())
3423             ctx.uses[op[1].tempId()]++;
3424       }
3425 
3426       /* turn packed mul+add into v_pk_fma_f16 */
3427       assert(mul_instr->isVOP3P());
3428       aco_opcode mad = fadd ? aco_opcode::v_pk_fma_f16 : aco_opcode::v_pk_mad_u16;
3429       aco_ptr<VOP3P_instruction> fma{
3430          create_instruction<VOP3P_instruction>(mad, Format::VOP3P, 3, 1)};
3431       VOP3P_instruction* mul = &mul_instr->vop3p();
3432       for (unsigned i = 0; i < 2; i++) {
3433          fma->operands[i] = op[i];
3434          fma->neg_lo[i] = mul->neg_lo[i];
3435          fma->neg_hi[i] = mul->neg_hi[i];
3436       }
3437       fma->operands[2] = op[2];
3438       fma->clamp = vop3p->clamp;
3439       fma->opsel_lo = mul->opsel_lo;
3440       fma->opsel_hi = mul->opsel_hi;
3441       propagate_swizzles(fma.get(), opsel_lo, opsel_hi);
3442       fma->opsel_lo |= (vop3p->opsel_lo << (2 - add_op_idx)) & 0x4;
3443       fma->opsel_hi |= (vop3p->opsel_hi << (2 - add_op_idx)) & 0x4;
3444       fma->neg_lo[2] = vop3p->neg_lo[add_op_idx];
3445       fma->neg_hi[2] = vop3p->neg_hi[add_op_idx];
3446       fma->neg_lo[1] = fma->neg_lo[1] ^ vop3p->neg_lo[1 - add_op_idx];
3447       fma->neg_hi[1] = fma->neg_hi[1] ^ vop3p->neg_hi[1 - add_op_idx];
3448       fma->definitions[0] = instr->definitions[0];
3449       instr = std::move(fma);
3450       ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
3451       return;
3452    }
3453 }
3454 
3455 bool
can_use_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3456 can_use_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3457 {
3458    if (ctx.program->chip_class < GFX9)
3459       return false;
3460 
3461    switch (instr->opcode) {
3462    case aco_opcode::v_add_f32:
3463    case aco_opcode::v_sub_f32:
3464    case aco_opcode::v_subrev_f32:
3465    case aco_opcode::v_mul_f32:
3466    case aco_opcode::v_fma_f32: break;
3467    case aco_opcode::v_fma_mix_f32:
3468    case aco_opcode::v_fma_mixlo_f16: return true;
3469    default: return false;
3470    }
3471 
3472    if (instr->opcode == aco_opcode::v_fma_f32 && !ctx.program->dev.fused_mad_mix &&
3473        instr->definitions[0].isPrecise())
3474       return false;
3475 
3476    if (instr->isVOP3())
3477       return !instr->vop3().omod && !(instr->vop3().opsel & 0x8);
3478 
3479    return instr->format == Format::VOP2;
3480 }
3481 
3482 void
to_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3483 to_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3484 {
3485    bool is_add = instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32;
3486 
3487    aco_ptr<VOP3P_instruction> vop3p{
3488       create_instruction<VOP3P_instruction>(aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1)};
3489 
3490    vop3p->opsel_lo = instr->isVOP3() ? ((instr->vop3().opsel & 0x7) << (is_add ? 1 : 0)) : 0x0;
3491    vop3p->opsel_hi = 0x0;
3492    for (unsigned i = 0; i < instr->operands.size(); i++) {
3493       vop3p->operands[is_add + i] = instr->operands[i];
3494       vop3p->neg_lo[is_add + i] = instr->isVOP3() && instr->vop3().neg[i];
3495       vop3p->neg_lo[is_add + i] |= instr->isSDWA() && instr->sdwa().neg[i];
3496       vop3p->neg_hi[is_add + i] = instr->isVOP3() && instr->vop3().abs[i];
3497       vop3p->neg_hi[is_add + i] |= instr->isSDWA() && instr->sdwa().abs[i];
3498       vop3p->opsel_lo |= (instr->isSDWA() && instr->sdwa().sel[i].offset()) << (is_add + i);
3499    }
3500    if (instr->opcode == aco_opcode::v_mul_f32) {
3501       vop3p->opsel_hi &= 0x3;
3502       vop3p->operands[2] = Operand::zero();
3503       vop3p->neg_lo[2] = true;
3504    } else if (is_add) {
3505       vop3p->opsel_hi &= 0x6;
3506       vop3p->operands[0] = Operand::c32(0x3f800000);
3507       if (instr->opcode == aco_opcode::v_sub_f32)
3508          vop3p->neg_lo[2] ^= true;
3509       else if (instr->opcode == aco_opcode::v_subrev_f32)
3510          vop3p->neg_lo[1] ^= true;
3511    }
3512    vop3p->definitions[0] = instr->definitions[0];
3513    vop3p->clamp = instr->isVOP3() && instr->vop3().clamp;
3514    instr = std::move(vop3p);
3515 
3516    ctx.info[instr->definitions[0].tempId()].label &= label_f2f16 | label_clamp | label_mul;
3517    if (ctx.info[instr->definitions[0].tempId()].label & label_mul)
3518       ctx.info[instr->definitions[0].tempId()].instr = instr.get();
3519 }
3520 
3521 bool
combine_output_conversion(opt_ctx & ctx,aco_ptr<Instruction> & instr)3522 combine_output_conversion(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3523 {
3524    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3525    if (!def_info.is_f2f16())
3526       return false;
3527    Instruction* conv = def_info.instr;
3528 
3529    if (!can_use_mad_mix(ctx, instr) || ctx.uses[instr->definitions[0].tempId()] != 1)
3530       return false;
3531 
3532    if (!ctx.uses[conv->definitions[0].tempId()])
3533       return false;
3534 
3535    if (conv->usesModifiers())
3536       return false;
3537 
3538    if (!instr->isVOP3P())
3539       to_mad_mix(ctx, instr);
3540 
3541    instr->opcode = aco_opcode::v_fma_mixlo_f16;
3542    instr->definitions[0].swapTemp(conv->definitions[0]);
3543    if (conv->definitions[0].isPrecise())
3544       instr->definitions[0].setPrecise(true);
3545    ctx.info[instr->definitions[0].tempId()].label &= label_clamp;
3546    ctx.uses[conv->definitions[0].tempId()]--;
3547 
3548    return true;
3549 }
3550 
3551 void
combine_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3552 combine_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3553 {
3554    if (!can_use_mad_mix(ctx, instr))
3555       return;
3556 
3557    for (unsigned i = 0; i < instr->operands.size(); i++) {
3558       if (!instr->operands[i].isTemp())
3559          continue;
3560       Temp tmp = instr->operands[i].getTemp();
3561       if (!ctx.info[tmp.id()].is_f2f32())
3562          continue;
3563 
3564       Instruction* conv = ctx.info[tmp.id()].instr;
3565       if (conv->isSDWA() && (conv->sdwa().dst_sel.size() != 4 || conv->sdwa().sel[0].size() != 2 ||
3566                              conv->sdwa().clamp || conv->sdwa().omod)) {
3567          continue;
3568       } else if (conv->isVOP3() && (conv->vop3().clamp || conv->vop3().omod)) {
3569          continue;
3570       } else if (conv->isDPP()) {
3571          continue;
3572       }
3573 
3574       if (get_operand_size(instr, i) != 32)
3575          continue;
3576 
3577       /* Conversion to VOP3P will add inline constant operands, but that shouldn't affect
3578        * check_vop3_operands(). */
3579       Operand op[3];
3580       for (unsigned j = 0; j < instr->operands.size(); j++)
3581          op[j] = instr->operands[j];
3582       op[i] = conv->operands[0];
3583       if (!check_vop3_operands(ctx, instr->operands.size(), op))
3584          continue;
3585 
3586       if (!instr->isVOP3P()) {
3587          bool is_add =
3588             instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32;
3589          to_mad_mix(ctx, instr);
3590          i += is_add;
3591       }
3592 
3593       if (--ctx.uses[tmp.id()])
3594          ctx.uses[conv->operands[0].tempId()]++;
3595       instr->operands[i].setTemp(conv->operands[0].getTemp());
3596       if (conv->definitions[0].isPrecise())
3597          instr->definitions[0].setPrecise(true);
3598       instr->vop3p().opsel_hi ^= 1u << i;
3599       if (conv->isSDWA() && conv->sdwa().sel[0].offset() == 2)
3600          instr->vop3p().opsel_lo |= 1u << i;
3601       bool neg = (conv->isVOP3() && conv->vop3().neg[0]) || (conv->isSDWA() && conv->sdwa().neg[0]);
3602       bool abs = (conv->isVOP3() && conv->vop3().abs[0]) || (conv->isSDWA() && conv->sdwa().abs[0]);
3603       if (!instr->vop3p().neg_hi[i]) {
3604          instr->vop3p().neg_lo[i] ^= neg;
3605          instr->vop3p().neg_hi[i] = abs;
3606       }
3607    }
3608 }
3609 
3610 // TODO: we could possibly move the whole label_instruction pass to combine_instruction:
3611 // this would mean that we'd have to fix the instruction uses while value propagation
3612 
3613 void
combine_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)3614 combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3615 {
3616    if (instr->definitions.empty() || is_dead(ctx.uses, instr.get()))
3617       return;
3618 
3619    if (instr->isVALU()) {
3620       /* Apply SDWA. Do this after label_instruction() so it can remove
3621        * label_extract if not all instructions can take SDWA. */
3622       for (unsigned i = 0; i < instr->operands.size(); i++) {
3623          Operand& op = instr->operands[i];
3624          if (!op.isTemp())
3625             continue;
3626          ssa_info& info = ctx.info[op.tempId()];
3627          if (!info.is_extract())
3628             continue;
3629          /* if there are that many uses, there are likely better combinations */
3630          // TODO: delay applying extract to a point where we know better
3631          if (ctx.uses[op.tempId()] > 4) {
3632             info.label &= ~label_extract;
3633             continue;
3634          }
3635          if (info.is_extract() &&
3636              (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
3637               instr->operands[i].getTemp().type() == RegType::sgpr) &&
3638              can_apply_extract(ctx, instr, i, info)) {
3639             /* Increase use count of the extract's operand if the extract still has uses. */
3640             apply_extract(ctx, instr, i, info);
3641             if (--ctx.uses[instr->operands[i].tempId()])
3642                ctx.uses[info.instr->operands[0].tempId()]++;
3643             instr->operands[i].setTemp(info.instr->operands[0].getTemp());
3644          }
3645       }
3646 
3647       if (can_apply_sgprs(ctx, instr))
3648          apply_sgprs(ctx, instr);
3649       combine_mad_mix(ctx, instr);
3650       while (apply_omod_clamp(ctx, instr) | combine_output_conversion(ctx, instr))
3651          ;
3652       apply_insert(ctx, instr);
3653    }
3654 
3655    if (instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_mix_f32 &&
3656        instr->opcode != aco_opcode::v_fma_mixlo_f16)
3657       return combine_vop3p(ctx, instr);
3658 
3659    if (ctx.info[instr->definitions[0].tempId()].is_vcc_hint()) {
3660       instr->definitions[0].setHint(vcc);
3661    }
3662 
3663    if (instr->isSDWA() || instr->isDPP())
3664       return;
3665 
3666    if (instr->opcode == aco_opcode::p_extract) {
3667       ssa_info& info = ctx.info[instr->operands[0].tempId()];
3668       if (info.is_extract() && can_apply_extract(ctx, instr, 0, info)) {
3669          apply_extract(ctx, instr, 0, info);
3670          if (--ctx.uses[instr->operands[0].tempId()])
3671             ctx.uses[info.instr->operands[0].tempId()]++;
3672          instr->operands[0].setTemp(info.instr->operands[0].getTemp());
3673       }
3674 
3675       apply_ds_extract(ctx, instr);
3676    }
3677 
3678    /* TODO: There are still some peephole optimizations that could be done:
3679     * - abs(a - b) -> s_absdiff_i32
3680     * - various patterns for s_bitcmp{0,1}_b32 and s_bitset{0,1}_b32
3681     * - patterns for v_alignbit_b32 and v_alignbyte_b32
3682     * These aren't probably too interesting though.
3683     * There are also patterns for v_cmp_class_f{16,32,64}. This is difficult but
3684     * probably more useful than the previously mentioned optimizations.
3685     * The various comparison optimizations also currently only work with 32-bit
3686     * floats. */
3687 
3688    /* neg(mul(a, b)) -> mul(neg(a), b), abs(mul(a, b)) -> mul(abs(a), abs(b)) */
3689    if ((ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)) &&
3690        ctx.uses[instr->operands[1].tempId()] == 1) {
3691       Temp val = ctx.info[instr->definitions[0].tempId()].temp;
3692 
3693       if (!ctx.info[val.id()].is_mul())
3694          return;
3695 
3696       Instruction* mul_instr = ctx.info[val.id()].instr;
3697 
3698       if (mul_instr->operands[0].isLiteral())
3699          return;
3700       if (mul_instr->isVOP3() && mul_instr->vop3().clamp)
3701          return;
3702       if (mul_instr->isSDWA() || mul_instr->isDPP() || mul_instr->isVOP3P())
3703          return;
3704       if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32 &&
3705           ctx.fp_mode.preserve_signed_zero_inf_nan32)
3706          return;
3707       if (mul_instr->definitions[0].bytes() != instr->definitions[0].bytes())
3708          return;
3709 
3710       /* convert to mul(neg(a), b), mul(abs(a), abs(b)) or mul(neg(abs(a)), abs(b)) */
3711       ctx.uses[mul_instr->definitions[0].tempId()]--;
3712       Definition def = instr->definitions[0];
3713       bool is_neg = ctx.info[instr->definitions[0].tempId()].is_neg();
3714       bool is_abs = ctx.info[instr->definitions[0].tempId()].is_abs();
3715       instr.reset(
3716          create_instruction<VOP3_instruction>(mul_instr->opcode, asVOP3(Format::VOP2), 2, 1));
3717       instr->operands[0] = mul_instr->operands[0];
3718       instr->operands[1] = mul_instr->operands[1];
3719       instr->definitions[0] = def;
3720       VOP3_instruction& new_mul = instr->vop3();
3721       if (mul_instr->isVOP3()) {
3722          VOP3_instruction& mul = mul_instr->vop3();
3723          new_mul.neg[0] = mul.neg[0];
3724          new_mul.neg[1] = mul.neg[1];
3725          new_mul.abs[0] = mul.abs[0];
3726          new_mul.abs[1] = mul.abs[1];
3727          new_mul.omod = mul.omod;
3728       }
3729       if (is_abs) {
3730          new_mul.neg[0] = new_mul.neg[1] = false;
3731          new_mul.abs[0] = new_mul.abs[1] = true;
3732       }
3733       new_mul.neg[0] ^= is_neg;
3734       new_mul.clamp = false;
3735 
3736       ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
3737       return;
3738    }
3739 
3740    /* combine mul+add -> mad */
3741    bool is_add_mix =
3742       (instr->opcode == aco_opcode::v_fma_mix_f32 ||
3743        instr->opcode == aco_opcode::v_fma_mixlo_f16) &&
3744       !instr->vop3p().neg_lo[0] &&
3745       ((instr->operands[0].constantEquals(0x3f800000) && (instr->vop3p().opsel_hi & 0x1) == 0) ||
3746        (instr->operands[0].constantEquals(0x3C00) && (instr->vop3p().opsel_hi & 0x1) &&
3747         !(instr->vop3p().opsel_lo & 0x1)));
3748    bool mad32 = instr->opcode == aco_opcode::v_add_f32 || instr->opcode == aco_opcode::v_sub_f32 ||
3749                 instr->opcode == aco_opcode::v_subrev_f32;
3750    bool mad16 = instr->opcode == aco_opcode::v_add_f16 || instr->opcode == aco_opcode::v_sub_f16 ||
3751                 instr->opcode == aco_opcode::v_subrev_f16;
3752    bool mad64 = instr->opcode == aco_opcode::v_add_f64;
3753    if (is_add_mix || mad16 || mad32 || mad64) {
3754       Instruction* mul_instr = nullptr;
3755       unsigned add_op_idx = 0;
3756       uint32_t uses = UINT32_MAX;
3757       bool emit_fma = false;
3758       /* find the 'best' mul instruction to combine with the add */
3759       for (unsigned i = is_add_mix ? 1 : 0; i < instr->operands.size(); i++) {
3760          if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_mul())
3761             continue;
3762          ssa_info& info = ctx.info[instr->operands[i].tempId()];
3763 
3764          /* no clamp/omod allowed between mul and add */
3765          if (info.instr->isVOP3() && (info.instr->vop3().clamp || info.instr->vop3().omod))
3766             continue;
3767          if (info.instr->isVOP3P() && info.instr->vop3p().clamp)
3768             continue;
3769          /* v_fma_mix_f32/etc can't do omod */
3770          if (info.instr->isVOP3P() && instr->isVOP3() && instr->vop3().omod)
3771             continue;
3772          /* don't promote fp16 to fp32 or remove fp32->fp16->fp32 conversions */
3773          if (is_add_mix && info.instr->definitions[0].bytes() == 2)
3774             continue;
3775 
3776          if (get_operand_size(instr, i) != info.instr->definitions[0].bytes() * 8)
3777             continue;
3778 
3779          bool legacy = info.instr->opcode == aco_opcode::v_mul_legacy_f32;
3780          bool mad_mix = is_add_mix || info.instr->isVOP3P();
3781 
3782          bool has_fma = mad16 || mad64 || (legacy && ctx.program->chip_class >= GFX10_3) ||
3783                         (mad32 && !legacy && !mad_mix && ctx.program->dev.has_fast_fma32) ||
3784                         (mad_mix && ctx.program->dev.fused_mad_mix);
3785          bool has_mad = mad_mix ? !ctx.program->dev.fused_mad_mix
3786                                 : ((mad32 && ctx.program->chip_class < GFX10_3) ||
3787                                    (mad16 && ctx.program->chip_class <= GFX9));
3788          bool can_use_fma = has_fma && !info.instr->definitions[0].isPrecise() &&
3789                             !instr->definitions[0].isPrecise();
3790          bool can_use_mad =
3791             has_mad && (mad_mix || mad32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == 0;
3792          if (mad_mix && legacy)
3793             continue;
3794          if (!can_use_fma && !can_use_mad)
3795             continue;
3796 
3797          unsigned candidate_add_op_idx = is_add_mix ? (3 - i) : (1 - i);
3798          Operand op[3] = {info.instr->operands[0], info.instr->operands[1],
3799                           instr->operands[candidate_add_op_idx]};
3800          if (info.instr->isSDWA() || info.instr->isDPP() || !check_vop3_operands(ctx, 3, op) ||
3801              ctx.uses[instr->operands[i].tempId()] > uses)
3802             continue;
3803 
3804          if (ctx.uses[instr->operands[i].tempId()] == uses) {
3805             unsigned cur_idx = mul_instr->definitions[0].tempId();
3806             unsigned new_idx = info.instr->definitions[0].tempId();
3807             if (cur_idx > new_idx)
3808                continue;
3809          }
3810 
3811          mul_instr = info.instr;
3812          add_op_idx = candidate_add_op_idx;
3813          uses = ctx.uses[instr->operands[i].tempId()];
3814          emit_fma = !can_use_mad;
3815       }
3816 
3817       if (mul_instr) {
3818          /* turn mul+add into v_mad/v_fma */
3819          Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1],
3820                           instr->operands[add_op_idx]};
3821          ctx.uses[mul_instr->definitions[0].tempId()]--;
3822          if (ctx.uses[mul_instr->definitions[0].tempId()]) {
3823             if (op[0].isTemp())
3824                ctx.uses[op[0].tempId()]++;
3825             if (op[1].isTemp())
3826                ctx.uses[op[1].tempId()]++;
3827          }
3828 
3829          bool neg[3] = {false, false, false};
3830          bool abs[3] = {false, false, false};
3831          unsigned omod = 0;
3832          bool clamp = false;
3833          uint8_t opsel_lo = 0;
3834          uint8_t opsel_hi = 0;
3835 
3836          if (mul_instr->isVOP3()) {
3837             VOP3_instruction& vop3 = mul_instr->vop3();
3838             neg[0] = vop3.neg[0];
3839             neg[1] = vop3.neg[1];
3840             abs[0] = vop3.abs[0];
3841             abs[1] = vop3.abs[1];
3842          } else if (mul_instr->isVOP3P()) {
3843             VOP3P_instruction& vop3p = mul_instr->vop3p();
3844             neg[0] = vop3p.neg_lo[0];
3845             neg[1] = vop3p.neg_lo[1];
3846             abs[0] = vop3p.neg_hi[0];
3847             abs[1] = vop3p.neg_hi[1];
3848             opsel_lo = vop3p.opsel_lo & 0x3;
3849             opsel_hi = vop3p.opsel_hi & 0x3;
3850          }
3851 
3852          if (instr->isVOP3()) {
3853             VOP3_instruction& vop3 = instr->vop3();
3854             neg[2] = vop3.neg[add_op_idx];
3855             abs[2] = vop3.abs[add_op_idx];
3856             omod = vop3.omod;
3857             clamp = vop3.clamp;
3858             /* abs of the multiplication result */
3859             if (vop3.abs[1 - add_op_idx]) {
3860                neg[0] = false;
3861                neg[1] = false;
3862                abs[0] = true;
3863                abs[1] = true;
3864             }
3865             /* neg of the multiplication result */
3866             neg[1] = neg[1] ^ vop3.neg[1 - add_op_idx];
3867          } else if (instr->isVOP3P()) {
3868             VOP3P_instruction& vop3p = instr->vop3p();
3869             neg[2] = vop3p.neg_lo[add_op_idx];
3870             abs[2] = vop3p.neg_hi[add_op_idx];
3871             opsel_lo |= vop3p.opsel_lo & (1 << add_op_idx) ? 0x4 : 0x0;
3872             opsel_hi |= vop3p.opsel_hi & (1 << add_op_idx) ? 0x4 : 0x0;
3873             clamp = vop3p.clamp;
3874             /* abs of the multiplication result */
3875             if (vop3p.neg_hi[3 - add_op_idx]) {
3876                neg[0] = false;
3877                neg[1] = false;
3878                abs[0] = true;
3879                abs[1] = true;
3880             }
3881             /* neg of the multiplication result */
3882             neg[1] = neg[1] ^ vop3p.neg_lo[3 - add_op_idx];
3883          }
3884 
3885          if (instr->opcode == aco_opcode::v_sub_f32 || instr->opcode == aco_opcode::v_sub_f16)
3886             neg[1 + add_op_idx] = neg[1 + add_op_idx] ^ true;
3887          else if (instr->opcode == aco_opcode::v_subrev_f32 ||
3888                   instr->opcode == aco_opcode::v_subrev_f16)
3889             neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
3890 
3891          aco_ptr<Instruction> add_instr = std::move(instr);
3892          if (add_instr->isVOP3P() || mul_instr->isVOP3P()) {
3893             assert(!omod);
3894 
3895             aco_opcode mad_op = add_instr->definitions[0].bytes() == 2 ? aco_opcode::v_fma_mixlo_f16
3896                                                                        : aco_opcode::v_fma_mix_f32;
3897             aco_ptr<VOP3P_instruction> mad{
3898                create_instruction<VOP3P_instruction>(mad_op, Format::VOP3P, 3, 1)};
3899             for (unsigned i = 0; i < 3; i++) {
3900                mad->operands[i] = op[i];
3901                mad->neg_lo[i] = neg[i];
3902                mad->neg_hi[i] = abs[i];
3903             }
3904             mad->clamp = clamp;
3905             mad->opsel_lo = opsel_lo;
3906             mad->opsel_hi = opsel_hi;
3907 
3908             instr = std::move(mad);
3909          } else {
3910             aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
3911             if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
3912                assert(emit_fma == (ctx.program->chip_class >= GFX10_3));
3913                mad_op = emit_fma ? aco_opcode::v_fma_legacy_f32 : aco_opcode::v_mad_legacy_f32;
3914             } else if (mad16) {
3915                mad_op = emit_fma ? (ctx.program->chip_class == GFX8 ? aco_opcode::v_fma_legacy_f16
3916                                                                     : aco_opcode::v_fma_f16)
3917                                  : (ctx.program->chip_class == GFX8 ? aco_opcode::v_mad_legacy_f16
3918                                                                     : aco_opcode::v_mad_f16);
3919             } else if (mad64) {
3920                mad_op = aco_opcode::v_fma_f64;
3921             }
3922 
3923             aco_ptr<VOP3_instruction> mad{
3924                create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 1)};
3925             for (unsigned i = 0; i < 3; i++) {
3926                mad->operands[i] = op[i];
3927                mad->neg[i] = neg[i];
3928                mad->abs[i] = abs[i];
3929             }
3930             mad->omod = omod;
3931             mad->clamp = clamp;
3932 
3933             instr = std::move(mad);
3934          }
3935          instr->definitions[0] = add_instr->definitions[0];
3936 
3937          /* mark this ssa_def to be re-checked for profitability and literals */
3938          ctx.mad_infos.emplace_back(std::move(add_instr), mul_instr->definitions[0].tempId());
3939          ctx.info[instr->definitions[0].tempId()].set_mad(instr.get(), ctx.mad_infos.size() - 1);
3940          return;
3941       }
3942    }
3943    /* v_mul_f32(v_cndmask_b32(0, 1.0, cond), a) -> v_cndmask_b32(0, a, cond) */
3944    else if (((instr->opcode == aco_opcode::v_mul_f32 &&
3945               !ctx.fp_mode.preserve_signed_zero_inf_nan32) ||
3946              instr->opcode == aco_opcode::v_mul_legacy_f32) &&
3947             !instr->usesModifiers() && !ctx.fp_mode.must_flush_denorms32) {
3948       for (unsigned i = 0; i < 2; i++) {
3949          if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2f() &&
3950              ctx.uses[instr->operands[i].tempId()] == 1 && instr->operands[!i].isTemp() &&
3951              instr->operands[!i].getTemp().type() == RegType::vgpr) {
3952             ctx.uses[instr->operands[i].tempId()]--;
3953             ctx.uses[ctx.info[instr->operands[i].tempId()].temp.id()]++;
3954 
3955             aco_ptr<VOP2_instruction> new_instr{
3956                create_instruction<VOP2_instruction>(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1)};
3957             new_instr->operands[0] = Operand::zero();
3958             new_instr->operands[1] = instr->operands[!i];
3959             new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
3960             new_instr->definitions[0] = instr->definitions[0];
3961             instr = std::move(new_instr);
3962             ctx.info[instr->definitions[0].tempId()].label = 0;
3963             return;
3964          }
3965       }
3966    } else if (instr->opcode == aco_opcode::v_or_b32 && ctx.program->chip_class >= GFX9) {
3967       if (combine_three_valu_op(ctx, instr, aco_opcode::s_or_b32, aco_opcode::v_or3_b32, "012",
3968                                 1 | 2)) {
3969       } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_or_b32, aco_opcode::v_or3_b32,
3970                                        "012", 1 | 2)) {
3971       } else if (combine_add_or_then_and_lshl(ctx, instr)) {
3972       }
3973    } else if (instr->opcode == aco_opcode::v_xor_b32 && ctx.program->chip_class >= GFX10) {
3974       if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xor3_b32, "012",
3975                                 1 | 2)) {
3976       } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xor3_b32,
3977                                        "012", 1 | 2)) {
3978       }
3979    } else if (instr->opcode == aco_opcode::v_add_u16) {
3980       combine_three_valu_op(
3981          ctx, instr, aco_opcode::v_mul_lo_u16,
3982          ctx.program->chip_class == GFX8 ? aco_opcode::v_mad_legacy_u16 : aco_opcode::v_mad_u16,
3983          "120", 1 | 2);
3984    } else if (instr->opcode == aco_opcode::v_add_u16_e64) {
3985       combine_three_valu_op(ctx, instr, aco_opcode::v_mul_lo_u16_e64, aco_opcode::v_mad_u16, "120",
3986                             1 | 2);
3987    } else if (instr->opcode == aco_opcode::v_add_u32) {
3988       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
3989       } else if (combine_add_bcnt(ctx, instr)) {
3990       } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
3991                                        aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
3992       } else if (ctx.program->chip_class >= GFX9 && !instr->usesModifiers()) {
3993          if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xad_u32, "120",
3994                                    1 | 2)) {
3995          } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xad_u32,
3996                                           "120", 1 | 2)) {
3997          } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_i32, aco_opcode::v_add3_u32,
3998                                           "012", 1 | 2)) {
3999          } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_u32, aco_opcode::v_add3_u32,
4000                                           "012", 1 | 2)) {
4001          } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add3_u32,
4002                                           "012", 1 | 2)) {
4003          } else if (combine_add_or_then_and_lshl(ctx, instr)) {
4004          }
4005       }
4006    } else if (instr->opcode == aco_opcode::v_add_co_u32 ||
4007               instr->opcode == aco_opcode::v_add_co_u32_e64) {
4008       bool carry_out = ctx.uses[instr->definitions[1].tempId()] > 0;
4009       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
4010       } else if (!carry_out && combine_add_bcnt(ctx, instr)) {
4011       } else if (!carry_out && combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
4012                                                      aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
4013       } else if (!carry_out && combine_add_lshl(ctx, instr, false)) {
4014       }
4015    } else if (instr->opcode == aco_opcode::v_sub_u32 || instr->opcode == aco_opcode::v_sub_co_u32 ||
4016               instr->opcode == aco_opcode::v_sub_co_u32_e64) {
4017       bool carry_out =
4018          instr->opcode != aco_opcode::v_sub_u32 && ctx.uses[instr->definitions[1].tempId()] > 0;
4019       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 2)) {
4020       } else if (!carry_out && combine_add_lshl(ctx, instr, true)) {
4021       }
4022    } else if (instr->opcode == aco_opcode::v_subrev_u32 ||
4023               instr->opcode == aco_opcode::v_subrev_co_u32 ||
4024               instr->opcode == aco_opcode::v_subrev_co_u32_e64) {
4025       combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 1);
4026    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && ctx.program->chip_class >= GFX9) {
4027       combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add_lshl_u32, "120",
4028                             2);
4029    } else if ((instr->opcode == aco_opcode::s_add_u32 || instr->opcode == aco_opcode::s_add_i32) &&
4030               ctx.program->chip_class >= GFX9) {
4031       combine_salu_lshl_add(ctx, instr);
4032    } else if (instr->opcode == aco_opcode::s_not_b32 || instr->opcode == aco_opcode::s_not_b64) {
4033       combine_salu_not_bitwise(ctx, instr);
4034    } else if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_or_b32 ||
4035               instr->opcode == aco_opcode::s_and_b64 || instr->opcode == aco_opcode::s_or_b64) {
4036       if (combine_ordering_test(ctx, instr)) {
4037       } else if (combine_comparison_ordering(ctx, instr)) {
4038       } else if (combine_constant_comparison_ordering(ctx, instr)) {
4039       } else if (combine_salu_n2(ctx, instr)) {
4040       }
4041    } else if (instr->opcode == aco_opcode::v_and_b32) {
4042       combine_and_subbrev(ctx, instr);
4043    } else if (instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) {
4044       /* set existing v_fma_f32 with label_mad so we can create v_fmamk_f32/v_fmaak_f32.
4045        * since ctx.uses[mad_info::mul_temp_id] is always 0, we don't have to worry about
4046        * select_instruction() using mad_info::add_instr.
4047        */
4048       ctx.mad_infos.emplace_back(nullptr, 0);
4049       ctx.info[instr->definitions[0].tempId()].set_mad(instr.get(), ctx.mad_infos.size() - 1);
4050    } else {
4051       aco_opcode min, max, min3, max3, med3;
4052       bool some_gfx9_only;
4053       if (get_minmax_info(instr->opcode, &min, &max, &min3, &max3, &med3, &some_gfx9_only) &&
4054           (!some_gfx9_only || ctx.program->chip_class >= GFX9)) {
4055          if (combine_minmax(ctx, instr, instr->opcode == min ? max : min,
4056                             instr->opcode == min ? min3 : max3)) {
4057          } else {
4058             combine_clamp(ctx, instr, min, max, med3);
4059          }
4060       }
4061    }
4062 
4063    /* do this after combine_salu_n2() */
4064    if (instr->opcode == aco_opcode::s_andn2_b32 || instr->opcode == aco_opcode::s_andn2_b64)
4065       combine_inverse_comparison(ctx, instr);
4066 }
4067 
4068 bool
to_uniform_bool_instr(opt_ctx & ctx,aco_ptr<Instruction> & instr)4069 to_uniform_bool_instr(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4070 {
4071    /* Check every operand to make sure they are suitable. */
4072    for (Operand& op : instr->operands) {
4073       if (!op.isTemp())
4074          return false;
4075       if (!ctx.info[op.tempId()].is_uniform_bool() && !ctx.info[op.tempId()].is_uniform_bitwise())
4076          return false;
4077    }
4078 
4079    switch (instr->opcode) {
4080    case aco_opcode::s_and_b32:
4081    case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_and_b32; break;
4082    case aco_opcode::s_or_b32:
4083    case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_or_b32; break;
4084    case aco_opcode::s_xor_b32:
4085    case aco_opcode::s_xor_b64: instr->opcode = aco_opcode::s_absdiff_i32; break;
4086    default:
4087       /* Don't transform other instructions. They are very unlikely to appear here. */
4088       return false;
4089    }
4090 
4091    for (Operand& op : instr->operands) {
4092       ctx.uses[op.tempId()]--;
4093 
4094       if (ctx.info[op.tempId()].is_uniform_bool()) {
4095          /* Just use the uniform boolean temp. */
4096          op.setTemp(ctx.info[op.tempId()].temp);
4097       } else if (ctx.info[op.tempId()].is_uniform_bitwise()) {
4098          /* Use the SCC definition of the predecessor instruction.
4099           * This allows the predecessor to get picked up by the same optimization (if it has no
4100           * divergent users), and it also makes sure that the current instruction will keep working
4101           * even if the predecessor won't be transformed.
4102           */
4103          Instruction* pred_instr = ctx.info[op.tempId()].instr;
4104          assert(pred_instr->definitions.size() >= 2);
4105          assert(pred_instr->definitions[1].isFixed() &&
4106                 pred_instr->definitions[1].physReg() == scc);
4107          op.setTemp(pred_instr->definitions[1].getTemp());
4108       } else {
4109          unreachable("Invalid operand on uniform bitwise instruction.");
4110       }
4111 
4112       ctx.uses[op.tempId()]++;
4113    }
4114 
4115    instr->definitions[0].setTemp(Temp(instr->definitions[0].tempId(), s1));
4116    assert(instr->operands[0].regClass() == s1);
4117    assert(instr->operands[1].regClass() == s1);
4118    return true;
4119 }
4120 
4121 void
select_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)4122 select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4123 {
4124    const uint32_t threshold = 4;
4125 
4126    if (is_dead(ctx.uses, instr.get())) {
4127       instr.reset();
4128       return;
4129    }
4130 
4131    /* convert split_vector into a copy or extract_vector if only one definition is ever used */
4132    if (instr->opcode == aco_opcode::p_split_vector) {
4133       unsigned num_used = 0;
4134       unsigned idx = 0;
4135       unsigned split_offset = 0;
4136       for (unsigned i = 0, offset = 0; i < instr->definitions.size();
4137            offset += instr->definitions[i++].bytes()) {
4138          if (ctx.uses[instr->definitions[i].tempId()]) {
4139             num_used++;
4140             idx = i;
4141             split_offset = offset;
4142          }
4143       }
4144       bool done = false;
4145       if (num_used == 1 && ctx.info[instr->operands[0].tempId()].is_vec() &&
4146           ctx.uses[instr->operands[0].tempId()] == 1) {
4147          Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
4148 
4149          unsigned off = 0;
4150          Operand op;
4151          for (Operand& vec_op : vec->operands) {
4152             if (off == split_offset) {
4153                op = vec_op;
4154                break;
4155             }
4156             off += vec_op.bytes();
4157          }
4158          if (off != instr->operands[0].bytes() && op.bytes() == instr->definitions[idx].bytes()) {
4159             ctx.uses[instr->operands[0].tempId()]--;
4160             for (Operand& vec_op : vec->operands) {
4161                if (vec_op.isTemp())
4162                   ctx.uses[vec_op.tempId()]--;
4163             }
4164             if (op.isTemp())
4165                ctx.uses[op.tempId()]++;
4166 
4167             aco_ptr<Pseudo_instruction> extract{create_instruction<Pseudo_instruction>(
4168                aco_opcode::p_create_vector, Format::PSEUDO, 1, 1)};
4169             extract->operands[0] = op;
4170             extract->definitions[0] = instr->definitions[idx];
4171             instr = std::move(extract);
4172 
4173             done = true;
4174          }
4175       }
4176 
4177       if (!done && num_used == 1 &&
4178           instr->operands[0].bytes() % instr->definitions[idx].bytes() == 0 &&
4179           split_offset % instr->definitions[idx].bytes() == 0) {
4180          aco_ptr<Pseudo_instruction> extract{create_instruction<Pseudo_instruction>(
4181             aco_opcode::p_extract_vector, Format::PSEUDO, 2, 1)};
4182          extract->operands[0] = instr->operands[0];
4183          extract->operands[1] =
4184             Operand::c32((uint32_t)split_offset / instr->definitions[idx].bytes());
4185          extract->definitions[0] = instr->definitions[idx];
4186          instr = std::move(extract);
4187       }
4188    }
4189 
4190    mad_info* mad_info = NULL;
4191    if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
4192       mad_info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].instr->pass_flags];
4193       /* re-check mad instructions */
4194       if (ctx.uses[mad_info->mul_temp_id] && mad_info->add_instr) {
4195          ctx.uses[mad_info->mul_temp_id]++;
4196          if (instr->operands[0].isTemp())
4197             ctx.uses[instr->operands[0].tempId()]--;
4198          if (instr->operands[1].isTemp())
4199             ctx.uses[instr->operands[1].tempId()]--;
4200          instr.swap(mad_info->add_instr);
4201          mad_info = NULL;
4202       }
4203       /* check literals */
4204       else if (!instr->usesModifiers() && !instr->isVOP3P() &&
4205                instr->opcode != aco_opcode::v_fma_f64 &&
4206                instr->opcode != aco_opcode::v_mad_legacy_f32 &&
4207                instr->opcode != aco_opcode::v_fma_legacy_f32) {
4208          /* FMA can only take literals on GFX10+ */
4209          if ((instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) &&
4210              ctx.program->chip_class < GFX10)
4211             return;
4212          /* There are no v_fmaak_legacy_f16/v_fmamk_legacy_f16 and on chips where VOP3 can take
4213           * literals (GFX10+), these instructions don't exist.
4214           */
4215          if (instr->opcode == aco_opcode::v_fma_legacy_f16)
4216             return;
4217 
4218          uint32_t literal_idx = 0;
4219          uint32_t literal_uses = UINT32_MAX;
4220 
4221          /* Try using v_madak/v_fmaak */
4222          if (instr->operands[2].isTemp() &&
4223              ctx.info[instr->operands[2].tempId()].is_literal(get_operand_size(instr, 2))) {
4224             bool has_sgpr = false;
4225             bool has_vgpr = false;
4226             for (unsigned i = 0; i < 2; i++) {
4227                if (!instr->operands[i].isTemp())
4228                   continue;
4229                has_sgpr |= instr->operands[i].getTemp().type() == RegType::sgpr;
4230                has_vgpr |= instr->operands[i].getTemp().type() == RegType::vgpr;
4231             }
4232             /* Encoding limitations requires a VGPR operand. The constant bus limitations before
4233              * GFX10 disallows SGPRs.
4234              */
4235             if ((!has_sgpr || ctx.program->chip_class >= GFX10) && has_vgpr) {
4236                literal_idx = 2;
4237                literal_uses = ctx.uses[instr->operands[2].tempId()];
4238             }
4239          }
4240 
4241          /* Try using v_madmk/v_fmamk */
4242          /* Encoding limitations requires a VGPR operand. */
4243          if (instr->operands[2].isTemp() && instr->operands[2].getTemp().type() == RegType::vgpr) {
4244             for (unsigned i = 0; i < 2; i++) {
4245                if (!instr->operands[i].isTemp())
4246                   continue;
4247 
4248                /* The constant bus limitations before GFX10 disallows SGPRs. */
4249                if (ctx.program->chip_class < GFX10 && instr->operands[!i].isTemp() &&
4250                    instr->operands[!i].getTemp().type() == RegType::sgpr)
4251                   continue;
4252 
4253                if (ctx.info[instr->operands[i].tempId()].is_literal(get_operand_size(instr, i)) &&
4254                    ctx.uses[instr->operands[i].tempId()] < literal_uses) {
4255                   literal_idx = i;
4256                   literal_uses = ctx.uses[instr->operands[i].tempId()];
4257                }
4258             }
4259          }
4260 
4261          /* Limit the number of literals to apply to not increase the code
4262           * size too much, but always apply literals for v_mad->v_madak
4263           * because both instructions are 64-bit and this doesn't increase
4264           * code size.
4265           * TODO: try to apply the literals earlier to lower the number of
4266           * uses below threshold
4267           */
4268          if (literal_uses < threshold || literal_idx == 2) {
4269             ctx.uses[instr->operands[literal_idx].tempId()]--;
4270             mad_info->check_literal = true;
4271             mad_info->literal_idx = literal_idx;
4272             return;
4273          }
4274       }
4275    }
4276 
4277    /* Mark SCC needed, so the uniform boolean transformation won't swap the definitions
4278     * when it isn't beneficial */
4279    if (instr->isBranch() && instr->operands.size() && instr->operands[0].isTemp() &&
4280        instr->operands[0].isFixed() && instr->operands[0].physReg() == scc) {
4281       ctx.info[instr->operands[0].tempId()].set_scc_needed();
4282       return;
4283    } else if ((instr->opcode == aco_opcode::s_cselect_b64 ||
4284                instr->opcode == aco_opcode::s_cselect_b32) &&
4285               instr->operands[2].isTemp()) {
4286       ctx.info[instr->operands[2].tempId()].set_scc_needed();
4287    } else if (instr->opcode == aco_opcode::p_wqm && instr->operands[0].isTemp() &&
4288               ctx.info[instr->definitions[0].tempId()].is_scc_needed()) {
4289       /* Propagate label so it is correctly detected by the uniform bool transform */
4290       ctx.info[instr->operands[0].tempId()].set_scc_needed();
4291 
4292       /* Fix definition to SCC, this will prevent RA from adding superfluous moves */
4293       instr->definitions[0].setFixed(scc);
4294    }
4295 
4296    /* check for literals */
4297    if (!instr->isSALU() && !instr->isVALU())
4298       return;
4299 
4300    /* Transform uniform bitwise boolean operations to 32-bit when there are no divergent uses. */
4301    if (instr->definitions.size() && ctx.uses[instr->definitions[0].tempId()] == 0 &&
4302        ctx.info[instr->definitions[0].tempId()].is_uniform_bitwise()) {
4303       bool transform_done = to_uniform_bool_instr(ctx, instr);
4304 
4305       if (transform_done && !ctx.info[instr->definitions[1].tempId()].is_scc_needed()) {
4306          /* Swap the two definition IDs in order to avoid overusing the SCC.
4307           * This reduces extra moves generated by RA. */
4308          uint32_t def0_id = instr->definitions[0].getTemp().id();
4309          uint32_t def1_id = instr->definitions[1].getTemp().id();
4310          instr->definitions[0].setTemp(Temp(def1_id, s1));
4311          instr->definitions[1].setTemp(Temp(def0_id, s1));
4312       }
4313 
4314       return;
4315    }
4316 
4317    /* Combine DPP copies into VALU. This should be done after creating MAD/FMA. */
4318    if (instr->isVALU()) {
4319       for (unsigned i = 0; i < instr->operands.size(); i++) {
4320          if (!instr->operands[i].isTemp())
4321             continue;
4322          ssa_info info = ctx.info[instr->operands[i].tempId()];
4323 
4324          aco_opcode swapped_op;
4325          if (info.is_dpp() && info.instr->pass_flags == instr->pass_flags &&
4326              (i == 0 || can_swap_operands(instr, &swapped_op)) &&
4327              can_use_DPP(instr, true, info.is_dpp8()) && !instr->isDPP()) {
4328             bool dpp8 = info.is_dpp8();
4329             convert_to_DPP(instr, dpp8);
4330             if (dpp8) {
4331                DPP8_instruction* dpp = &instr->dpp8();
4332                for (unsigned j = 0; j < 8; ++j)
4333                   dpp->lane_sel[j] = info.instr->dpp8().lane_sel[j];
4334                if (i) {
4335                   instr->opcode = swapped_op;
4336                   std::swap(instr->operands[0], instr->operands[1]);
4337                }
4338             } else {
4339                DPP16_instruction* dpp = &instr->dpp16();
4340                if (i) {
4341                   instr->opcode = swapped_op;
4342                   std::swap(instr->operands[0], instr->operands[1]);
4343                   std::swap(dpp->neg[0], dpp->neg[1]);
4344                   std::swap(dpp->abs[0], dpp->abs[1]);
4345                }
4346                dpp->dpp_ctrl = info.instr->dpp16().dpp_ctrl;
4347                dpp->bound_ctrl = info.instr->dpp16().bound_ctrl;
4348                dpp->neg[0] ^= info.instr->dpp16().neg[0] && !dpp->abs[0];
4349                dpp->abs[0] |= info.instr->dpp16().abs[0];
4350             }
4351             if (--ctx.uses[info.instr->definitions[0].tempId()])
4352                ctx.uses[info.instr->operands[0].tempId()]++;
4353             instr->operands[0].setTemp(info.instr->operands[0].getTemp());
4354             break;
4355          }
4356       }
4357    }
4358 
4359    if (instr->isSDWA() || (instr->isVOP3() && ctx.program->chip_class < GFX10) ||
4360        (instr->isVOP3P() && ctx.program->chip_class < GFX10))
4361       return; /* some encodings can't ever take literals */
4362 
4363    /* we do not apply the literals yet as we don't know if it is profitable */
4364    Operand current_literal(s1);
4365 
4366    unsigned literal_id = 0;
4367    unsigned literal_uses = UINT32_MAX;
4368    Operand literal(s1);
4369    unsigned num_operands = 1;
4370    if (instr->isSALU() ||
4371        (ctx.program->chip_class >= GFX10 && (can_use_VOP3(ctx, instr) || instr->isVOP3P())))
4372       num_operands = instr->operands.size();
4373    /* catch VOP2 with a 3rd SGPR operand (e.g. v_cndmask_b32, v_addc_co_u32) */
4374    else if (instr->isVALU() && instr->operands.size() >= 3)
4375       return;
4376 
4377    unsigned sgpr_ids[2] = {0, 0};
4378    bool is_literal_sgpr = false;
4379    uint32_t mask = 0;
4380 
4381    /* choose a literal to apply */
4382    for (unsigned i = 0; i < num_operands; i++) {
4383       Operand op = instr->operands[i];
4384       unsigned bits = get_operand_size(instr, i);
4385 
4386       if (instr->isVALU() && op.isTemp() && op.getTemp().type() == RegType::sgpr &&
4387           op.tempId() != sgpr_ids[0])
4388          sgpr_ids[!!sgpr_ids[0]] = op.tempId();
4389 
4390       if (op.isLiteral()) {
4391          current_literal = op;
4392          continue;
4393       } else if (!op.isTemp() || !ctx.info[op.tempId()].is_literal(bits)) {
4394          continue;
4395       }
4396 
4397       if (!alu_can_accept_constant(instr->opcode, i))
4398          continue;
4399 
4400       if (ctx.uses[op.tempId()] < literal_uses) {
4401          is_literal_sgpr = op.getTemp().type() == RegType::sgpr;
4402          mask = 0;
4403          literal = Operand::c32(ctx.info[op.tempId()].val);
4404          literal_uses = ctx.uses[op.tempId()];
4405          literal_id = op.tempId();
4406       }
4407 
4408       mask |= (op.tempId() == literal_id) << i;
4409    }
4410 
4411    /* don't go over the constant bus limit */
4412    bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64 ||
4413                      instr->opcode == aco_opcode::v_lshrrev_b64 ||
4414                      instr->opcode == aco_opcode::v_ashrrev_i64;
4415    unsigned const_bus_limit = instr->isVALU() ? 1 : UINT32_MAX;
4416    if (ctx.program->chip_class >= GFX10 && !is_shift64)
4417       const_bus_limit = 2;
4418 
4419    unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
4420    if (num_sgprs == const_bus_limit && !is_literal_sgpr)
4421       return;
4422 
4423    if (literal_id && literal_uses < threshold &&
4424        (current_literal.isUndefined() ||
4425         (current_literal.size() == literal.size() &&
4426          current_literal.constantValue() == literal.constantValue()))) {
4427       /* mark the literal to be applied */
4428       while (mask) {
4429          unsigned i = u_bit_scan(&mask);
4430          if (instr->operands[i].isTemp() && instr->operands[i].tempId() == literal_id)
4431             ctx.uses[instr->operands[i].tempId()]--;
4432       }
4433    }
4434 }
4435 
4436 void
apply_literals(opt_ctx & ctx,aco_ptr<Instruction> & instr)4437 apply_literals(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4438 {
4439    /* Cleanup Dead Instructions */
4440    if (!instr)
4441       return;
4442 
4443    /* apply literals on MAD */
4444    if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
4445       mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].instr->pass_flags];
4446       if (info->check_literal &&
4447           (ctx.uses[instr->operands[info->literal_idx].tempId()] == 0 || info->literal_idx == 2)) {
4448          aco_ptr<Instruction> new_mad;
4449 
4450          aco_opcode new_op =
4451             info->literal_idx == 2 ? aco_opcode::v_madak_f32 : aco_opcode::v_madmk_f32;
4452          if (instr->opcode == aco_opcode::v_fma_f32)
4453             new_op = info->literal_idx == 2 ? aco_opcode::v_fmaak_f32 : aco_opcode::v_fmamk_f32;
4454          else if (instr->opcode == aco_opcode::v_mad_f16 ||
4455                   instr->opcode == aco_opcode::v_mad_legacy_f16)
4456             new_op = info->literal_idx == 2 ? aco_opcode::v_madak_f16 : aco_opcode::v_madmk_f16;
4457          else if (instr->opcode == aco_opcode::v_fma_f16)
4458             new_op = info->literal_idx == 2 ? aco_opcode::v_fmaak_f16 : aco_opcode::v_fmamk_f16;
4459 
4460          new_mad.reset(create_instruction<VOP2_instruction>(new_op, Format::VOP2, 3, 1));
4461          if (info->literal_idx == 2) { /* add literal -> madak */
4462             new_mad->operands[0] = instr->operands[0];
4463             new_mad->operands[1] = instr->operands[1];
4464             if (!new_mad->operands[1].isTemp() ||
4465                 new_mad->operands[1].getTemp().type() == RegType::sgpr)
4466                std::swap(new_mad->operands[0], new_mad->operands[1]);
4467          } else { /* mul literal -> madmk */
4468             new_mad->operands[0] = instr->operands[1 - info->literal_idx];
4469             new_mad->operands[1] = instr->operands[2];
4470          }
4471          new_mad->operands[2] =
4472             Operand::c32(ctx.info[instr->operands[info->literal_idx].tempId()].val);
4473          new_mad->definitions[0] = instr->definitions[0];
4474          ctx.instructions.emplace_back(std::move(new_mad));
4475          return;
4476       }
4477    }
4478 
4479    /* apply literals on other SALU/VALU */
4480    if (instr->isSALU() || instr->isVALU()) {
4481       for (unsigned i = 0; i < instr->operands.size(); i++) {
4482          Operand op = instr->operands[i];
4483          unsigned bits = get_operand_size(instr, i);
4484          if (op.isTemp() && ctx.info[op.tempId()].is_literal(bits) && ctx.uses[op.tempId()] == 0) {
4485             Operand literal = Operand::c32(ctx.info[op.tempId()].val);
4486             instr->format = withoutDPP(instr->format);
4487             if (instr->isVALU() && i > 0 && instr->format != Format::VOP3P)
4488                to_VOP3(ctx, instr);
4489             instr->operands[i] = literal;
4490          }
4491       }
4492    }
4493 
4494    ctx.instructions.emplace_back(std::move(instr));
4495 }
4496 
4497 void
optimize(Program * program)4498 optimize(Program* program)
4499 {
4500    opt_ctx ctx;
4501    ctx.program = program;
4502    std::vector<ssa_info> info(program->peekAllocationId());
4503    ctx.info = info.data();
4504 
4505    /* 1. Bottom-Up DAG pass (forward) to label all ssa-defs */
4506    for (Block& block : program->blocks) {
4507       ctx.fp_mode = block.fp_mode;
4508       for (aco_ptr<Instruction>& instr : block.instructions)
4509          label_instruction(ctx, instr);
4510    }
4511 
4512    ctx.uses = dead_code_analysis(program);
4513 
4514    /* 2. Combine v_mad, omod, clamp and propagate sgpr on VALU instructions */
4515    for (Block& block : program->blocks) {
4516       ctx.fp_mode = block.fp_mode;
4517       for (aco_ptr<Instruction>& instr : block.instructions)
4518          combine_instruction(ctx, instr);
4519    }
4520 
4521    /* 3. Top-Down DAG pass (backward) to select instructions (includes DCE) */
4522    for (auto block_rit = program->blocks.rbegin(); block_rit != program->blocks.rend();
4523         ++block_rit) {
4524       Block* block = &(*block_rit);
4525       ctx.fp_mode = block->fp_mode;
4526       for (auto instr_rit = block->instructions.rbegin(); instr_rit != block->instructions.rend();
4527            ++instr_rit)
4528          select_instruction(ctx, *instr_rit);
4529    }
4530 
4531    /* 4. Add literals to instructions */
4532    for (Block& block : program->blocks) {
4533       ctx.instructions.clear();
4534       ctx.fp_mode = block.fp_mode;
4535       for (aco_ptr<Instruction>& instr : block.instructions)
4536          apply_literals(ctx, instr);
4537       block.instructions.swap(ctx.instructions);
4538    }
4539 }
4540 
4541 } // namespace aco
4542