1 /*
2 * Copyright © 2018 Valve Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 */
24
25 #include "aco_builder.h"
26 #include "aco_ir.h"
27
28 #include "util/half_float.h"
29 #include "util/memstream.h"
30
31 #include <algorithm>
32 #include <array>
33 #include <vector>
34
35 namespace aco {
36
37 #ifndef NDEBUG
38 void
perfwarn(Program * program,bool cond,const char * msg,Instruction * instr)39 perfwarn(Program* program, bool cond, const char* msg, Instruction* instr)
40 {
41 if (cond) {
42 char* out;
43 size_t outsize;
44 struct u_memstream mem;
45 u_memstream_open(&mem, &out, &outsize);
46 FILE* const memf = u_memstream_get(&mem);
47
48 fprintf(memf, "%s: ", msg);
49 aco_print_instr(instr, memf);
50 u_memstream_close(&mem);
51
52 aco_perfwarn(program, out);
53 free(out);
54
55 if (debug_flags & DEBUG_PERFWARN)
56 exit(1);
57 }
58 }
59 #endif
60
61 /**
62 * The optimizer works in 4 phases:
63 * (1) The first pass collects information for each ssa-def,
64 * propagates reg->reg operands of the same type, inline constants
65 * and neg/abs input modifiers.
66 * (2) The second pass combines instructions like mad, omod, clamp and
67 * propagates sgpr's on VALU instructions.
68 * This pass depends on information collected in the first pass.
69 * (3) The third pass goes backwards, and selects instructions,
70 * i.e. decides if a mad instruction is profitable and eliminates dead code.
71 * (4) The fourth pass cleans up the sequence: literals get applied and dead
72 * instructions are removed from the sequence.
73 */
74
75 struct mad_info {
76 aco_ptr<Instruction> add_instr;
77 uint32_t mul_temp_id;
78 uint16_t literal_idx;
79 bool check_literal;
80
mad_infoaco::mad_info81 mad_info(aco_ptr<Instruction> instr, uint32_t id)
82 : add_instr(std::move(instr)), mul_temp_id(id), literal_idx(0), check_literal(false)
83 {}
84 };
85
86 enum Label {
87 label_vec = 1 << 0,
88 label_constant_32bit = 1 << 1,
89 /* label_{abs,neg,mul,omod2,omod4,omod5,clamp} are used for both 16 and
90 * 32-bit operations but this shouldn't cause any issues because we don't
91 * look through any conversions */
92 label_abs = 1 << 2,
93 label_neg = 1 << 3,
94 label_mul = 1 << 4,
95 label_temp = 1 << 5,
96 label_literal = 1 << 6,
97 label_mad = 1 << 7,
98 label_omod2 = 1 << 8,
99 label_omod4 = 1 << 9,
100 label_omod5 = 1 << 10,
101 label_clamp = 1 << 12,
102 label_undefined = 1 << 14,
103 label_vcc = 1 << 15,
104 label_b2f = 1 << 16,
105 label_add_sub = 1 << 17,
106 label_bitwise = 1 << 18,
107 label_minmax = 1 << 19,
108 label_vopc = 1 << 20,
109 label_uniform_bool = 1 << 21,
110 label_constant_64bit = 1 << 22,
111 label_uniform_bitwise = 1 << 23,
112 label_scc_invert = 1 << 24,
113 label_vcc_hint = 1 << 25,
114 label_scc_needed = 1 << 26,
115 label_b2i = 1 << 27,
116 label_fcanonicalize = 1 << 28,
117 label_constant_16bit = 1 << 29,
118 label_usedef = 1 << 30, /* generic label */
119 label_vop3p = 1ull << 31, /* 1ull to prevent sign extension */
120 label_canonicalized = 1ull << 32,
121 label_extract = 1ull << 33,
122 label_insert = 1ull << 34,
123 label_dpp16 = 1ull << 35,
124 label_dpp8 = 1ull << 36,
125 label_f2f32 = 1ull << 37,
126 label_f2f16 = 1ull << 38,
127 };
128
129 static constexpr uint64_t instr_usedef_labels =
130 label_vec | label_mul | label_mad | label_add_sub | label_vop3p | label_bitwise |
131 label_uniform_bitwise | label_minmax | label_vopc | label_usedef | label_extract | label_dpp16 |
132 label_dpp8 | label_f2f32;
133 static constexpr uint64_t instr_mod_labels =
134 label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert | label_f2f16;
135
136 static constexpr uint64_t instr_labels = instr_usedef_labels | instr_mod_labels;
137 static constexpr uint64_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f |
138 label_uniform_bool | label_scc_invert | label_b2i |
139 label_fcanonicalize;
140 static constexpr uint32_t val_labels =
141 label_constant_32bit | label_constant_64bit | label_constant_16bit | label_literal;
142
143 static_assert((instr_labels & temp_labels) == 0, "labels cannot intersect");
144 static_assert((instr_labels & val_labels) == 0, "labels cannot intersect");
145 static_assert((temp_labels & val_labels) == 0, "labels cannot intersect");
146
147 struct ssa_info {
148 uint64_t label;
149 union {
150 uint32_t val;
151 Temp temp;
152 Instruction* instr;
153 };
154
ssa_infoaco::ssa_info155 ssa_info() : label(0) {}
156
add_labelaco::ssa_info157 void add_label(Label new_label)
158 {
159 /* Since all the instr_usedef_labels use instr for the same thing
160 * (indicating the defining instruction), there is usually no need to
161 * clear any other instr labels. */
162 if (new_label & instr_usedef_labels)
163 label &= ~(instr_mod_labels | temp_labels | val_labels); /* instr, temp and val alias */
164
165 if (new_label & instr_mod_labels) {
166 label &= ~instr_labels;
167 label &= ~(temp_labels | val_labels); /* instr, temp and val alias */
168 }
169
170 if (new_label & temp_labels) {
171 label &= ~temp_labels;
172 label &= ~(instr_labels | val_labels); /* instr, temp and val alias */
173 }
174
175 uint32_t const_labels =
176 label_literal | label_constant_32bit | label_constant_64bit | label_constant_16bit;
177 if (new_label & const_labels) {
178 label &= ~val_labels | const_labels;
179 label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
180 } else if (new_label & val_labels) {
181 label &= ~val_labels;
182 label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
183 }
184
185 label |= new_label;
186 }
187
set_vecaco::ssa_info188 void set_vec(Instruction* vec)
189 {
190 add_label(label_vec);
191 instr = vec;
192 }
193
is_vecaco::ssa_info194 bool is_vec() { return label & label_vec; }
195
set_constantaco::ssa_info196 void set_constant(chip_class chip, uint64_t constant)
197 {
198 Operand op16 = Operand::c16(constant);
199 Operand op32 = Operand::get_const(chip, constant, 4);
200 add_label(label_literal);
201 val = constant;
202
203 /* check that no upper bits are lost in case of packed 16bit constants */
204 if (chip >= GFX8 && !op16.isLiteral() && op16.constantValue64() == constant)
205 add_label(label_constant_16bit);
206
207 if (!op32.isLiteral())
208 add_label(label_constant_32bit);
209
210 if (Operand::is_constant_representable(constant, 8))
211 add_label(label_constant_64bit);
212
213 if (label & label_constant_64bit) {
214 val = Operand::c64(constant).constantValue();
215 if (val != constant)
216 label &= ~(label_literal | label_constant_16bit | label_constant_32bit);
217 }
218 }
219
is_constantaco::ssa_info220 bool is_constant(unsigned bits)
221 {
222 switch (bits) {
223 case 8: return label & label_literal;
224 case 16: return label & label_constant_16bit;
225 case 32: return label & label_constant_32bit;
226 case 64: return label & label_constant_64bit;
227 }
228 return false;
229 }
230
is_literalaco::ssa_info231 bool is_literal(unsigned bits)
232 {
233 bool is_lit = label & label_literal;
234 switch (bits) {
235 case 8: return false;
236 case 16: return is_lit && ~(label & label_constant_16bit);
237 case 32: return is_lit && ~(label & label_constant_32bit);
238 case 64: return false;
239 }
240 return false;
241 }
242
is_constant_or_literalaco::ssa_info243 bool is_constant_or_literal(unsigned bits)
244 {
245 if (bits == 64)
246 return label & label_constant_64bit;
247 else
248 return label & label_literal;
249 }
250
set_absaco::ssa_info251 void set_abs(Temp abs_temp)
252 {
253 add_label(label_abs);
254 temp = abs_temp;
255 }
256
is_absaco::ssa_info257 bool is_abs() { return label & label_abs; }
258
set_negaco::ssa_info259 void set_neg(Temp neg_temp)
260 {
261 add_label(label_neg);
262 temp = neg_temp;
263 }
264
is_negaco::ssa_info265 bool is_neg() { return label & label_neg; }
266
set_neg_absaco::ssa_info267 void set_neg_abs(Temp neg_abs_temp)
268 {
269 add_label((Label)((uint32_t)label_abs | (uint32_t)label_neg));
270 temp = neg_abs_temp;
271 }
272
set_mulaco::ssa_info273 void set_mul(Instruction* mul)
274 {
275 add_label(label_mul);
276 instr = mul;
277 }
278
is_mulaco::ssa_info279 bool is_mul() { return label & label_mul; }
280
set_tempaco::ssa_info281 void set_temp(Temp tmp)
282 {
283 add_label(label_temp);
284 temp = tmp;
285 }
286
is_tempaco::ssa_info287 bool is_temp() { return label & label_temp; }
288
set_madaco::ssa_info289 void set_mad(Instruction* mad, uint32_t mad_info_idx)
290 {
291 add_label(label_mad);
292 mad->pass_flags = mad_info_idx;
293 instr = mad;
294 }
295
is_madaco::ssa_info296 bool is_mad() { return label & label_mad; }
297
set_omod2aco::ssa_info298 void set_omod2(Instruction* mul)
299 {
300 add_label(label_omod2);
301 instr = mul;
302 }
303
is_omod2aco::ssa_info304 bool is_omod2() { return label & label_omod2; }
305
set_omod4aco::ssa_info306 void set_omod4(Instruction* mul)
307 {
308 add_label(label_omod4);
309 instr = mul;
310 }
311
is_omod4aco::ssa_info312 bool is_omod4() { return label & label_omod4; }
313
set_omod5aco::ssa_info314 void set_omod5(Instruction* mul)
315 {
316 add_label(label_omod5);
317 instr = mul;
318 }
319
is_omod5aco::ssa_info320 bool is_omod5() { return label & label_omod5; }
321
set_clampaco::ssa_info322 void set_clamp(Instruction* med3)
323 {
324 add_label(label_clamp);
325 instr = med3;
326 }
327
is_clampaco::ssa_info328 bool is_clamp() { return label & label_clamp; }
329
set_f2f16aco::ssa_info330 void set_f2f16(Instruction* conv)
331 {
332 add_label(label_f2f16);
333 instr = conv;
334 }
335
is_f2f16aco::ssa_info336 bool is_f2f16() { return label & label_f2f16; }
337
set_undefinedaco::ssa_info338 void set_undefined() { add_label(label_undefined); }
339
is_undefinedaco::ssa_info340 bool is_undefined() { return label & label_undefined; }
341
set_vccaco::ssa_info342 void set_vcc(Temp vcc_val)
343 {
344 add_label(label_vcc);
345 temp = vcc_val;
346 }
347
is_vccaco::ssa_info348 bool is_vcc() { return label & label_vcc; }
349
set_b2faco::ssa_info350 void set_b2f(Temp b2f_val)
351 {
352 add_label(label_b2f);
353 temp = b2f_val;
354 }
355
is_b2faco::ssa_info356 bool is_b2f() { return label & label_b2f; }
357
set_add_subaco::ssa_info358 void set_add_sub(Instruction* add_sub_instr)
359 {
360 add_label(label_add_sub);
361 instr = add_sub_instr;
362 }
363
is_add_subaco::ssa_info364 bool is_add_sub() { return label & label_add_sub; }
365
set_bitwiseaco::ssa_info366 void set_bitwise(Instruction* bitwise_instr)
367 {
368 add_label(label_bitwise);
369 instr = bitwise_instr;
370 }
371
is_bitwiseaco::ssa_info372 bool is_bitwise() { return label & label_bitwise; }
373
set_uniform_bitwiseaco::ssa_info374 void set_uniform_bitwise() { add_label(label_uniform_bitwise); }
375
is_uniform_bitwiseaco::ssa_info376 bool is_uniform_bitwise() { return label & label_uniform_bitwise; }
377
set_minmaxaco::ssa_info378 void set_minmax(Instruction* minmax_instr)
379 {
380 add_label(label_minmax);
381 instr = minmax_instr;
382 }
383
is_minmaxaco::ssa_info384 bool is_minmax() { return label & label_minmax; }
385
set_vopcaco::ssa_info386 void set_vopc(Instruction* vopc_instr)
387 {
388 add_label(label_vopc);
389 instr = vopc_instr;
390 }
391
is_vopcaco::ssa_info392 bool is_vopc() { return label & label_vopc; }
393
set_scc_neededaco::ssa_info394 void set_scc_needed() { add_label(label_scc_needed); }
395
is_scc_neededaco::ssa_info396 bool is_scc_needed() { return label & label_scc_needed; }
397
set_scc_invertaco::ssa_info398 void set_scc_invert(Temp scc_inv)
399 {
400 add_label(label_scc_invert);
401 temp = scc_inv;
402 }
403
is_scc_invertaco::ssa_info404 bool is_scc_invert() { return label & label_scc_invert; }
405
set_uniform_boolaco::ssa_info406 void set_uniform_bool(Temp uniform_bool)
407 {
408 add_label(label_uniform_bool);
409 temp = uniform_bool;
410 }
411
is_uniform_boolaco::ssa_info412 bool is_uniform_bool() { return label & label_uniform_bool; }
413
set_vcc_hintaco::ssa_info414 void set_vcc_hint() { add_label(label_vcc_hint); }
415
is_vcc_hintaco::ssa_info416 bool is_vcc_hint() { return label & label_vcc_hint; }
417
set_b2iaco::ssa_info418 void set_b2i(Temp b2i_val)
419 {
420 add_label(label_b2i);
421 temp = b2i_val;
422 }
423
is_b2iaco::ssa_info424 bool is_b2i() { return label & label_b2i; }
425
set_usedefaco::ssa_info426 void set_usedef(Instruction* label_instr)
427 {
428 add_label(label_usedef);
429 instr = label_instr;
430 }
431
is_usedefaco::ssa_info432 bool is_usedef() { return label & label_usedef; }
433
set_vop3paco::ssa_info434 void set_vop3p(Instruction* vop3p_instr)
435 {
436 add_label(label_vop3p);
437 instr = vop3p_instr;
438 }
439
is_vop3paco::ssa_info440 bool is_vop3p() { return label & label_vop3p; }
441
set_fcanonicalizeaco::ssa_info442 void set_fcanonicalize(Temp tmp)
443 {
444 add_label(label_fcanonicalize);
445 temp = tmp;
446 }
447
is_fcanonicalizeaco::ssa_info448 bool is_fcanonicalize() { return label & label_fcanonicalize; }
449
set_canonicalizedaco::ssa_info450 void set_canonicalized() { add_label(label_canonicalized); }
451
is_canonicalizedaco::ssa_info452 bool is_canonicalized() { return label & label_canonicalized; }
453
set_f2f32aco::ssa_info454 void set_f2f32(Instruction* cvt)
455 {
456 add_label(label_f2f32);
457 instr = cvt;
458 }
459
is_f2f32aco::ssa_info460 bool is_f2f32() { return label & label_f2f32; }
461
set_extractaco::ssa_info462 void set_extract(Instruction* extract)
463 {
464 add_label(label_extract);
465 instr = extract;
466 }
467
is_extractaco::ssa_info468 bool is_extract() { return label & label_extract; }
469
set_insertaco::ssa_info470 void set_insert(Instruction* insert)
471 {
472 add_label(label_insert);
473 instr = insert;
474 }
475
is_insertaco::ssa_info476 bool is_insert() { return label & label_insert; }
477
set_dpp16aco::ssa_info478 void set_dpp16(Instruction* mov)
479 {
480 add_label(label_dpp16);
481 instr = mov;
482 }
483
set_dpp8aco::ssa_info484 void set_dpp8(Instruction* mov)
485 {
486 add_label(label_dpp8);
487 instr = mov;
488 }
489
is_dppaco::ssa_info490 bool is_dpp() { return label & (label_dpp16 | label_dpp8); }
is_dpp16aco::ssa_info491 bool is_dpp16() { return label & label_dpp16; }
is_dpp8aco::ssa_info492 bool is_dpp8() { return label & label_dpp8; }
493 };
494
495 struct opt_ctx {
496 Program* program;
497 float_mode fp_mode;
498 std::vector<aco_ptr<Instruction>> instructions;
499 ssa_info* info;
500 std::pair<uint32_t, Temp> last_literal;
501 std::vector<mad_info> mad_infos;
502 std::vector<uint16_t> uses;
503 };
504
505 bool
can_use_VOP3(opt_ctx & ctx,const aco_ptr<Instruction> & instr)506 can_use_VOP3(opt_ctx& ctx, const aco_ptr<Instruction>& instr)
507 {
508 if (instr->isVOP3())
509 return true;
510
511 if (instr->isVOP3P())
512 return false;
513
514 if (instr->operands.size() && instr->operands[0].isLiteral() && ctx.program->chip_class < GFX10)
515 return false;
516
517 if (instr->isDPP() || instr->isSDWA())
518 return false;
519
520 return instr->opcode != aco_opcode::v_madmk_f32 && instr->opcode != aco_opcode::v_madak_f32 &&
521 instr->opcode != aco_opcode::v_madmk_f16 && instr->opcode != aco_opcode::v_madak_f16 &&
522 instr->opcode != aco_opcode::v_fmamk_f32 && instr->opcode != aco_opcode::v_fmaak_f32 &&
523 instr->opcode != aco_opcode::v_fmamk_f16 && instr->opcode != aco_opcode::v_fmaak_f16 &&
524 instr->opcode != aco_opcode::v_readlane_b32 &&
525 instr->opcode != aco_opcode::v_writelane_b32 &&
526 instr->opcode != aco_opcode::v_readfirstlane_b32;
527 }
528
529 bool
pseudo_propagate_temp(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp temp,unsigned index)530 pseudo_propagate_temp(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp temp, unsigned index)
531 {
532 if (instr->definitions.empty())
533 return false;
534
535 const bool vgpr =
536 instr->opcode == aco_opcode::p_as_uniform ||
537 std::all_of(instr->definitions.begin(), instr->definitions.end(),
538 [](const Definition& def) { return def.regClass().type() == RegType::vgpr; });
539
540 /* don't propagate VGPRs into SGPR instructions */
541 if (temp.type() == RegType::vgpr && !vgpr)
542 return false;
543
544 bool can_accept_sgpr =
545 ctx.program->chip_class >= GFX9 ||
546 std::none_of(instr->definitions.begin(), instr->definitions.end(),
547 [](const Definition& def) { return def.regClass().is_subdword(); });
548
549 switch (instr->opcode) {
550 case aco_opcode::p_phi:
551 case aco_opcode::p_linear_phi:
552 case aco_opcode::p_parallelcopy:
553 case aco_opcode::p_create_vector:
554 if (temp.bytes() != instr->operands[index].bytes())
555 return false;
556 break;
557 case aco_opcode::p_extract_vector:
558 case aco_opcode::p_extract:
559 if (temp.type() == RegType::sgpr && !can_accept_sgpr)
560 return false;
561 break;
562 case aco_opcode::p_split_vector: {
563 if (temp.type() == RegType::sgpr && !can_accept_sgpr)
564 return false;
565 /* don't increase the vector size */
566 if (temp.bytes() > instr->operands[index].bytes())
567 return false;
568 /* We can decrease the vector size as smaller temporaries are only
569 * propagated by p_as_uniform instructions.
570 * If this propagation leads to invalid IR or hits the assertion below,
571 * it means that some undefined bytes within a dword are begin accessed
572 * and a bug in instruction_selection is likely. */
573 int decrease = instr->operands[index].bytes() - temp.bytes();
574 while (decrease > 0) {
575 decrease -= instr->definitions.back().bytes();
576 instr->definitions.pop_back();
577 }
578 assert(decrease == 0);
579 break;
580 }
581 case aco_opcode::p_as_uniform:
582 if (temp.regClass() == instr->definitions[0].regClass())
583 instr->opcode = aco_opcode::p_parallelcopy;
584 break;
585 default: return false;
586 }
587
588 instr->operands[index].setTemp(temp);
589 return true;
590 }
591
592 /* This expects the DPP modifier to be removed. */
593 bool
can_apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)594 can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
595 {
596 if (instr->isSDWA() && ctx.program->chip_class < GFX9)
597 return false;
598 return instr->opcode != aco_opcode::v_readfirstlane_b32 &&
599 instr->opcode != aco_opcode::v_readlane_b32 &&
600 instr->opcode != aco_opcode::v_readlane_b32_e64 &&
601 instr->opcode != aco_opcode::v_writelane_b32 &&
602 instr->opcode != aco_opcode::v_writelane_b32_e64 &&
603 instr->opcode != aco_opcode::v_permlane16_b32 &&
604 instr->opcode != aco_opcode::v_permlanex16_b32;
605 }
606
607 void
to_VOP3(opt_ctx & ctx,aco_ptr<Instruction> & instr)608 to_VOP3(opt_ctx& ctx, aco_ptr<Instruction>& instr)
609 {
610 if (instr->isVOP3())
611 return;
612
613 aco_ptr<Instruction> tmp = std::move(instr);
614 Format format = asVOP3(tmp->format);
615 instr.reset(create_instruction<VOP3_instruction>(tmp->opcode, format, tmp->operands.size(),
616 tmp->definitions.size()));
617 std::copy(tmp->operands.cbegin(), tmp->operands.cend(), instr->operands.begin());
618 for (unsigned i = 0; i < instr->definitions.size(); i++) {
619 instr->definitions[i] = tmp->definitions[i];
620 if (instr->definitions[i].isTemp()) {
621 ssa_info& info = ctx.info[instr->definitions[i].tempId()];
622 if (info.label & instr_usedef_labels && info.instr == tmp.get())
623 info.instr = instr.get();
624 }
625 }
626 /* we don't need to update any instr_mod_labels because they either haven't
627 * been applied yet or this instruction isn't dead and so they've been ignored */
628
629 instr->pass_flags = tmp->pass_flags;
630 }
631
632 bool
is_operand_vgpr(Operand op)633 is_operand_vgpr(Operand op)
634 {
635 return op.isTemp() && op.getTemp().type() == RegType::vgpr;
636 }
637
638 void
to_SDWA(opt_ctx & ctx,aco_ptr<Instruction> & instr)639 to_SDWA(opt_ctx& ctx, aco_ptr<Instruction>& instr)
640 {
641 aco_ptr<Instruction> tmp = convert_to_SDWA(ctx.program->chip_class, instr);
642 if (!tmp)
643 return;
644
645 for (unsigned i = 0; i < instr->definitions.size(); i++) {
646 ssa_info& info = ctx.info[instr->definitions[i].tempId()];
647 if (info.label & instr_labels && info.instr == tmp.get())
648 info.instr = instr.get();
649 }
650 }
651
652 /* only covers special cases */
653 bool
alu_can_accept_constant(aco_opcode opcode,unsigned operand)654 alu_can_accept_constant(aco_opcode opcode, unsigned operand)
655 {
656 switch (opcode) {
657 case aco_opcode::v_interp_p2_f32:
658 case aco_opcode::v_mac_f32:
659 case aco_opcode::v_writelane_b32:
660 case aco_opcode::v_writelane_b32_e64:
661 case aco_opcode::v_cndmask_b32: return operand != 2;
662 case aco_opcode::s_addk_i32:
663 case aco_opcode::s_mulk_i32:
664 case aco_opcode::p_wqm:
665 case aco_opcode::p_extract_vector:
666 case aco_opcode::p_split_vector:
667 case aco_opcode::v_readlane_b32:
668 case aco_opcode::v_readlane_b32_e64:
669 case aco_opcode::v_readfirstlane_b32:
670 case aco_opcode::p_extract:
671 case aco_opcode::p_insert: return operand != 0;
672 default: return true;
673 }
674 }
675
676 bool
valu_can_accept_vgpr(aco_ptr<Instruction> & instr,unsigned operand)677 valu_can_accept_vgpr(aco_ptr<Instruction>& instr, unsigned operand)
678 {
679 if (instr->opcode == aco_opcode::v_readlane_b32 ||
680 instr->opcode == aco_opcode::v_readlane_b32_e64 ||
681 instr->opcode == aco_opcode::v_writelane_b32 ||
682 instr->opcode == aco_opcode::v_writelane_b32_e64)
683 return operand != 1;
684 if (instr->opcode == aco_opcode::v_permlane16_b32 ||
685 instr->opcode == aco_opcode::v_permlanex16_b32)
686 return operand == 0;
687 return true;
688 }
689
690 /* check constant bus and literal limitations */
691 bool
check_vop3_operands(opt_ctx & ctx,unsigned num_operands,Operand * operands)692 check_vop3_operands(opt_ctx& ctx, unsigned num_operands, Operand* operands)
693 {
694 int limit = ctx.program->chip_class >= GFX10 ? 2 : 1;
695 Operand literal32(s1);
696 Operand literal64(s2);
697 unsigned num_sgprs = 0;
698 unsigned sgpr[] = {0, 0};
699
700 for (unsigned i = 0; i < num_operands; i++) {
701 Operand op = operands[i];
702
703 if (op.hasRegClass() && op.regClass().type() == RegType::sgpr) {
704 /* two reads of the same SGPR count as 1 to the limit */
705 if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
706 if (num_sgprs < 2)
707 sgpr[num_sgprs++] = op.tempId();
708 limit--;
709 if (limit < 0)
710 return false;
711 }
712 } else if (op.isLiteral()) {
713 if (ctx.program->chip_class < GFX10)
714 return false;
715
716 if (!literal32.isUndefined() && literal32.constantValue() != op.constantValue())
717 return false;
718 if (!literal64.isUndefined() && literal64.constantValue() != op.constantValue())
719 return false;
720
721 /* Any number of 32-bit literals counts as only 1 to the limit. Same
722 * (but separately) for 64-bit literals. */
723 if (op.size() == 1 && literal32.isUndefined()) {
724 limit--;
725 literal32 = op;
726 } else if (op.size() == 2 && literal64.isUndefined()) {
727 limit--;
728 literal64 = op;
729 }
730
731 if (limit < 0)
732 return false;
733 }
734 }
735
736 return true;
737 }
738
739 bool
parse_base_offset(opt_ctx & ctx,Instruction * instr,unsigned op_index,Temp * base,uint32_t * offset,bool prevent_overflow)740 parse_base_offset(opt_ctx& ctx, Instruction* instr, unsigned op_index, Temp* base, uint32_t* offset,
741 bool prevent_overflow)
742 {
743 Operand op = instr->operands[op_index];
744
745 if (!op.isTemp())
746 return false;
747 Temp tmp = op.getTemp();
748 if (!ctx.info[tmp.id()].is_add_sub())
749 return false;
750
751 Instruction* add_instr = ctx.info[tmp.id()].instr;
752
753 switch (add_instr->opcode) {
754 case aco_opcode::v_add_u32:
755 case aco_opcode::v_add_co_u32:
756 case aco_opcode::v_add_co_u32_e64:
757 case aco_opcode::s_add_i32:
758 case aco_opcode::s_add_u32: break;
759 default: return false;
760 }
761 if (prevent_overflow && !add_instr->definitions[0].isNUW())
762 return false;
763
764 if (add_instr->usesModifiers())
765 return false;
766
767 for (unsigned i = 0; i < 2; i++) {
768 if (add_instr->operands[i].isConstant()) {
769 *offset = add_instr->operands[i].constantValue();
770 } else if (add_instr->operands[i].isTemp() &&
771 ctx.info[add_instr->operands[i].tempId()].is_constant_or_literal(32)) {
772 *offset = ctx.info[add_instr->operands[i].tempId()].val;
773 } else {
774 continue;
775 }
776 if (!add_instr->operands[!i].isTemp())
777 continue;
778
779 uint32_t offset2 = 0;
780 if (parse_base_offset(ctx, add_instr, !i, base, &offset2, prevent_overflow)) {
781 *offset += offset2;
782 } else {
783 *base = add_instr->operands[!i].getTemp();
784 }
785 return true;
786 }
787
788 return false;
789 }
790
791 void
skip_smem_offset_align(opt_ctx & ctx,SMEM_instruction * smem)792 skip_smem_offset_align(opt_ctx& ctx, SMEM_instruction* smem)
793 {
794 bool soe = smem->operands.size() >= (!smem->definitions.empty() ? 3 : 4);
795 if (soe && !smem->operands[1].isConstant())
796 return;
797 /* We don't need to check the constant offset because the address seems to be calculated with
798 * (offset&-4 + const_offset&-4), not (offset+const_offset)&-4.
799 */
800
801 Operand& op = smem->operands[soe ? smem->operands.size() - 1 : 1];
802 if (!op.isTemp() || !ctx.info[op.tempId()].is_bitwise())
803 return;
804
805 Instruction* bitwise_instr = ctx.info[op.tempId()].instr;
806 if (bitwise_instr->opcode != aco_opcode::s_and_b32)
807 return;
808
809 if (bitwise_instr->operands[0].constantEquals(-4) &&
810 bitwise_instr->operands[1].isOfType(op.regClass().type()))
811 op.setTemp(bitwise_instr->operands[1].getTemp());
812 else if (bitwise_instr->operands[1].constantEquals(-4) &&
813 bitwise_instr->operands[0].isOfType(op.regClass().type()))
814 op.setTemp(bitwise_instr->operands[0].getTemp());
815 }
816
817 void
smem_combine(opt_ctx & ctx,aco_ptr<Instruction> & instr)818 smem_combine(opt_ctx& ctx, aco_ptr<Instruction>& instr)
819 {
820 /* skip &-4 before offset additions: load((a + 16) & -4, 0) */
821 if (!instr->operands.empty())
822 skip_smem_offset_align(ctx, &instr->smem());
823
824 /* propagate constants and combine additions */
825 if (!instr->operands.empty() && instr->operands[1].isTemp()) {
826 SMEM_instruction& smem = instr->smem();
827 ssa_info info = ctx.info[instr->operands[1].tempId()];
828
829 Temp base;
830 uint32_t offset;
831 bool prevent_overflow = smem.operands[0].size() > 2 || smem.prevent_overflow;
832 if (info.is_constant_or_literal(32) &&
833 ((ctx.program->chip_class == GFX6 && info.val <= 0x3FF) ||
834 (ctx.program->chip_class == GFX7 && info.val <= 0xFFFFFFFF) ||
835 (ctx.program->chip_class >= GFX8 && info.val <= 0xFFFFF))) {
836 instr->operands[1] = Operand::c32(info.val);
837 } else if (parse_base_offset(ctx, instr.get(), 1, &base, &offset, prevent_overflow) &&
838 base.regClass() == s1 && offset <= 0xFFFFF && ctx.program->chip_class >= GFX9 &&
839 offset % 4u == 0) {
840 bool soe = smem.operands.size() >= (!smem.definitions.empty() ? 3 : 4);
841 if (soe) {
842 if (ctx.info[smem.operands.back().tempId()].is_constant_or_literal(32) &&
843 ctx.info[smem.operands.back().tempId()].val == 0) {
844 smem.operands[1] = Operand::c32(offset);
845 smem.operands.back() = Operand(base);
846 }
847 } else {
848 SMEM_instruction* new_instr = create_instruction<SMEM_instruction>(
849 smem.opcode, Format::SMEM, smem.operands.size() + 1, smem.definitions.size());
850 new_instr->operands[0] = smem.operands[0];
851 new_instr->operands[1] = Operand::c32(offset);
852 if (smem.definitions.empty())
853 new_instr->operands[2] = smem.operands[2];
854 new_instr->operands.back() = Operand(base);
855 if (!smem.definitions.empty())
856 new_instr->definitions[0] = smem.definitions[0];
857 new_instr->sync = smem.sync;
858 new_instr->glc = smem.glc;
859 new_instr->dlc = smem.dlc;
860 new_instr->nv = smem.nv;
861 new_instr->disable_wqm = smem.disable_wqm;
862 instr.reset(new_instr);
863 }
864 }
865 }
866
867 /* skip &-4 after offset additions: load(a & -4, 16) */
868 if (!instr->operands.empty())
869 skip_smem_offset_align(ctx, &instr->smem());
870 }
871
872 unsigned
get_operand_size(aco_ptr<Instruction> & instr,unsigned index)873 get_operand_size(aco_ptr<Instruction>& instr, unsigned index)
874 {
875 if (instr->isPseudo())
876 return instr->operands[index].bytes() * 8u;
877 else if (instr->opcode == aco_opcode::v_mad_u64_u32 ||
878 instr->opcode == aco_opcode::v_mad_i64_i32)
879 return index == 2 ? 64 : 32;
880 else if (instr->opcode == aco_opcode::v_fma_mix_f32 ||
881 instr->opcode == aco_opcode::v_fma_mixlo_f16)
882 return instr->vop3p().opsel_hi & (1u << index) ? 16 : 32;
883 else if (instr->isVALU() || instr->isSALU())
884 return instr_info.operand_size[(int)instr->opcode];
885 else
886 return 0;
887 }
888
889 Operand
get_constant_op(opt_ctx & ctx,ssa_info info,uint32_t bits)890 get_constant_op(opt_ctx& ctx, ssa_info info, uint32_t bits)
891 {
892 if (bits == 64)
893 return Operand::c32_or_c64(info.val, true);
894 return Operand::get_const(ctx.program->chip_class, info.val, bits / 8u);
895 }
896
897 void
propagate_constants_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info,unsigned i)898 propagate_constants_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info, unsigned i)
899 {
900 if (!info.is_constant_or_literal(32))
901 return;
902
903 assert(instr->operands[i].isTemp());
904 unsigned bits = get_operand_size(instr, i);
905 if (info.is_constant(bits)) {
906 instr->operands[i] = get_constant_op(ctx, info, bits);
907 return;
908 }
909
910 /* try to fold inline constants */
911 VOP3P_instruction* vop3p = &instr->vop3p();
912 Operand const_lo = Operand::c16(info.val);
913 Operand const_hi = Operand::c16(info.val >> 16);
914 bool opsel_lo = (vop3p->opsel_lo >> i) & 1;
915 bool opsel_hi = (vop3p->opsel_hi >> i) & 1;
916
917 if (const_hi.isLiteral() && (opsel_lo || opsel_hi))
918 return;
919 if (const_lo.isLiteral() && !(opsel_lo && opsel_hi))
920 return;
921
922 if (opsel_lo == opsel_hi) {
923 /* use the single 16bit value */
924 instr->operands[i] = opsel_lo ? const_hi : const_lo;
925
926 /* opsel must point to lo for both halves */
927 vop3p->opsel_lo &= ~(1 << i);
928 vop3p->opsel_hi &= ~(1 << i);
929 } else if (const_lo == const_hi) {
930 /* both constants are the same */
931 instr->operands[i] = const_lo;
932
933 /* opsel must point to lo for both halves */
934 vop3p->opsel_lo &= ~(1 << i);
935 vop3p->opsel_hi &= ~(1 << i);
936 } else if (const_lo == Operand::c16(0)) {
937 /* don't inline FP constants into integer instructions */
938 // TODO: check if negative integers are zero- or sign-extended
939 if (bits == 32 && const_hi.constantValue() > 64u)
940 return;
941
942 instr->operands[i] = const_hi;
943
944 /* redirect opsel selection */
945 vop3p->opsel_lo ^= (1 << i);
946 vop3p->opsel_hi ^= (1 << i);
947 } else if (bits == 16 && const_lo.constantValue() == (const_hi.constantValue() ^ (1 << 15))) {
948 /* const_lo == -const_hi */
949 if (!instr_info.can_use_input_modifiers[(int)instr->opcode])
950 return;
951
952 instr->operands[i] = Operand::c16(const_lo.constantValue() & 0x7FFF);
953 bool neg_lo = const_lo.constantValue() & (1 << 15);
954 vop3p->neg_lo[i] ^= opsel_lo ^ neg_lo;
955 vop3p->neg_hi[i] ^= opsel_hi ^ neg_lo;
956
957 /* opsel must point to lo for both operands */
958 vop3p->opsel_lo &= ~(1 << i);
959 vop3p->opsel_hi &= ~(1 << i);
960 }
961 }
962
963 bool
fixed_to_exec(Operand op)964 fixed_to_exec(Operand op)
965 {
966 return op.isFixed() && op.physReg() == exec;
967 }
968
969 SubdwordSel
parse_extract(Instruction * instr)970 parse_extract(Instruction* instr)
971 {
972 if (instr->opcode == aco_opcode::p_extract) {
973 unsigned size = instr->operands[2].constantValue() / 8;
974 unsigned offset = instr->operands[1].constantValue() * size;
975 bool sext = instr->operands[3].constantEquals(1);
976 return SubdwordSel(size, offset, sext);
977 } else if (instr->opcode == aco_opcode::p_insert && instr->operands[1].constantEquals(0)) {
978 return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
979 } else if (instr->opcode == aco_opcode::p_extract_vector) {
980 unsigned size = instr->definitions[0].bytes();
981 unsigned offset = instr->operands[1].constantValue() * size;
982 if (size <= 2)
983 return SubdwordSel(size, offset, false);
984 } else if (instr->opcode == aco_opcode::p_split_vector) {
985 assert(instr->operands[0].bytes() == 4 && instr->definitions[1].bytes() == 2);
986 return SubdwordSel(2, 2, false);
987 }
988
989 return SubdwordSel();
990 }
991
992 SubdwordSel
parse_insert(Instruction * instr)993 parse_insert(Instruction* instr)
994 {
995 if (instr->opcode == aco_opcode::p_extract && instr->operands[3].constantEquals(0) &&
996 instr->operands[1].constantEquals(0)) {
997 return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
998 } else if (instr->opcode == aco_opcode::p_insert) {
999 unsigned size = instr->operands[2].constantValue() / 8;
1000 unsigned offset = instr->operands[1].constantValue() * size;
1001 return SubdwordSel(size, offset, false);
1002 } else {
1003 return SubdwordSel();
1004 }
1005 }
1006
1007 bool
can_apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)1008 can_apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
1009 {
1010 if (idx >= 2)
1011 return false;
1012
1013 Temp tmp = info.instr->operands[0].getTemp();
1014 SubdwordSel sel = parse_extract(info.instr);
1015
1016 if (!sel) {
1017 return false;
1018 } else if (sel.size() == 4) {
1019 return true;
1020 } else if (instr->opcode == aco_opcode::v_cvt_f32_u32 && sel.size() == 1 && !sel.sign_extend()) {
1021 return true;
1022 } else if (can_use_SDWA(ctx.program->chip_class, instr, true) &&
1023 (tmp.type() == RegType::vgpr || ctx.program->chip_class >= GFX9)) {
1024 if (instr->isSDWA() && instr->sdwa().sel[idx] != SubdwordSel::dword)
1025 return false;
1026 return true;
1027 } else if (instr->isVOP3() && sel.size() == 2 &&
1028 can_use_opsel(ctx.program->chip_class, instr->opcode, idx) &&
1029 !(instr->vop3().opsel & (1 << idx))) {
1030 return true;
1031 } else if (instr->opcode == aco_opcode::p_extract) {
1032 SubdwordSel instrSel = parse_extract(instr.get());
1033
1034 /* the outer offset must be within extracted range */
1035 if (instrSel.offset() >= sel.size())
1036 return false;
1037
1038 /* don't remove the sign-extension when increasing the size further */
1039 if (instrSel.size() > sel.size() && !instrSel.sign_extend() && sel.sign_extend())
1040 return false;
1041
1042 return true;
1043 }
1044
1045 return false;
1046 }
1047
1048 /* Combine an p_extract (or p_insert, in some cases) instruction with instr.
1049 * instr(p_extract(...)) -> instr()
1050 */
1051 void
apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)1052 apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
1053 {
1054 Temp tmp = info.instr->operands[0].getTemp();
1055 SubdwordSel sel = parse_extract(info.instr);
1056 assert(sel);
1057
1058 instr->operands[idx].set16bit(false);
1059 instr->operands[idx].set24bit(false);
1060
1061 ctx.info[tmp.id()].label &= ~label_insert;
1062
1063 if (sel.size() == 4) {
1064 /* full dword selection */
1065 } else if (instr->opcode == aco_opcode::v_cvt_f32_u32 && sel.size() == 1 && !sel.sign_extend()) {
1066 switch (sel.offset()) {
1067 case 0: instr->opcode = aco_opcode::v_cvt_f32_ubyte0; break;
1068 case 1: instr->opcode = aco_opcode::v_cvt_f32_ubyte1; break;
1069 case 2: instr->opcode = aco_opcode::v_cvt_f32_ubyte2; break;
1070 case 3: instr->opcode = aco_opcode::v_cvt_f32_ubyte3; break;
1071 }
1072 } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() &&
1073 sel.offset() == 0 &&
1074 ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) ||
1075 (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) {
1076 /* The undesireable upper bits are already shifted out. */
1077 return;
1078 } else if (can_use_SDWA(ctx.program->chip_class, instr, true) &&
1079 (tmp.type() == RegType::vgpr || ctx.program->chip_class >= GFX9)) {
1080 to_SDWA(ctx, instr);
1081 static_cast<SDWA_instruction*>(instr.get())->sel[idx] = sel;
1082 } else if (instr->isVOP3()) {
1083 if (sel.offset())
1084 instr->vop3().opsel |= 1 << idx;
1085 } else if (instr->opcode == aco_opcode::p_extract) {
1086 SubdwordSel instrSel = parse_extract(instr.get());
1087
1088 unsigned size = std::min(sel.size(), instrSel.size());
1089 unsigned offset = sel.offset() + instrSel.offset();
1090 unsigned sign_extend =
1091 instrSel.sign_extend() && (sel.sign_extend() || instrSel.size() <= sel.size());
1092
1093 instr->operands[1] = Operand::c32(offset / size);
1094 instr->operands[2] = Operand::c32(size * 8u);
1095 instr->operands[3] = Operand::c32(sign_extend);
1096 return;
1097 }
1098
1099 /* Output modifier, label_vopc and label_f2f32 seem to be the only one worth keeping at the
1100 * moment
1101 */
1102 for (Definition& def : instr->definitions)
1103 ctx.info[def.tempId()].label &= (label_vopc | label_f2f32 | instr_mod_labels);
1104 }
1105
1106 void
check_sdwa_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr)1107 check_sdwa_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1108 {
1109 for (unsigned i = 0; i < instr->operands.size(); i++) {
1110 Operand op = instr->operands[i];
1111 if (!op.isTemp())
1112 continue;
1113 ssa_info& info = ctx.info[op.tempId()];
1114 if (info.is_extract() && (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
1115 op.getTemp().type() == RegType::sgpr)) {
1116 if (!can_apply_extract(ctx, instr, i, info))
1117 info.label &= ~label_extract;
1118 }
1119 }
1120 }
1121
1122 bool
does_fp_op_flush_denorms(opt_ctx & ctx,aco_opcode op)1123 does_fp_op_flush_denorms(opt_ctx& ctx, aco_opcode op)
1124 {
1125 if (ctx.program->chip_class <= GFX8) {
1126 switch (op) {
1127 case aco_opcode::v_min_f32:
1128 case aco_opcode::v_max_f32:
1129 case aco_opcode::v_med3_f32:
1130 case aco_opcode::v_min3_f32:
1131 case aco_opcode::v_max3_f32:
1132 case aco_opcode::v_min_f16:
1133 case aco_opcode::v_max_f16: return false;
1134 default: break;
1135 }
1136 }
1137 return op != aco_opcode::v_cndmask_b32;
1138 }
1139
1140 bool
can_eliminate_fcanonicalize(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp tmp)1141 can_eliminate_fcanonicalize(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp tmp)
1142 {
1143 float_mode* fp = &ctx.fp_mode;
1144 if (ctx.info[tmp.id()].is_canonicalized() ||
1145 (tmp.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
1146 return true;
1147
1148 aco_opcode op = instr->opcode;
1149 return instr_info.can_use_input_modifiers[(int)op] && does_fp_op_flush_denorms(ctx, op);
1150 }
1151
1152 bool
is_copy_label(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info)1153 is_copy_label(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info)
1154 {
1155 return info.is_temp() ||
1156 (info.is_fcanonicalize() && can_eliminate_fcanonicalize(ctx, instr, info.temp));
1157 }
1158
1159 bool
is_op_canonicalized(opt_ctx & ctx,Operand op)1160 is_op_canonicalized(opt_ctx& ctx, Operand op)
1161 {
1162 float_mode* fp = &ctx.fp_mode;
1163 if ((op.isTemp() && ctx.info[op.tempId()].is_canonicalized()) ||
1164 (op.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
1165 return true;
1166
1167 if (op.isConstant() || (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(32))) {
1168 uint32_t val = op.isTemp() ? ctx.info[op.tempId()].val : op.constantValue();
1169 if (op.bytes() == 2)
1170 return (val & 0x7fff) == 0 || (val & 0x7fff) > 0x3ff;
1171 else if (op.bytes() == 4)
1172 return (val & 0x7fffffff) == 0 || (val & 0x7fffffff) > 0x7fffff;
1173 }
1174 return false;
1175 }
1176
1177 void
label_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)1178 label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1179 {
1180 if (instr->isSALU() || instr->isVALU() || instr->isPseudo()) {
1181 ASSERTED bool all_const = false;
1182 for (Operand& op : instr->operands)
1183 all_const =
1184 all_const && (!op.isTemp() || ctx.info[op.tempId()].is_constant_or_literal(32));
1185 perfwarn(ctx.program, all_const, "All instruction operands are constant", instr.get());
1186
1187 ASSERTED bool is_copy = instr->opcode == aco_opcode::s_mov_b32 ||
1188 instr->opcode == aco_opcode::s_mov_b64 ||
1189 instr->opcode == aco_opcode::v_mov_b32;
1190 perfwarn(ctx.program, is_copy && !instr->usesModifiers(), "Use p_parallelcopy instead",
1191 instr.get());
1192 }
1193
1194 if (instr->isSMEM())
1195 smem_combine(ctx, instr);
1196
1197 for (unsigned i = 0; i < instr->operands.size(); i++) {
1198 if (!instr->operands[i].isTemp())
1199 continue;
1200
1201 ssa_info info = ctx.info[instr->operands[i].tempId()];
1202 /* propagate undef */
1203 if (info.is_undefined() && is_phi(instr))
1204 instr->operands[i] = Operand(instr->operands[i].regClass());
1205 /* propagate reg->reg of same type */
1206 while (info.is_temp() && info.temp.regClass() == instr->operands[i].getTemp().regClass()) {
1207 instr->operands[i].setTemp(ctx.info[instr->operands[i].tempId()].temp);
1208 info = ctx.info[info.temp.id()];
1209 }
1210
1211 /* PSEUDO: propagate temporaries */
1212 if (instr->isPseudo()) {
1213 while (info.is_temp()) {
1214 pseudo_propagate_temp(ctx, instr, info.temp, i);
1215 info = ctx.info[info.temp.id()];
1216 }
1217 }
1218
1219 /* SALU / PSEUDO: propagate inline constants */
1220 if (instr->isSALU() || instr->isPseudo()) {
1221 unsigned bits = get_operand_size(instr, i);
1222 if ((info.is_constant(bits) || (info.is_literal(bits) && instr->isPseudo())) &&
1223 !instr->operands[i].isFixed() && alu_can_accept_constant(instr->opcode, i)) {
1224 instr->operands[i] = get_constant_op(ctx, info, bits);
1225 continue;
1226 }
1227 }
1228
1229 /* VALU: propagate neg, abs & inline constants */
1230 else if (instr->isVALU()) {
1231 if (is_copy_label(ctx, instr, info) && info.temp.type() == RegType::vgpr &&
1232 valu_can_accept_vgpr(instr, i)) {
1233 instr->operands[i].setTemp(info.temp);
1234 info = ctx.info[info.temp.id()];
1235 }
1236 /* applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier */
1237 if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(ctx, instr) &&
1238 instr->operands.size() == 1) {
1239 instr->format = withoutDPP(instr->format);
1240 instr->operands[i].setTemp(info.temp);
1241 info = ctx.info[info.temp.id()];
1242 }
1243
1244 /* for instructions other than v_cndmask_b32, the size of the instruction should match the
1245 * operand size */
1246 unsigned can_use_mod =
1247 instr->opcode != aco_opcode::v_cndmask_b32 || instr->operands[i].getTemp().bytes() == 4;
1248 can_use_mod = can_use_mod && instr_info.can_use_input_modifiers[(int)instr->opcode];
1249
1250 if (instr->isSDWA())
1251 can_use_mod = can_use_mod && instr->sdwa().sel[i].size() == 4;
1252 else
1253 can_use_mod = can_use_mod && (instr->isDPP16() || can_use_VOP3(ctx, instr));
1254
1255 unsigned bits = get_operand_size(instr, i);
1256 bool mod_bitsize_compat = instr->operands[i].bytes() * 8 == bits;
1257
1258 if (info.is_neg() && instr->opcode == aco_opcode::v_add_f32 && mod_bitsize_compat) {
1259 instr->opcode = i ? aco_opcode::v_sub_f32 : aco_opcode::v_subrev_f32;
1260 instr->operands[i].setTemp(info.temp);
1261 } else if (info.is_neg() && instr->opcode == aco_opcode::v_add_f16 && mod_bitsize_compat) {
1262 instr->opcode = i ? aco_opcode::v_sub_f16 : aco_opcode::v_subrev_f16;
1263 instr->operands[i].setTemp(info.temp);
1264 } else if (info.is_neg() && can_use_mod && mod_bitsize_compat &&
1265 can_eliminate_fcanonicalize(ctx, instr, info.temp)) {
1266 if (!instr->isDPP() && !instr->isSDWA())
1267 to_VOP3(ctx, instr);
1268 instr->operands[i].setTemp(info.temp);
1269 if (instr->isDPP16() && !instr->dpp16().abs[i])
1270 instr->dpp16().neg[i] = true;
1271 else if (instr->isSDWA() && !instr->sdwa().abs[i])
1272 instr->sdwa().neg[i] = true;
1273 else if (instr->isVOP3() && !instr->vop3().abs[i])
1274 instr->vop3().neg[i] = true;
1275 }
1276 if (info.is_abs() && can_use_mod && mod_bitsize_compat &&
1277 can_eliminate_fcanonicalize(ctx, instr, info.temp)) {
1278 if (!instr->isDPP() && !instr->isSDWA())
1279 to_VOP3(ctx, instr);
1280 instr->operands[i] = Operand(info.temp);
1281 if (instr->isDPP16())
1282 instr->dpp16().abs[i] = true;
1283 else if (instr->isSDWA())
1284 instr->sdwa().abs[i] = true;
1285 else
1286 instr->vop3().abs[i] = true;
1287 continue;
1288 }
1289
1290 if (instr->isVOP3P()) {
1291 propagate_constants_vop3p(ctx, instr, info, i);
1292 continue;
1293 }
1294
1295 if (info.is_constant(bits) && alu_can_accept_constant(instr->opcode, i) &&
1296 (!instr->isSDWA() || ctx.program->chip_class >= GFX9)) {
1297 Operand op = get_constant_op(ctx, info, bits);
1298 perfwarn(ctx.program, instr->opcode == aco_opcode::v_cndmask_b32 && i == 2,
1299 "v_cndmask_b32 with a constant selector", instr.get());
1300 if (i == 0 || instr->isSDWA() || instr->opcode == aco_opcode::v_readlane_b32 ||
1301 instr->opcode == aco_opcode::v_writelane_b32) {
1302 instr->format = withoutDPP(instr->format);
1303 instr->operands[i] = op;
1304 continue;
1305 } else if (!instr->isVOP3() && can_swap_operands(instr, &instr->opcode)) {
1306 instr->operands[i] = instr->operands[0];
1307 instr->operands[0] = op;
1308 continue;
1309 } else if (can_use_VOP3(ctx, instr)) {
1310 to_VOP3(ctx, instr);
1311 instr->operands[i] = op;
1312 continue;
1313 }
1314 }
1315 }
1316
1317 /* MUBUF: propagate constants and combine additions */
1318 else if (instr->isMUBUF()) {
1319 MUBUF_instruction& mubuf = instr->mubuf();
1320 Temp base;
1321 uint32_t offset;
1322 while (info.is_temp())
1323 info = ctx.info[info.temp.id()];
1324
1325 /* According to AMDGPUDAGToDAGISel::SelectMUBUFScratchOffen(), vaddr
1326 * overflow for scratch accesses works only on GFX9+ and saddr overflow
1327 * never works. Since swizzling is the only thing that separates
1328 * scratch accesses and other accesses and swizzling changing how
1329 * addressing works significantly, this probably applies to swizzled
1330 * MUBUF accesses. */
1331 bool vaddr_prevent_overflow = mubuf.swizzled && ctx.program->chip_class < GFX9;
1332 bool saddr_prevent_overflow = mubuf.swizzled;
1333
1334 if (mubuf.offen && i == 1 && info.is_constant_or_literal(32) &&
1335 mubuf.offset + info.val < 4096) {
1336 assert(!mubuf.idxen);
1337 instr->operands[1] = Operand(v1);
1338 mubuf.offset += info.val;
1339 mubuf.offen = false;
1340 continue;
1341 } else if (i == 2 && info.is_constant_or_literal(32) && mubuf.offset + info.val < 4096) {
1342 instr->operands[2] = Operand::c32(0);
1343 mubuf.offset += info.val;
1344 continue;
1345 } else if (mubuf.offen && i == 1 &&
1346 parse_base_offset(ctx, instr.get(), i, &base, &offset,
1347 vaddr_prevent_overflow) &&
1348 base.regClass() == v1 && mubuf.offset + offset < 4096) {
1349 assert(!mubuf.idxen);
1350 instr->operands[1].setTemp(base);
1351 mubuf.offset += offset;
1352 continue;
1353 } else if (i == 2 &&
1354 parse_base_offset(ctx, instr.get(), i, &base, &offset,
1355 saddr_prevent_overflow) &&
1356 base.regClass() == s1 && mubuf.offset + offset < 4096) {
1357 instr->operands[i].setTemp(base);
1358 mubuf.offset += offset;
1359 continue;
1360 }
1361 }
1362
1363 /* DS: combine additions */
1364 else if (instr->isDS()) {
1365
1366 DS_instruction& ds = instr->ds();
1367 Temp base;
1368 uint32_t offset;
1369 bool has_usable_ds_offset = ctx.program->chip_class >= GFX7;
1370 if (has_usable_ds_offset && i == 0 &&
1371 parse_base_offset(ctx, instr.get(), i, &base, &offset, false) &&
1372 base.regClass() == instr->operands[i].regClass() &&
1373 instr->opcode != aco_opcode::ds_swizzle_b32) {
1374 if (instr->opcode == aco_opcode::ds_write2_b32 ||
1375 instr->opcode == aco_opcode::ds_read2_b32 ||
1376 instr->opcode == aco_opcode::ds_write2_b64 ||
1377 instr->opcode == aco_opcode::ds_read2_b64) {
1378 unsigned mask = (instr->opcode == aco_opcode::ds_write2_b64 ||
1379 instr->opcode == aco_opcode::ds_read2_b64)
1380 ? 0x7
1381 : 0x3;
1382 unsigned shifts = (instr->opcode == aco_opcode::ds_write2_b64 ||
1383 instr->opcode == aco_opcode::ds_read2_b64)
1384 ? 3
1385 : 2;
1386
1387 if ((offset & mask) == 0 && ds.offset0 + (offset >> shifts) <= 255 &&
1388 ds.offset1 + (offset >> shifts) <= 255) {
1389 instr->operands[i].setTemp(base);
1390 ds.offset0 += offset >> shifts;
1391 ds.offset1 += offset >> shifts;
1392 }
1393 } else {
1394 if (ds.offset0 + offset <= 65535) {
1395 instr->operands[i].setTemp(base);
1396 ds.offset0 += offset;
1397 }
1398 }
1399 }
1400 }
1401
1402 else if (instr->isBranch()) {
1403 if (ctx.info[instr->operands[0].tempId()].is_scc_invert()) {
1404 /* Flip the branch instruction to get rid of the scc_invert instruction */
1405 instr->opcode = instr->opcode == aco_opcode::p_cbranch_z ? aco_opcode::p_cbranch_nz
1406 : aco_opcode::p_cbranch_z;
1407 instr->operands[0].setTemp(ctx.info[instr->operands[0].tempId()].temp);
1408 }
1409 }
1410 }
1411
1412 /* if this instruction doesn't define anything, return */
1413 if (instr->definitions.empty()) {
1414 check_sdwa_extract(ctx, instr);
1415 return;
1416 }
1417
1418 if (instr->isVALU() || instr->isVINTRP()) {
1419 if (instr_info.can_use_output_modifiers[(int)instr->opcode] || instr->isVINTRP() ||
1420 instr->opcode == aco_opcode::v_cndmask_b32) {
1421 bool canonicalized = true;
1422 if (!does_fp_op_flush_denorms(ctx, instr->opcode)) {
1423 unsigned ops = instr->opcode == aco_opcode::v_cndmask_b32 ? 2 : instr->operands.size();
1424 for (unsigned i = 0; canonicalized && (i < ops); i++)
1425 canonicalized = is_op_canonicalized(ctx, instr->operands[i]);
1426 }
1427 if (canonicalized)
1428 ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1429 }
1430
1431 if (instr->isVOPC()) {
1432 ctx.info[instr->definitions[0].tempId()].set_vopc(instr.get());
1433 check_sdwa_extract(ctx, instr);
1434 return;
1435 }
1436 if (instr->isVOP3P()) {
1437 ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
1438 return;
1439 }
1440 }
1441
1442 switch (instr->opcode) {
1443 case aco_opcode::p_create_vector: {
1444 bool copy_prop = instr->operands.size() == 1 && instr->operands[0].isTemp() &&
1445 instr->operands[0].regClass() == instr->definitions[0].regClass();
1446 if (copy_prop) {
1447 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1448 break;
1449 }
1450
1451 /* expand vector operands */
1452 std::vector<Operand> ops;
1453 unsigned offset = 0;
1454 for (const Operand& op : instr->operands) {
1455 /* ensure that any expanded operands are properly aligned */
1456 bool aligned = offset % 4 == 0 || op.bytes() < 4;
1457 offset += op.bytes();
1458 if (aligned && op.isTemp() && ctx.info[op.tempId()].is_vec()) {
1459 Instruction* vec = ctx.info[op.tempId()].instr;
1460 for (const Operand& vec_op : vec->operands)
1461 ops.emplace_back(vec_op);
1462 } else {
1463 ops.emplace_back(op);
1464 }
1465 }
1466
1467 /* combine expanded operands to new vector */
1468 if (ops.size() != instr->operands.size()) {
1469 assert(ops.size() > instr->operands.size());
1470 Definition def = instr->definitions[0];
1471 instr.reset(create_instruction<Pseudo_instruction>(aco_opcode::p_create_vector,
1472 Format::PSEUDO, ops.size(), 1));
1473 for (unsigned i = 0; i < ops.size(); i++) {
1474 if (ops[i].isTemp() && ctx.info[ops[i].tempId()].is_temp() &&
1475 ops[i].regClass() == ctx.info[ops[i].tempId()].temp.regClass())
1476 ops[i].setTemp(ctx.info[ops[i].tempId()].temp);
1477 instr->operands[i] = ops[i];
1478 }
1479 instr->definitions[0] = def;
1480 } else {
1481 for (unsigned i = 0; i < ops.size(); i++) {
1482 assert(instr->operands[i] == ops[i]);
1483 }
1484 }
1485 ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1486 break;
1487 }
1488 case aco_opcode::p_split_vector: {
1489 ssa_info& info = ctx.info[instr->operands[0].tempId()];
1490
1491 if (info.is_constant_or_literal(32)) {
1492 uint64_t val = info.val;
1493 for (Definition def : instr->definitions) {
1494 uint32_t mask = u_bit_consecutive(0, def.bytes() * 8u);
1495 ctx.info[def.tempId()].set_constant(ctx.program->chip_class, val & mask);
1496 val >>= def.bytes() * 8u;
1497 }
1498 break;
1499 } else if (!info.is_vec()) {
1500 if (instr->operands[0].bytes() == 4 && instr->definitions.size() == 2) {
1501 /* D16 subdword split */
1502 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1503 if (instr->definitions[1].bytes() == 2)
1504 ctx.info[instr->definitions[1].tempId()].set_extract(instr.get());
1505 }
1506 break;
1507 }
1508
1509 Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1510 unsigned split_offset = 0;
1511 unsigned vec_offset = 0;
1512 unsigned vec_index = 0;
1513 for (unsigned i = 0; i < instr->definitions.size();
1514 split_offset += instr->definitions[i++].bytes()) {
1515 while (vec_offset < split_offset && vec_index < vec->operands.size())
1516 vec_offset += vec->operands[vec_index++].bytes();
1517
1518 if (vec_offset != split_offset ||
1519 vec->operands[vec_index].bytes() != instr->definitions[i].bytes())
1520 continue;
1521
1522 Operand vec_op = vec->operands[vec_index];
1523 if (vec_op.isConstant()) {
1524 ctx.info[instr->definitions[i].tempId()].set_constant(ctx.program->chip_class,
1525 vec_op.constantValue64());
1526 } else if (vec_op.isUndefined()) {
1527 ctx.info[instr->definitions[i].tempId()].set_undefined();
1528 } else {
1529 assert(vec_op.isTemp());
1530 ctx.info[instr->definitions[i].tempId()].set_temp(vec_op.getTemp());
1531 }
1532 }
1533 break;
1534 }
1535 case aco_opcode::p_extract_vector: { /* mov */
1536 ssa_info& info = ctx.info[instr->operands[0].tempId()];
1537 const unsigned index = instr->operands[1].constantValue();
1538 const unsigned dst_offset = index * instr->definitions[0].bytes();
1539
1540 if (info.is_vec()) {
1541 /* check if we index directly into a vector element */
1542 Instruction* vec = info.instr;
1543 unsigned offset = 0;
1544
1545 for (const Operand& op : vec->operands) {
1546 if (offset < dst_offset) {
1547 offset += op.bytes();
1548 continue;
1549 } else if (offset != dst_offset || op.bytes() != instr->definitions[0].bytes()) {
1550 break;
1551 }
1552 instr->operands[0] = op;
1553 break;
1554 }
1555 } else if (info.is_constant_or_literal(32)) {
1556 /* propagate constants */
1557 uint32_t mask = u_bit_consecutive(0, instr->definitions[0].bytes() * 8u);
1558 uint32_t val = (info.val >> (dst_offset * 8u)) & mask;
1559 instr->operands[0] =
1560 Operand::get_const(ctx.program->chip_class, val, instr->definitions[0].bytes());
1561 ;
1562 }
1563
1564 if (instr->operands[0].bytes() != instr->definitions[0].bytes()) {
1565 if (instr->operands[0].size() != 1)
1566 break;
1567
1568 if (index == 0)
1569 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1570 else
1571 ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
1572 break;
1573 }
1574
1575 /* convert this extract into a copy instruction */
1576 instr->opcode = aco_opcode::p_parallelcopy;
1577 instr->operands.pop_back();
1578 FALLTHROUGH;
1579 }
1580 case aco_opcode::p_parallelcopy: /* propagate */
1581 if (instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_vec() &&
1582 instr->operands[0].regClass() != instr->definitions[0].regClass()) {
1583 /* We might not be able to copy-propagate if it's a SGPR->VGPR copy, so
1584 * duplicate the vector instead.
1585 */
1586 Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1587 aco_ptr<Instruction> old_copy = std::move(instr);
1588
1589 instr.reset(create_instruction<Pseudo_instruction>(
1590 aco_opcode::p_create_vector, Format::PSEUDO, vec->operands.size(), 1));
1591 instr->definitions[0] = old_copy->definitions[0];
1592 std::copy(vec->operands.begin(), vec->operands.end(), instr->operands.begin());
1593 for (unsigned i = 0; i < vec->operands.size(); i++) {
1594 Operand& op = instr->operands[i];
1595 if (op.isTemp() && ctx.info[op.tempId()].is_temp() &&
1596 ctx.info[op.tempId()].temp.type() == instr->definitions[0].regClass().type())
1597 op.setTemp(ctx.info[op.tempId()].temp);
1598 }
1599 ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1600 break;
1601 }
1602 FALLTHROUGH;
1603 case aco_opcode::p_as_uniform:
1604 if (instr->definitions[0].isFixed()) {
1605 /* don't copy-propagate copies into fixed registers */
1606 } else if (instr->usesModifiers()) {
1607 // TODO
1608 } else if (instr->operands[0].isConstant()) {
1609 ctx.info[instr->definitions[0].tempId()].set_constant(
1610 ctx.program->chip_class, instr->operands[0].constantValue64());
1611 } else if (instr->operands[0].isTemp()) {
1612 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1613 if (ctx.info[instr->operands[0].tempId()].is_canonicalized())
1614 ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1615 } else {
1616 assert(instr->operands[0].isFixed());
1617 }
1618 break;
1619 case aco_opcode::v_mov_b32:
1620 if (instr->isDPP16()) {
1621 /* anything else doesn't make sense in SSA */
1622 assert(instr->dpp16().row_mask == 0xf && instr->dpp16().bank_mask == 0xf);
1623 ctx.info[instr->definitions[0].tempId()].set_dpp16(instr.get());
1624 } else if (instr->isDPP8()) {
1625 ctx.info[instr->definitions[0].tempId()].set_dpp8(instr.get());
1626 }
1627 break;
1628 case aco_opcode::p_is_helper:
1629 if (!ctx.program->needs_wqm)
1630 ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->chip_class, 0u);
1631 break;
1632 case aco_opcode::v_mul_f64: ctx.info[instr->definitions[0].tempId()].set_mul(instr.get()); break;
1633 case aco_opcode::v_mul_f16:
1634 case aco_opcode::v_mul_f32:
1635 case aco_opcode::v_mul_legacy_f32: { /* omod */
1636 ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
1637
1638 /* TODO: try to move the negate/abs modifier to the consumer instead */
1639 bool uses_mods = instr->usesModifiers();
1640 bool fp16 = instr->opcode == aco_opcode::v_mul_f16;
1641
1642 for (unsigned i = 0; i < 2; i++) {
1643 if (instr->operands[!i].isConstant() && instr->operands[i].isTemp()) {
1644 if (!instr->isDPP() && !instr->isSDWA() &&
1645 (instr->operands[!i].constantEquals(fp16 ? 0x3c00 : 0x3f800000) || /* 1.0 */
1646 instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u))) { /* -1.0 */
1647 bool neg1 = instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u);
1648
1649 VOP3_instruction* vop3 = instr->isVOP3() ? &instr->vop3() : NULL;
1650 if (vop3 && (vop3->abs[!i] || vop3->neg[!i] || vop3->clamp || vop3->omod))
1651 continue;
1652
1653 bool abs = vop3 && vop3->abs[i];
1654 bool neg = neg1 ^ (vop3 && vop3->neg[i]);
1655
1656 Temp other = instr->operands[i].getTemp();
1657 if (abs && neg && other.type() == RegType::vgpr)
1658 ctx.info[instr->definitions[0].tempId()].set_neg_abs(other);
1659 else if (abs && !neg && other.type() == RegType::vgpr)
1660 ctx.info[instr->definitions[0].tempId()].set_abs(other);
1661 else if (!abs && neg && other.type() == RegType::vgpr)
1662 ctx.info[instr->definitions[0].tempId()].set_neg(other);
1663 else if (!abs && !neg)
1664 ctx.info[instr->definitions[0].tempId()].set_fcanonicalize(other);
1665 } else if (uses_mods) {
1666 continue;
1667 } else if (instr->operands[!i].constantValue() ==
1668 (fp16 ? 0x4000 : 0x40000000)) { /* 2.0 */
1669 ctx.info[instr->operands[i].tempId()].set_omod2(instr.get());
1670 } else if (instr->operands[!i].constantValue() ==
1671 (fp16 ? 0x4400 : 0x40800000)) { /* 4.0 */
1672 ctx.info[instr->operands[i].tempId()].set_omod4(instr.get());
1673 } else if (instr->operands[!i].constantValue() ==
1674 (fp16 ? 0x3800 : 0x3f000000)) { /* 0.5 */
1675 ctx.info[instr->operands[i].tempId()].set_omod5(instr.get());
1676 } else if (instr->operands[!i].constantValue() == 0u &&
1677 (!(fp16 ? ctx.fp_mode.preserve_signed_zero_inf_nan16_64
1678 : ctx.fp_mode.preserve_signed_zero_inf_nan32) ||
1679 instr->opcode == aco_opcode::v_mul_legacy_f32)) { /* 0.0 */
1680 ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->chip_class, 0u);
1681 } else {
1682 continue;
1683 }
1684 break;
1685 }
1686 }
1687 break;
1688 }
1689 case aco_opcode::v_mul_lo_u16:
1690 case aco_opcode::v_mul_lo_u16_e64:
1691 case aco_opcode::v_mul_u32_u24:
1692 ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
1693 break;
1694 case aco_opcode::v_med3_f16:
1695 case aco_opcode::v_med3_f32: { /* clamp */
1696 VOP3_instruction& vop3 = instr->vop3();
1697 if (vop3.abs[0] || vop3.abs[1] || vop3.abs[2] || vop3.neg[0] || vop3.neg[1] || vop3.neg[2] ||
1698 vop3.omod != 0 || vop3.opsel != 0)
1699 break;
1700
1701 unsigned idx = 0;
1702 bool found_zero = false, found_one = false;
1703 bool is_fp16 = instr->opcode == aco_opcode::v_med3_f16;
1704 for (unsigned i = 0; i < 3; i++) {
1705 if (instr->operands[i].constantEquals(0))
1706 found_zero = true;
1707 else if (instr->operands[i].constantEquals(is_fp16 ? 0x3c00 : 0x3f800000)) /* 1.0 */
1708 found_one = true;
1709 else
1710 idx = i;
1711 }
1712 if (found_zero && found_one && instr->operands[idx].isTemp())
1713 ctx.info[instr->operands[idx].tempId()].set_clamp(instr.get());
1714 break;
1715 }
1716 case aco_opcode::v_cndmask_b32:
1717 if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(0xFFFFFFFF))
1718 ctx.info[instr->definitions[0].tempId()].set_vcc(instr->operands[2].getTemp());
1719 else if (instr->operands[0].constantEquals(0) &&
1720 instr->operands[1].constantEquals(0x3f800000u))
1721 ctx.info[instr->definitions[0].tempId()].set_b2f(instr->operands[2].getTemp());
1722 else if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(1))
1723 ctx.info[instr->definitions[0].tempId()].set_b2i(instr->operands[2].getTemp());
1724
1725 ctx.info[instr->operands[2].tempId()].set_vcc_hint();
1726 break;
1727 case aco_opcode::v_cmp_lg_u32:
1728 if (instr->format == Format::VOPC && /* don't optimize VOP3 / SDWA / DPP */
1729 instr->operands[0].constantEquals(0) && instr->operands[1].isTemp() &&
1730 ctx.info[instr->operands[1].tempId()].is_vcc())
1731 ctx.info[instr->definitions[0].tempId()].set_temp(
1732 ctx.info[instr->operands[1].tempId()].temp);
1733 break;
1734 case aco_opcode::p_linear_phi: {
1735 /* lower_bool_phis() can create phis like this */
1736 bool all_same_temp = instr->operands[0].isTemp();
1737 /* this check is needed when moving uniform loop counters out of a divergent loop */
1738 if (all_same_temp)
1739 all_same_temp = instr->definitions[0].regClass() == instr->operands[0].regClass();
1740 for (unsigned i = 1; all_same_temp && (i < instr->operands.size()); i++) {
1741 if (!instr->operands[i].isTemp() ||
1742 instr->operands[i].tempId() != instr->operands[0].tempId())
1743 all_same_temp = false;
1744 }
1745 if (all_same_temp) {
1746 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1747 } else {
1748 bool all_undef = instr->operands[0].isUndefined();
1749 for (unsigned i = 1; all_undef && (i < instr->operands.size()); i++) {
1750 if (!instr->operands[i].isUndefined())
1751 all_undef = false;
1752 }
1753 if (all_undef)
1754 ctx.info[instr->definitions[0].tempId()].set_undefined();
1755 }
1756 break;
1757 }
1758 case aco_opcode::v_add_u32:
1759 case aco_opcode::v_add_co_u32:
1760 case aco_opcode::v_add_co_u32_e64:
1761 case aco_opcode::s_add_i32:
1762 case aco_opcode::s_add_u32:
1763 case aco_opcode::v_subbrev_co_u32:
1764 ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
1765 break;
1766 case aco_opcode::s_not_b32:
1767 case aco_opcode::s_not_b64:
1768 if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
1769 ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1770 ctx.info[instr->definitions[1].tempId()].set_scc_invert(
1771 ctx.info[instr->operands[0].tempId()].temp);
1772 } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
1773 ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1774 ctx.info[instr->definitions[1].tempId()].set_scc_invert(
1775 ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1776 }
1777 ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
1778 break;
1779 case aco_opcode::s_and_b32:
1780 case aco_opcode::s_and_b64:
1781 if (fixed_to_exec(instr->operands[1]) && instr->operands[0].isTemp()) {
1782 if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
1783 /* Try to get rid of the superfluous s_cselect + s_and_b64 that comes from turning a
1784 * uniform bool into divergent */
1785 ctx.info[instr->definitions[1].tempId()].set_temp(
1786 ctx.info[instr->operands[0].tempId()].temp);
1787 ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
1788 ctx.info[instr->operands[0].tempId()].temp);
1789 break;
1790 } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
1791 /* Try to get rid of the superfluous s_and_b64, since the uniform bitwise instruction
1792 * already produces the same SCC */
1793 ctx.info[instr->definitions[1].tempId()].set_temp(
1794 ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1795 ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
1796 ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1797 break;
1798 } else if ((ctx.program->stage.num_sw_stages() > 1 ||
1799 ctx.program->stage.hw == HWStage::NGG) &&
1800 instr->pass_flags == 1) {
1801 /* In case of merged shaders, pass_flags=1 means that all lanes are active (exec=-1), so
1802 * s_and is unnecessary. */
1803 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1804 break;
1805 } else if (ctx.info[instr->operands[0].tempId()].is_vopc()) {
1806 Instruction* vopc_instr = ctx.info[instr->operands[0].tempId()].instr;
1807 /* Remove superfluous s_and when the VOPC instruction uses the same exec and thus
1808 * already produces the same result */
1809 if (vopc_instr->pass_flags == instr->pass_flags) {
1810 assert(instr->pass_flags > 0);
1811 ctx.info[instr->definitions[0].tempId()].set_temp(
1812 vopc_instr->definitions[0].getTemp());
1813 break;
1814 }
1815 }
1816 }
1817 FALLTHROUGH;
1818 case aco_opcode::s_or_b32:
1819 case aco_opcode::s_or_b64:
1820 case aco_opcode::s_xor_b32:
1821 case aco_opcode::s_xor_b64:
1822 if (std::all_of(instr->operands.begin(), instr->operands.end(),
1823 [&ctx](const Operand& op)
1824 {
1825 return op.isTemp() && (ctx.info[op.tempId()].is_uniform_bool() ||
1826 ctx.info[op.tempId()].is_uniform_bitwise());
1827 })) {
1828 ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1829 }
1830 FALLTHROUGH;
1831 case aco_opcode::s_lshl_b32:
1832 case aco_opcode::v_or_b32:
1833 case aco_opcode::v_lshlrev_b32:
1834 case aco_opcode::v_bcnt_u32_b32:
1835 case aco_opcode::v_and_b32:
1836 case aco_opcode::v_xor_b32:
1837 ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
1838 break;
1839 case aco_opcode::v_min_f32:
1840 case aco_opcode::v_min_f16:
1841 case aco_opcode::v_min_u32:
1842 case aco_opcode::v_min_i32:
1843 case aco_opcode::v_min_u16:
1844 case aco_opcode::v_min_i16:
1845 case aco_opcode::v_max_f32:
1846 case aco_opcode::v_max_f16:
1847 case aco_opcode::v_max_u32:
1848 case aco_opcode::v_max_i32:
1849 case aco_opcode::v_max_u16:
1850 case aco_opcode::v_max_i16:
1851 ctx.info[instr->definitions[0].tempId()].set_minmax(instr.get());
1852 break;
1853 case aco_opcode::s_cselect_b64:
1854 case aco_opcode::s_cselect_b32:
1855 if (instr->operands[0].constantEquals((unsigned)-1) && instr->operands[1].constantEquals(0)) {
1856 /* Found a cselect that operates on a uniform bool that comes from eg. s_cmp */
1857 ctx.info[instr->definitions[0].tempId()].set_uniform_bool(instr->operands[2].getTemp());
1858 }
1859 if (instr->operands[2].isTemp() && ctx.info[instr->operands[2].tempId()].is_scc_invert()) {
1860 /* Flip the operands to get rid of the scc_invert instruction */
1861 std::swap(instr->operands[0], instr->operands[1]);
1862 instr->operands[2].setTemp(ctx.info[instr->operands[2].tempId()].temp);
1863 }
1864 break;
1865 case aco_opcode::p_wqm:
1866 if (instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_scc_invert()) {
1867 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1868 }
1869 break;
1870 case aco_opcode::s_mul_i32:
1871 /* Testing every uint32_t shows that 0x3f800000*n is never a denormal.
1872 * This pattern is created from a uniform nir_op_b2f. */
1873 if (instr->operands[0].constantEquals(0x3f800000u))
1874 ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1875 break;
1876 case aco_opcode::p_extract: {
1877 if (instr->definitions[0].bytes() == 4) {
1878 ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
1879 if (instr->operands[0].regClass() == v1 && parse_insert(instr.get()))
1880 ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
1881 }
1882 break;
1883 }
1884 case aco_opcode::p_insert: {
1885 if (instr->operands[0].bytes() == 4) {
1886 if (instr->operands[0].regClass() == v1)
1887 ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
1888 if (parse_extract(instr.get()))
1889 ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
1890 ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
1891 }
1892 break;
1893 }
1894 case aco_opcode::ds_read_u8:
1895 case aco_opcode::ds_read_u8_d16:
1896 case aco_opcode::ds_read_u16:
1897 case aco_opcode::ds_read_u16_d16: {
1898 ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
1899 break;
1900 }
1901 case aco_opcode::v_cvt_f16_f32: {
1902 if (instr->operands[0].isTemp())
1903 ctx.info[instr->operands[0].tempId()].set_f2f16(instr.get());
1904 break;
1905 }
1906 case aco_opcode::v_cvt_f32_f16: {
1907 if (instr->operands[0].isTemp())
1908 ctx.info[instr->definitions[0].tempId()].set_f2f32(instr.get());
1909 break;
1910 }
1911 default: break;
1912 }
1913
1914 /* Don't remove label_extract if we can't apply the extract to
1915 * neg/abs instructions because we'll likely combine it into another valu. */
1916 if (!(ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)))
1917 check_sdwa_extract(ctx, instr);
1918 }
1919
1920 unsigned
original_temp_id(opt_ctx & ctx,Temp tmp)1921 original_temp_id(opt_ctx& ctx, Temp tmp)
1922 {
1923 if (ctx.info[tmp.id()].is_temp())
1924 return ctx.info[tmp.id()].temp.id();
1925 else
1926 return tmp.id();
1927 }
1928
1929 void
decrease_uses(opt_ctx & ctx,Instruction * instr)1930 decrease_uses(opt_ctx& ctx, Instruction* instr)
1931 {
1932 if (!--ctx.uses[instr->definitions[0].tempId()]) {
1933 for (const Operand& op : instr->operands) {
1934 if (op.isTemp())
1935 ctx.uses[op.tempId()]--;
1936 }
1937 }
1938 }
1939
1940 Instruction*
follow_operand(opt_ctx & ctx,Operand op,bool ignore_uses=false)1941 follow_operand(opt_ctx& ctx, Operand op, bool ignore_uses = false)
1942 {
1943 if (!op.isTemp() || !(ctx.info[op.tempId()].label & instr_usedef_labels))
1944 return nullptr;
1945 if (!ignore_uses && ctx.uses[op.tempId()] > 1)
1946 return nullptr;
1947
1948 Instruction* instr = ctx.info[op.tempId()].instr;
1949
1950 if (instr->definitions.size() == 2) {
1951 assert(instr->definitions[0].isTemp() && instr->definitions[0].tempId() == op.tempId());
1952 if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
1953 return nullptr;
1954 }
1955
1956 return instr;
1957 }
1958
1959 /* s_or_b64(neq(a, a), neq(b, b)) -> v_cmp_u_f32(a, b)
1960 * s_and_b64(eq(a, a), eq(b, b)) -> v_cmp_o_f32(a, b) */
1961 bool
combine_ordering_test(opt_ctx & ctx,aco_ptr<Instruction> & instr)1962 combine_ordering_test(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1963 {
1964 if (instr->definitions[0].regClass() != ctx.program->lane_mask)
1965 return false;
1966 if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
1967 return false;
1968
1969 bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
1970
1971 bool neg[2] = {false, false};
1972 bool abs[2] = {false, false};
1973 uint8_t opsel = 0;
1974 Instruction* op_instr[2];
1975 Temp op[2];
1976
1977 unsigned bitsize = 0;
1978 for (unsigned i = 0; i < 2; i++) {
1979 op_instr[i] = follow_operand(ctx, instr->operands[i], true);
1980 if (!op_instr[i])
1981 return false;
1982
1983 aco_opcode expected_cmp = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
1984 unsigned op_bitsize = get_cmp_bitsize(op_instr[i]->opcode);
1985
1986 if (get_f32_cmp(op_instr[i]->opcode) != expected_cmp)
1987 return false;
1988 if (bitsize && op_bitsize != bitsize)
1989 return false;
1990 if (!op_instr[i]->operands[0].isTemp() || !op_instr[i]->operands[1].isTemp())
1991 return false;
1992
1993 if (op_instr[i]->isVOP3()) {
1994 VOP3_instruction& vop3 = op_instr[i]->vop3();
1995 if (vop3.neg[0] != vop3.neg[1] || vop3.abs[0] != vop3.abs[1] || vop3.opsel == 1 ||
1996 vop3.opsel == 2)
1997 return false;
1998 neg[i] = vop3.neg[0];
1999 abs[i] = vop3.abs[0];
2000 opsel |= (vop3.opsel & 1) << i;
2001 } else if (op_instr[i]->isSDWA()) {
2002 return false;
2003 }
2004
2005 Temp op0 = op_instr[i]->operands[0].getTemp();
2006 Temp op1 = op_instr[i]->operands[1].getTemp();
2007 if (original_temp_id(ctx, op0) != original_temp_id(ctx, op1))
2008 return false;
2009
2010 op[i] = op1;
2011 bitsize = op_bitsize;
2012 }
2013
2014 if (op[1].type() == RegType::sgpr)
2015 std::swap(op[0], op[1]);
2016 unsigned num_sgprs = (op[0].type() == RegType::sgpr) + (op[1].type() == RegType::sgpr);
2017 if (num_sgprs > (ctx.program->chip_class >= GFX10 ? 2 : 1))
2018 return false;
2019
2020 ctx.uses[op[0].id()]++;
2021 ctx.uses[op[1].id()]++;
2022 decrease_uses(ctx, op_instr[0]);
2023 decrease_uses(ctx, op_instr[1]);
2024
2025 aco_opcode new_op = aco_opcode::num_opcodes;
2026 switch (bitsize) {
2027 case 16: new_op = is_or ? aco_opcode::v_cmp_u_f16 : aco_opcode::v_cmp_o_f16; break;
2028 case 32: new_op = is_or ? aco_opcode::v_cmp_u_f32 : aco_opcode::v_cmp_o_f32; break;
2029 case 64: new_op = is_or ? aco_opcode::v_cmp_u_f64 : aco_opcode::v_cmp_o_f64; break;
2030 }
2031 Instruction* new_instr;
2032 if (neg[0] || neg[1] || abs[0] || abs[1] || opsel || num_sgprs > 1) {
2033 VOP3_instruction* vop3 =
2034 create_instruction<VOP3_instruction>(new_op, asVOP3(Format::VOPC), 2, 1);
2035 for (unsigned i = 0; i < 2; i++) {
2036 vop3->neg[i] = neg[i];
2037 vop3->abs[i] = abs[i];
2038 }
2039 vop3->opsel = opsel;
2040 new_instr = static_cast<Instruction*>(vop3);
2041 } else {
2042 new_instr = create_instruction<VOPC_instruction>(new_op, Format::VOPC, 2, 1);
2043 instr->definitions[0].setHint(vcc);
2044 }
2045 new_instr->operands[0] = Operand(op[0]);
2046 new_instr->operands[1] = Operand(op[1]);
2047 new_instr->definitions[0] = instr->definitions[0];
2048
2049 ctx.info[instr->definitions[0].tempId()].label = 0;
2050 ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2051
2052 instr.reset(new_instr);
2053
2054 return true;
2055 }
2056
2057 /* s_or_b64(v_cmp_u_f32(a, b), cmp(a, b)) -> get_unordered(cmp)(a, b)
2058 * s_and_b64(v_cmp_o_f32(a, b), cmp(a, b)) -> get_ordered(cmp)(a, b) */
2059 bool
combine_comparison_ordering(opt_ctx & ctx,aco_ptr<Instruction> & instr)2060 combine_comparison_ordering(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2061 {
2062 if (instr->definitions[0].regClass() != ctx.program->lane_mask)
2063 return false;
2064 if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2065 return false;
2066
2067 bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
2068 aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_u_f32 : aco_opcode::v_cmp_o_f32;
2069
2070 Instruction* nan_test = follow_operand(ctx, instr->operands[0], true);
2071 Instruction* cmp = follow_operand(ctx, instr->operands[1], true);
2072 if (!nan_test || !cmp)
2073 return false;
2074 if (nan_test->isSDWA() || cmp->isSDWA())
2075 return false;
2076
2077 if (get_f32_cmp(cmp->opcode) == expected_nan_test)
2078 std::swap(nan_test, cmp);
2079 else if (get_f32_cmp(nan_test->opcode) != expected_nan_test)
2080 return false;
2081
2082 if (!is_cmp(cmp->opcode) || get_cmp_bitsize(cmp->opcode) != get_cmp_bitsize(nan_test->opcode))
2083 return false;
2084
2085 if (!nan_test->operands[0].isTemp() || !nan_test->operands[1].isTemp())
2086 return false;
2087 if (!cmp->operands[0].isTemp() || !cmp->operands[1].isTemp())
2088 return false;
2089
2090 unsigned prop_cmp0 = original_temp_id(ctx, cmp->operands[0].getTemp());
2091 unsigned prop_cmp1 = original_temp_id(ctx, cmp->operands[1].getTemp());
2092 unsigned prop_nan0 = original_temp_id(ctx, nan_test->operands[0].getTemp());
2093 unsigned prop_nan1 = original_temp_id(ctx, nan_test->operands[1].getTemp());
2094 if (prop_cmp0 != prop_nan0 && prop_cmp0 != prop_nan1)
2095 return false;
2096 if (prop_cmp1 != prop_nan0 && prop_cmp1 != prop_nan1)
2097 return false;
2098
2099 ctx.uses[cmp->operands[0].tempId()]++;
2100 ctx.uses[cmp->operands[1].tempId()]++;
2101 decrease_uses(ctx, nan_test);
2102 decrease_uses(ctx, cmp);
2103
2104 aco_opcode new_op = is_or ? get_unordered(cmp->opcode) : get_ordered(cmp->opcode);
2105 Instruction* new_instr;
2106 if (cmp->isVOP3()) {
2107 VOP3_instruction* new_vop3 =
2108 create_instruction<VOP3_instruction>(new_op, asVOP3(Format::VOPC), 2, 1);
2109 VOP3_instruction& cmp_vop3 = cmp->vop3();
2110 memcpy(new_vop3->abs, cmp_vop3.abs, sizeof(new_vop3->abs));
2111 memcpy(new_vop3->neg, cmp_vop3.neg, sizeof(new_vop3->neg));
2112 new_vop3->clamp = cmp_vop3.clamp;
2113 new_vop3->omod = cmp_vop3.omod;
2114 new_vop3->opsel = cmp_vop3.opsel;
2115 new_instr = new_vop3;
2116 } else {
2117 new_instr = create_instruction<VOPC_instruction>(new_op, Format::VOPC, 2, 1);
2118 instr->definitions[0].setHint(vcc);
2119 }
2120 new_instr->operands[0] = cmp->operands[0];
2121 new_instr->operands[1] = cmp->operands[1];
2122 new_instr->definitions[0] = instr->definitions[0];
2123
2124 ctx.info[instr->definitions[0].tempId()].label = 0;
2125 ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2126
2127 instr.reset(new_instr);
2128
2129 return true;
2130 }
2131
2132 bool
is_operand_constant(opt_ctx & ctx,Operand op,unsigned bit_size,uint64_t * value)2133 is_operand_constant(opt_ctx& ctx, Operand op, unsigned bit_size, uint64_t* value)
2134 {
2135 if (op.isConstant()) {
2136 *value = op.constantValue64();
2137 return true;
2138 } else if (op.isTemp()) {
2139 unsigned id = original_temp_id(ctx, op.getTemp());
2140 if (!ctx.info[id].is_constant_or_literal(bit_size))
2141 return false;
2142 *value = get_constant_op(ctx, ctx.info[id], bit_size).constantValue64();
2143 return true;
2144 }
2145 return false;
2146 }
2147
2148 bool
is_constant_nan(uint64_t value,unsigned bit_size)2149 is_constant_nan(uint64_t value, unsigned bit_size)
2150 {
2151 if (bit_size == 16)
2152 return ((value >> 10) & 0x1f) == 0x1f && (value & 0x3ff);
2153 else if (bit_size == 32)
2154 return ((value >> 23) & 0xff) == 0xff && (value & 0x7fffff);
2155 else
2156 return ((value >> 52) & 0x7ff) == 0x7ff && (value & 0xfffffffffffff);
2157 }
2158
2159 /* s_or_b64(v_cmp_neq_f32(a, a), cmp(a, #b)) and b is not NaN -> get_unordered(cmp)(a, b)
2160 * s_and_b64(v_cmp_eq_f32(a, a), cmp(a, #b)) and b is not NaN -> get_ordered(cmp)(a, b) */
2161 bool
combine_constant_comparison_ordering(opt_ctx & ctx,aco_ptr<Instruction> & instr)2162 combine_constant_comparison_ordering(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2163 {
2164 if (instr->definitions[0].regClass() != ctx.program->lane_mask)
2165 return false;
2166 if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2167 return false;
2168
2169 bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
2170
2171 Instruction* nan_test = follow_operand(ctx, instr->operands[0], true);
2172 Instruction* cmp = follow_operand(ctx, instr->operands[1], true);
2173
2174 if (!nan_test || !cmp || nan_test->isSDWA() || cmp->isSDWA())
2175 return false;
2176 if (nan_test->isSDWA() || cmp->isSDWA())
2177 return false;
2178
2179 aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
2180 if (get_f32_cmp(cmp->opcode) == expected_nan_test)
2181 std::swap(nan_test, cmp);
2182 else if (get_f32_cmp(nan_test->opcode) != expected_nan_test)
2183 return false;
2184
2185 unsigned bit_size = get_cmp_bitsize(cmp->opcode);
2186 if (!is_cmp(cmp->opcode) || get_cmp_bitsize(nan_test->opcode) != bit_size)
2187 return false;
2188
2189 if (!nan_test->operands[0].isTemp() || !nan_test->operands[1].isTemp())
2190 return false;
2191 if (!cmp->operands[0].isTemp() && !cmp->operands[1].isTemp())
2192 return false;
2193
2194 unsigned prop_nan0 = original_temp_id(ctx, nan_test->operands[0].getTemp());
2195 unsigned prop_nan1 = original_temp_id(ctx, nan_test->operands[1].getTemp());
2196 if (prop_nan0 != prop_nan1)
2197 return false;
2198
2199 if (nan_test->isVOP3()) {
2200 VOP3_instruction& vop3 = nan_test->vop3();
2201 if (vop3.neg[0] != vop3.neg[1] || vop3.abs[0] != vop3.abs[1] || vop3.opsel == 1 ||
2202 vop3.opsel == 2)
2203 return false;
2204 }
2205
2206 int constant_operand = -1;
2207 for (unsigned i = 0; i < 2; i++) {
2208 if (cmp->operands[i].isTemp() &&
2209 original_temp_id(ctx, cmp->operands[i].getTemp()) == prop_nan0) {
2210 constant_operand = !i;
2211 break;
2212 }
2213 }
2214 if (constant_operand == -1)
2215 return false;
2216
2217 uint64_t constant_value;
2218 if (!is_operand_constant(ctx, cmp->operands[constant_operand], bit_size, &constant_value))
2219 return false;
2220 if (is_constant_nan(constant_value, bit_size))
2221 return false;
2222
2223 if (cmp->operands[0].isTemp())
2224 ctx.uses[cmp->operands[0].tempId()]++;
2225 if (cmp->operands[1].isTemp())
2226 ctx.uses[cmp->operands[1].tempId()]++;
2227 decrease_uses(ctx, nan_test);
2228 decrease_uses(ctx, cmp);
2229
2230 aco_opcode new_op = is_or ? get_unordered(cmp->opcode) : get_ordered(cmp->opcode);
2231 Instruction* new_instr;
2232 if (cmp->isVOP3()) {
2233 VOP3_instruction* new_vop3 =
2234 create_instruction<VOP3_instruction>(new_op, asVOP3(Format::VOPC), 2, 1);
2235 VOP3_instruction& cmp_vop3 = cmp->vop3();
2236 memcpy(new_vop3->abs, cmp_vop3.abs, sizeof(new_vop3->abs));
2237 memcpy(new_vop3->neg, cmp_vop3.neg, sizeof(new_vop3->neg));
2238 new_vop3->clamp = cmp_vop3.clamp;
2239 new_vop3->omod = cmp_vop3.omod;
2240 new_vop3->opsel = cmp_vop3.opsel;
2241 new_instr = new_vop3;
2242 } else {
2243 new_instr = create_instruction<VOPC_instruction>(new_op, Format::VOPC, 2, 1);
2244 instr->definitions[0].setHint(vcc);
2245 }
2246 new_instr->operands[0] = cmp->operands[0];
2247 new_instr->operands[1] = cmp->operands[1];
2248 new_instr->definitions[0] = instr->definitions[0];
2249
2250 ctx.info[instr->definitions[0].tempId()].label = 0;
2251 ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2252
2253 instr.reset(new_instr);
2254
2255 return true;
2256 }
2257
2258 /* s_andn2(exec, cmp(a, b)) -> get_inverse(cmp)(a, b) */
2259 bool
combine_inverse_comparison(opt_ctx & ctx,aco_ptr<Instruction> & instr)2260 combine_inverse_comparison(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2261 {
2262 if (!instr->operands[0].isFixed() || instr->operands[0].physReg() != exec)
2263 return false;
2264 if (ctx.uses[instr->definitions[1].tempId()])
2265 return false;
2266
2267 Instruction* cmp = follow_operand(ctx, instr->operands[1]);
2268 if (!cmp)
2269 return false;
2270
2271 aco_opcode new_opcode = get_inverse(cmp->opcode);
2272 if (new_opcode == aco_opcode::num_opcodes)
2273 return false;
2274
2275 if (cmp->operands[0].isTemp())
2276 ctx.uses[cmp->operands[0].tempId()]++;
2277 if (cmp->operands[1].isTemp())
2278 ctx.uses[cmp->operands[1].tempId()]++;
2279 decrease_uses(ctx, cmp);
2280
2281 /* This creates a new instruction instead of modifying the existing
2282 * comparison so that the comparison is done with the correct exec mask. */
2283 Instruction* new_instr;
2284 if (cmp->isVOP3()) {
2285 VOP3_instruction* new_vop3 =
2286 create_instruction<VOP3_instruction>(new_opcode, asVOP3(Format::VOPC), 2, 1);
2287 VOP3_instruction& cmp_vop3 = cmp->vop3();
2288 memcpy(new_vop3->abs, cmp_vop3.abs, sizeof(new_vop3->abs));
2289 memcpy(new_vop3->neg, cmp_vop3.neg, sizeof(new_vop3->neg));
2290 new_vop3->clamp = cmp_vop3.clamp;
2291 new_vop3->omod = cmp_vop3.omod;
2292 new_vop3->opsel = cmp_vop3.opsel;
2293 new_instr = new_vop3;
2294 } else if (cmp->isSDWA()) {
2295 SDWA_instruction* new_sdwa = create_instruction<SDWA_instruction>(
2296 new_opcode, (Format)((uint16_t)Format::SDWA | (uint16_t)Format::VOPC), 2, 1);
2297 SDWA_instruction& cmp_sdwa = cmp->sdwa();
2298 memcpy(new_sdwa->abs, cmp_sdwa.abs, sizeof(new_sdwa->abs));
2299 memcpy(new_sdwa->sel, cmp_sdwa.sel, sizeof(new_sdwa->sel));
2300 memcpy(new_sdwa->neg, cmp_sdwa.neg, sizeof(new_sdwa->neg));
2301 new_sdwa->dst_sel = cmp_sdwa.dst_sel;
2302 new_sdwa->clamp = cmp_sdwa.clamp;
2303 new_sdwa->omod = cmp_sdwa.omod;
2304 new_instr = new_sdwa;
2305 } else if (cmp->isDPP16()) {
2306 DPP16_instruction* new_dpp = create_instruction<DPP16_instruction>(
2307 new_opcode, (Format)((uint16_t)Format::DPP16 | (uint16_t)Format::VOPC), 2, 1);
2308 DPP16_instruction& cmp_dpp = cmp->dpp16();
2309 memcpy(new_dpp->abs, cmp_dpp.abs, sizeof(new_dpp->abs));
2310 memcpy(new_dpp->neg, cmp_dpp.neg, sizeof(new_dpp->neg));
2311 new_dpp->dpp_ctrl = cmp_dpp.dpp_ctrl;
2312 new_dpp->row_mask = cmp_dpp.row_mask;
2313 new_dpp->bank_mask = cmp_dpp.bank_mask;
2314 new_dpp->bound_ctrl = cmp_dpp.bound_ctrl;
2315 new_instr = new_dpp;
2316 } else if (cmp->isDPP8()) {
2317 DPP8_instruction* new_dpp = create_instruction<DPP8_instruction>(
2318 new_opcode, (Format)((uint16_t)Format::DPP8 | (uint16_t)Format::VOPC), 2, 1);
2319 DPP8_instruction& cmp_dpp = cmp->dpp8();
2320 memcpy(new_dpp->lane_sel, cmp_dpp.lane_sel, sizeof(new_dpp->lane_sel));
2321 new_instr = new_dpp;
2322 } else {
2323 new_instr = create_instruction<VOPC_instruction>(new_opcode, Format::VOPC, 2, 1);
2324 instr->definitions[0].setHint(vcc);
2325 }
2326 new_instr->operands[0] = cmp->operands[0];
2327 new_instr->operands[1] = cmp->operands[1];
2328 new_instr->definitions[0] = instr->definitions[0];
2329
2330 ctx.info[instr->definitions[0].tempId()].label = 0;
2331 ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr);
2332
2333 instr.reset(new_instr);
2334
2335 return true;
2336 }
2337
2338 /* op1(op2(1, 2), 0) if swap = false
2339 * op1(0, op2(1, 2)) if swap = true */
2340 bool
match_op3_for_vop3(opt_ctx & ctx,aco_opcode op1,aco_opcode op2,Instruction * op1_instr,bool swap,const char * shuffle_str,Operand operands[3],bool neg[3],bool abs[3],uint8_t * opsel,bool * op1_clamp,uint8_t * op1_omod,bool * inbetween_neg,bool * inbetween_abs,bool * inbetween_opsel,bool * precise)2341 match_op3_for_vop3(opt_ctx& ctx, aco_opcode op1, aco_opcode op2, Instruction* op1_instr, bool swap,
2342 const char* shuffle_str, Operand operands[3], bool neg[3], bool abs[3],
2343 uint8_t* opsel, bool* op1_clamp, uint8_t* op1_omod, bool* inbetween_neg,
2344 bool* inbetween_abs, bool* inbetween_opsel, bool* precise)
2345 {
2346 /* checks */
2347 if (op1_instr->opcode != op1)
2348 return false;
2349
2350 Instruction* op2_instr = follow_operand(ctx, op1_instr->operands[swap]);
2351 if (!op2_instr || op2_instr->opcode != op2)
2352 return false;
2353 if (fixed_to_exec(op2_instr->operands[0]) || fixed_to_exec(op2_instr->operands[1]))
2354 return false;
2355
2356 VOP3_instruction* op1_vop3 = op1_instr->isVOP3() ? &op1_instr->vop3() : NULL;
2357 VOP3_instruction* op2_vop3 = op2_instr->isVOP3() ? &op2_instr->vop3() : NULL;
2358
2359 if (op1_instr->isSDWA() || op2_instr->isSDWA())
2360 return false;
2361 if (op1_instr->isDPP() || op2_instr->isDPP())
2362 return false;
2363
2364 /* don't support inbetween clamp/omod */
2365 if (op2_vop3 && (op2_vop3->clamp || op2_vop3->omod))
2366 return false;
2367
2368 /* get operands and modifiers and check inbetween modifiers */
2369 *op1_clamp = op1_vop3 ? op1_vop3->clamp : false;
2370 *op1_omod = op1_vop3 ? op1_vop3->omod : 0u;
2371
2372 if (inbetween_neg)
2373 *inbetween_neg = op1_vop3 ? op1_vop3->neg[swap] : false;
2374 else if (op1_vop3 && op1_vop3->neg[swap])
2375 return false;
2376
2377 if (inbetween_abs)
2378 *inbetween_abs = op1_vop3 ? op1_vop3->abs[swap] : false;
2379 else if (op1_vop3 && op1_vop3->abs[swap])
2380 return false;
2381
2382 if (inbetween_opsel)
2383 *inbetween_opsel = op1_vop3 ? op1_vop3->opsel & (1 << (unsigned)swap) : false;
2384 else if (op1_vop3 && op1_vop3->opsel & (1 << (unsigned)swap))
2385 return false;
2386
2387 *precise = op1_instr->definitions[0].isPrecise() || op2_instr->definitions[0].isPrecise();
2388
2389 int shuffle[3];
2390 shuffle[shuffle_str[0] - '0'] = 0;
2391 shuffle[shuffle_str[1] - '0'] = 1;
2392 shuffle[shuffle_str[2] - '0'] = 2;
2393
2394 operands[shuffle[0]] = op1_instr->operands[!swap];
2395 neg[shuffle[0]] = op1_vop3 ? op1_vop3->neg[!swap] : false;
2396 abs[shuffle[0]] = op1_vop3 ? op1_vop3->abs[!swap] : false;
2397 if (op1_vop3 && (op1_vop3->opsel & (1 << (unsigned)!swap)))
2398 *opsel |= 1 << shuffle[0];
2399
2400 for (unsigned i = 0; i < 2; i++) {
2401 operands[shuffle[i + 1]] = op2_instr->operands[i];
2402 neg[shuffle[i + 1]] = op2_vop3 ? op2_vop3->neg[i] : false;
2403 abs[shuffle[i + 1]] = op2_vop3 ? op2_vop3->abs[i] : false;
2404 if (op2_vop3 && op2_vop3->opsel & (1 << i))
2405 *opsel |= 1 << shuffle[i + 1];
2406 }
2407
2408 /* check operands */
2409 if (!check_vop3_operands(ctx, 3, operands))
2410 return false;
2411
2412 return true;
2413 }
2414
2415 void
create_vop3_for_op3(opt_ctx & ctx,aco_opcode opcode,aco_ptr<Instruction> & instr,Operand operands[3],bool neg[3],bool abs[3],uint8_t opsel,bool clamp,unsigned omod)2416 create_vop3_for_op3(opt_ctx& ctx, aco_opcode opcode, aco_ptr<Instruction>& instr,
2417 Operand operands[3], bool neg[3], bool abs[3], uint8_t opsel, bool clamp,
2418 unsigned omod)
2419 {
2420 VOP3_instruction* new_instr = create_instruction<VOP3_instruction>(opcode, Format::VOP3, 3, 1);
2421 memcpy(new_instr->abs, abs, sizeof(bool[3]));
2422 memcpy(new_instr->neg, neg, sizeof(bool[3]));
2423 new_instr->clamp = clamp;
2424 new_instr->omod = omod;
2425 new_instr->opsel = opsel;
2426 new_instr->operands[0] = operands[0];
2427 new_instr->operands[1] = operands[1];
2428 new_instr->operands[2] = operands[2];
2429 new_instr->definitions[0] = instr->definitions[0];
2430 ctx.info[instr->definitions[0].tempId()].label = 0;
2431
2432 instr.reset(new_instr);
2433 }
2434
2435 bool
combine_three_valu_op(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode op2,aco_opcode new_op,const char * shuffle,uint8_t ops)2436 combine_three_valu_op(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode op2, aco_opcode new_op,
2437 const char* shuffle, uint8_t ops)
2438 {
2439 for (unsigned swap = 0; swap < 2; swap++) {
2440 if (!((1 << swap) & ops))
2441 continue;
2442
2443 Operand operands[3];
2444 bool neg[3], abs[3], clamp, precise;
2445 uint8_t opsel = 0, omod = 0;
2446 if (match_op3_for_vop3(ctx, instr->opcode, op2, instr.get(), swap, shuffle, operands, neg,
2447 abs, &opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
2448 ctx.uses[instr->operands[swap].tempId()]--;
2449 create_vop3_for_op3(ctx, new_op, instr, operands, neg, abs, opsel, clamp, omod);
2450 return true;
2451 }
2452 }
2453 return false;
2454 }
2455
2456 /* creates v_lshl_add_u32, v_lshl_or_b32 or v_and_or_b32 */
2457 bool
combine_add_or_then_and_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr)2458 combine_add_or_then_and_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2459 {
2460 bool is_or = instr->opcode == aco_opcode::v_or_b32;
2461 aco_opcode new_op_lshl = is_or ? aco_opcode::v_lshl_or_b32 : aco_opcode::v_lshl_add_u32;
2462
2463 if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::s_and_b32, aco_opcode::v_and_or_b32,
2464 "120", 1 | 2))
2465 return true;
2466 if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32,
2467 "120", 1 | 2))
2468 return true;
2469 if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, new_op_lshl, "120", 1 | 2))
2470 return true;
2471 if (combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, new_op_lshl, "210", 1 | 2))
2472 return true;
2473
2474 if (instr->isSDWA() || instr->isDPP())
2475 return false;
2476
2477 /* v_or_b32(p_extract(a, 0, 8/16, 0), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2478 * v_or_b32(p_insert(a, 0, 8/16), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2479 * v_or_b32(p_insert(a, 24/16, 8/16), b) -> v_lshl_or_b32(a, 24/16, b)
2480 * v_add_u32(p_insert(a, 24/16, 8/16), b) -> v_lshl_add_b32(a, 24/16, b)
2481 */
2482 for (unsigned i = 0; i < 2; i++) {
2483 Instruction* extins = follow_operand(ctx, instr->operands[i]);
2484 if (!extins)
2485 continue;
2486
2487 aco_opcode op;
2488 Operand operands[3];
2489
2490 if (extins->opcode == aco_opcode::p_insert &&
2491 (extins->operands[1].constantValue() + 1) * extins->operands[2].constantValue() == 32) {
2492 op = new_op_lshl;
2493 operands[1] =
2494 Operand::c32(extins->operands[1].constantValue() * extins->operands[2].constantValue());
2495 } else if (is_or &&
2496 (extins->opcode == aco_opcode::p_insert ||
2497 (extins->opcode == aco_opcode::p_extract &&
2498 extins->operands[3].constantEquals(0))) &&
2499 extins->operands[1].constantEquals(0)) {
2500 op = aco_opcode::v_and_or_b32;
2501 operands[1] = Operand::c32(extins->operands[2].constantEquals(8) ? 0xffu : 0xffffu);
2502 } else {
2503 continue;
2504 }
2505
2506 operands[0] = extins->operands[0];
2507 operands[2] = instr->operands[!i];
2508
2509 if (!check_vop3_operands(ctx, 3, operands))
2510 continue;
2511
2512 bool neg[3] = {}, abs[3] = {};
2513 uint8_t opsel = 0, omod = 0;
2514 bool clamp = false;
2515 if (instr->isVOP3())
2516 clamp = instr->vop3().clamp;
2517
2518 ctx.uses[instr->operands[i].tempId()]--;
2519 create_vop3_for_op3(ctx, op, instr, operands, neg, abs, opsel, clamp, omod);
2520 return true;
2521 }
2522
2523 return false;
2524 }
2525
2526 bool
combine_minmax(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode opposite,aco_opcode minmax3)2527 combine_minmax(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode opposite, aco_opcode minmax3)
2528 {
2529 /* TODO: this can handle SDWA min/max instructions by using opsel */
2530 if (combine_three_valu_op(ctx, instr, instr->opcode, minmax3, "012", 1 | 2))
2531 return true;
2532
2533 /* min(-max(a, b), c) -> min3(c, -a, -b) *
2534 * max(-min(a, b), c) -> max3(c, -a, -b) */
2535 for (unsigned swap = 0; swap < 2; swap++) {
2536 Operand operands[3];
2537 bool neg[3], abs[3], clamp, precise;
2538 uint8_t opsel = 0, omod = 0;
2539 bool inbetween_neg;
2540 if (match_op3_for_vop3(ctx, instr->opcode, opposite, instr.get(), swap, "012", operands, neg,
2541 abs, &opsel, &clamp, &omod, &inbetween_neg, NULL, NULL, &precise) &&
2542 inbetween_neg) {
2543 ctx.uses[instr->operands[swap].tempId()]--;
2544 neg[1] = !neg[1];
2545 neg[2] = !neg[2];
2546 create_vop3_for_op3(ctx, minmax3, instr, operands, neg, abs, opsel, clamp, omod);
2547 return true;
2548 }
2549 }
2550 return false;
2551 }
2552
2553 /* s_not_b32(s_and_b32(a, b)) -> s_nand_b32(a, b)
2554 * s_not_b32(s_or_b32(a, b)) -> s_nor_b32(a, b)
2555 * s_not_b32(s_xor_b32(a, b)) -> s_xnor_b32(a, b)
2556 * s_not_b64(s_and_b64(a, b)) -> s_nand_b64(a, b)
2557 * s_not_b64(s_or_b64(a, b)) -> s_nor_b64(a, b)
2558 * s_not_b64(s_xor_b64(a, b)) -> s_xnor_b64(a, b) */
2559 bool
combine_salu_not_bitwise(opt_ctx & ctx,aco_ptr<Instruction> & instr)2560 combine_salu_not_bitwise(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2561 {
2562 /* checks */
2563 if (!instr->operands[0].isTemp())
2564 return false;
2565 if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2566 return false;
2567
2568 Instruction* op2_instr = follow_operand(ctx, instr->operands[0]);
2569 if (!op2_instr)
2570 return false;
2571 switch (op2_instr->opcode) {
2572 case aco_opcode::s_and_b32:
2573 case aco_opcode::s_or_b32:
2574 case aco_opcode::s_xor_b32:
2575 case aco_opcode::s_and_b64:
2576 case aco_opcode::s_or_b64:
2577 case aco_opcode::s_xor_b64: break;
2578 default: return false;
2579 }
2580
2581 /* create instruction */
2582 std::swap(instr->definitions[0], op2_instr->definitions[0]);
2583 std::swap(instr->definitions[1], op2_instr->definitions[1]);
2584 ctx.uses[instr->operands[0].tempId()]--;
2585 ctx.info[op2_instr->definitions[0].tempId()].label = 0;
2586
2587 switch (op2_instr->opcode) {
2588 case aco_opcode::s_and_b32: op2_instr->opcode = aco_opcode::s_nand_b32; break;
2589 case aco_opcode::s_or_b32: op2_instr->opcode = aco_opcode::s_nor_b32; break;
2590 case aco_opcode::s_xor_b32: op2_instr->opcode = aco_opcode::s_xnor_b32; break;
2591 case aco_opcode::s_and_b64: op2_instr->opcode = aco_opcode::s_nand_b64; break;
2592 case aco_opcode::s_or_b64: op2_instr->opcode = aco_opcode::s_nor_b64; break;
2593 case aco_opcode::s_xor_b64: op2_instr->opcode = aco_opcode::s_xnor_b64; break;
2594 default: break;
2595 }
2596
2597 return true;
2598 }
2599
2600 /* s_and_b32(a, s_not_b32(b)) -> s_andn2_b32(a, b)
2601 * s_or_b32(a, s_not_b32(b)) -> s_orn2_b32(a, b)
2602 * s_and_b64(a, s_not_b64(b)) -> s_andn2_b64(a, b)
2603 * s_or_b64(a, s_not_b64(b)) -> s_orn2_b64(a, b) */
2604 bool
combine_salu_n2(opt_ctx & ctx,aco_ptr<Instruction> & instr)2605 combine_salu_n2(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2606 {
2607 if (instr->definitions[0].isTemp() && ctx.info[instr->definitions[0].tempId()].is_uniform_bool())
2608 return false;
2609
2610 for (unsigned i = 0; i < 2; i++) {
2611 Instruction* op2_instr = follow_operand(ctx, instr->operands[i]);
2612 if (!op2_instr || (op2_instr->opcode != aco_opcode::s_not_b32 &&
2613 op2_instr->opcode != aco_opcode::s_not_b64))
2614 continue;
2615 if (ctx.uses[op2_instr->definitions[1].tempId()] || fixed_to_exec(op2_instr->operands[0]))
2616 continue;
2617
2618 if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
2619 instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
2620 continue;
2621
2622 ctx.uses[instr->operands[i].tempId()]--;
2623 instr->operands[0] = instr->operands[!i];
2624 instr->operands[1] = op2_instr->operands[0];
2625 ctx.info[instr->definitions[0].tempId()].label = 0;
2626
2627 switch (instr->opcode) {
2628 case aco_opcode::s_and_b32: instr->opcode = aco_opcode::s_andn2_b32; break;
2629 case aco_opcode::s_or_b32: instr->opcode = aco_opcode::s_orn2_b32; break;
2630 case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_andn2_b64; break;
2631 case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_orn2_b64; break;
2632 default: break;
2633 }
2634
2635 return true;
2636 }
2637 return false;
2638 }
2639
2640 /* s_add_{i32,u32}(a, s_lshl_b32(b, <n>)) -> s_lshl<n>_add_u32(a, b) */
2641 bool
combine_salu_lshl_add(opt_ctx & ctx,aco_ptr<Instruction> & instr)2642 combine_salu_lshl_add(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2643 {
2644 if (instr->opcode == aco_opcode::s_add_i32 && ctx.uses[instr->definitions[1].tempId()])
2645 return false;
2646
2647 for (unsigned i = 0; i < 2; i++) {
2648 Instruction* op2_instr = follow_operand(ctx, instr->operands[i], true);
2649 if (!op2_instr || op2_instr->opcode != aco_opcode::s_lshl_b32 ||
2650 ctx.uses[op2_instr->definitions[1].tempId()])
2651 continue;
2652 if (!op2_instr->operands[1].isConstant() || fixed_to_exec(op2_instr->operands[0]))
2653 continue;
2654
2655 uint32_t shift = op2_instr->operands[1].constantValue();
2656 if (shift < 1 || shift > 4)
2657 continue;
2658
2659 if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
2660 instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
2661 continue;
2662
2663 ctx.uses[instr->operands[i].tempId()]--;
2664 instr->operands[1] = instr->operands[!i];
2665 instr->operands[0] = op2_instr->operands[0];
2666 ctx.info[instr->definitions[0].tempId()].label = 0;
2667
2668 instr->opcode = std::array<aco_opcode, 4>{
2669 aco_opcode::s_lshl1_add_u32, aco_opcode::s_lshl2_add_u32, aco_opcode::s_lshl3_add_u32,
2670 aco_opcode::s_lshl4_add_u32}[shift - 1];
2671
2672 return true;
2673 }
2674 return false;
2675 }
2676
2677 bool
combine_add_sub_b2i(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode new_op,uint8_t ops)2678 combine_add_sub_b2i(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode new_op, uint8_t ops)
2679 {
2680 if (instr->usesModifiers())
2681 return false;
2682
2683 for (unsigned i = 0; i < 2; i++) {
2684 if (!((1 << i) & ops))
2685 continue;
2686 if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2i() &&
2687 ctx.uses[instr->operands[i].tempId()] == 1) {
2688
2689 aco_ptr<Instruction> new_instr;
2690 if (instr->operands[!i].isTemp() &&
2691 instr->operands[!i].getTemp().type() == RegType::vgpr) {
2692 new_instr.reset(create_instruction<VOP2_instruction>(new_op, Format::VOP2, 3, 2));
2693 } else if (ctx.program->chip_class >= GFX10 ||
2694 (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
2695 new_instr.reset(
2696 create_instruction<VOP3_instruction>(new_op, asVOP3(Format::VOP2), 3, 2));
2697 } else {
2698 return false;
2699 }
2700 ctx.uses[instr->operands[i].tempId()]--;
2701 new_instr->definitions[0] = instr->definitions[0];
2702 if (instr->definitions.size() == 2) {
2703 new_instr->definitions[1] = instr->definitions[1];
2704 } else {
2705 new_instr->definitions[1] =
2706 Definition(ctx.program->allocateTmp(ctx.program->lane_mask));
2707 /* Make sure the uses vector is large enough and the number of
2708 * uses properly initialized to 0.
2709 */
2710 ctx.uses.push_back(0);
2711 }
2712 new_instr->definitions[1].setHint(vcc);
2713 new_instr->operands[0] = Operand::zero();
2714 new_instr->operands[1] = instr->operands[!i];
2715 new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
2716 instr = std::move(new_instr);
2717 ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
2718 return true;
2719 }
2720 }
2721
2722 return false;
2723 }
2724
2725 bool
combine_add_bcnt(opt_ctx & ctx,aco_ptr<Instruction> & instr)2726 combine_add_bcnt(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2727 {
2728 if (instr->usesModifiers())
2729 return false;
2730
2731 for (unsigned i = 0; i < 2; i++) {
2732 Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
2733 if (op_instr && op_instr->opcode == aco_opcode::v_bcnt_u32_b32 &&
2734 !op_instr->usesModifiers() && op_instr->operands[0].isTemp() &&
2735 op_instr->operands[0].getTemp().type() == RegType::vgpr &&
2736 op_instr->operands[1].constantEquals(0)) {
2737 aco_ptr<Instruction> new_instr{
2738 create_instruction<VOP3_instruction>(aco_opcode::v_bcnt_u32_b32, Format::VOP3, 2, 1)};
2739 ctx.uses[instr->operands[i].tempId()]--;
2740 new_instr->operands[0] = op_instr->operands[0];
2741 new_instr->operands[1] = instr->operands[!i];
2742 new_instr->definitions[0] = instr->definitions[0];
2743 instr = std::move(new_instr);
2744 ctx.info[instr->definitions[0].tempId()].label = 0;
2745
2746 return true;
2747 }
2748 }
2749
2750 return false;
2751 }
2752
2753 bool
get_minmax_info(aco_opcode op,aco_opcode * min,aco_opcode * max,aco_opcode * min3,aco_opcode * max3,aco_opcode * med3,bool * some_gfx9_only)2754 get_minmax_info(aco_opcode op, aco_opcode* min, aco_opcode* max, aco_opcode* min3, aco_opcode* max3,
2755 aco_opcode* med3, bool* some_gfx9_only)
2756 {
2757 switch (op) {
2758 #define MINMAX(type, gfx9) \
2759 case aco_opcode::v_min_##type: \
2760 case aco_opcode::v_max_##type: \
2761 case aco_opcode::v_med3_##type: \
2762 *min = aco_opcode::v_min_##type; \
2763 *max = aco_opcode::v_max_##type; \
2764 *med3 = aco_opcode::v_med3_##type; \
2765 *min3 = aco_opcode::v_min3_##type; \
2766 *max3 = aco_opcode::v_max3_##type; \
2767 *some_gfx9_only = gfx9; \
2768 return true;
2769 MINMAX(f32, false)
2770 MINMAX(u32, false)
2771 MINMAX(i32, false)
2772 MINMAX(f16, true)
2773 MINMAX(u16, true)
2774 MINMAX(i16, true)
2775 #undef MINMAX
2776 default: return false;
2777 }
2778 }
2779
2780 /* when ub > lb:
2781 * v_min_{f,u,i}{16,32}(v_max_{f,u,i}{16,32}(a, lb), ub) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
2782 * v_max_{f,u,i}{16,32}(v_min_{f,u,i}{16,32}(a, ub), lb) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
2783 */
2784 bool
combine_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode min,aco_opcode max,aco_opcode med)2785 combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode min, aco_opcode max,
2786 aco_opcode med)
2787 {
2788 /* TODO: GLSL's clamp(x, minVal, maxVal) and SPIR-V's
2789 * FClamp(x, minVal, maxVal)/NClamp(x, minVal, maxVal) are undefined if
2790 * minVal > maxVal, which means we can always select it to a v_med3_f32 */
2791 aco_opcode other_op;
2792 if (instr->opcode == min)
2793 other_op = max;
2794 else if (instr->opcode == max)
2795 other_op = min;
2796 else
2797 return false;
2798
2799 for (unsigned swap = 0; swap < 2; swap++) {
2800 Operand operands[3];
2801 bool neg[3], abs[3], clamp, precise;
2802 uint8_t opsel = 0, omod = 0;
2803 if (match_op3_for_vop3(ctx, instr->opcode, other_op, instr.get(), swap, "012", operands, neg,
2804 abs, &opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
2805 /* max(min(src, upper), lower) returns upper if src is NaN, but
2806 * med3(src, lower, upper) returns lower.
2807 */
2808 if (precise && instr->opcode != min)
2809 continue;
2810
2811 int const0_idx = -1, const1_idx = -1;
2812 uint32_t const0 = 0, const1 = 0;
2813 for (int i = 0; i < 3; i++) {
2814 uint32_t val;
2815 if (operands[i].isConstant()) {
2816 val = operands[i].constantValue();
2817 } else if (operands[i].isTemp() &&
2818 ctx.info[operands[i].tempId()].is_constant_or_literal(32)) {
2819 val = ctx.info[operands[i].tempId()].val;
2820 } else {
2821 continue;
2822 }
2823 if (const0_idx >= 0) {
2824 const1_idx = i;
2825 const1 = val;
2826 } else {
2827 const0_idx = i;
2828 const0 = val;
2829 }
2830 }
2831 if (const0_idx < 0 || const1_idx < 0)
2832 continue;
2833
2834 if (opsel & (1 << const0_idx))
2835 const0 >>= 16;
2836 if (opsel & (1 << const1_idx))
2837 const1 >>= 16;
2838
2839 int lower_idx = const0_idx;
2840 switch (min) {
2841 case aco_opcode::v_min_f32:
2842 case aco_opcode::v_min_f16: {
2843 float const0_f, const1_f;
2844 if (min == aco_opcode::v_min_f32) {
2845 memcpy(&const0_f, &const0, 4);
2846 memcpy(&const1_f, &const1, 4);
2847 } else {
2848 const0_f = _mesa_half_to_float(const0);
2849 const1_f = _mesa_half_to_float(const1);
2850 }
2851 if (abs[const0_idx])
2852 const0_f = fabsf(const0_f);
2853 if (abs[const1_idx])
2854 const1_f = fabsf(const1_f);
2855 if (neg[const0_idx])
2856 const0_f = -const0_f;
2857 if (neg[const1_idx])
2858 const1_f = -const1_f;
2859 lower_idx = const0_f < const1_f ? const0_idx : const1_idx;
2860 break;
2861 }
2862 case aco_opcode::v_min_u32: {
2863 lower_idx = const0 < const1 ? const0_idx : const1_idx;
2864 break;
2865 }
2866 case aco_opcode::v_min_u16: {
2867 lower_idx = (uint16_t)const0 < (uint16_t)const1 ? const0_idx : const1_idx;
2868 break;
2869 }
2870 case aco_opcode::v_min_i32: {
2871 int32_t const0_i =
2872 const0 & 0x80000000u ? -2147483648 + (int32_t)(const0 & 0x7fffffffu) : const0;
2873 int32_t const1_i =
2874 const1 & 0x80000000u ? -2147483648 + (int32_t)(const1 & 0x7fffffffu) : const1;
2875 lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
2876 break;
2877 }
2878 case aco_opcode::v_min_i16: {
2879 int16_t const0_i = const0 & 0x8000u ? -32768 + (int16_t)(const0 & 0x7fffu) : const0;
2880 int16_t const1_i = const1 & 0x8000u ? -32768 + (int16_t)(const1 & 0x7fffu) : const1;
2881 lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
2882 break;
2883 }
2884 default: break;
2885 }
2886 int upper_idx = lower_idx == const0_idx ? const1_idx : const0_idx;
2887
2888 if (instr->opcode == min) {
2889 if (upper_idx != 0 || lower_idx == 0)
2890 return false;
2891 } else {
2892 if (upper_idx == 0 || lower_idx != 0)
2893 return false;
2894 }
2895
2896 ctx.uses[instr->operands[swap].tempId()]--;
2897 create_vop3_for_op3(ctx, med, instr, operands, neg, abs, opsel, clamp, omod);
2898
2899 return true;
2900 }
2901 }
2902
2903 return false;
2904 }
2905
2906 void
apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)2907 apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2908 {
2909 bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64 ||
2910 instr->opcode == aco_opcode::v_lshrrev_b64 ||
2911 instr->opcode == aco_opcode::v_ashrrev_i64;
2912
2913 /* find candidates and create the set of sgprs already read */
2914 unsigned sgpr_ids[2] = {0, 0};
2915 uint32_t operand_mask = 0;
2916 bool has_literal = false;
2917 for (unsigned i = 0; i < instr->operands.size(); i++) {
2918 if (instr->operands[i].isLiteral())
2919 has_literal = true;
2920 if (!instr->operands[i].isTemp())
2921 continue;
2922 if (instr->operands[i].getTemp().type() == RegType::sgpr) {
2923 if (instr->operands[i].tempId() != sgpr_ids[0])
2924 sgpr_ids[!!sgpr_ids[0]] = instr->operands[i].tempId();
2925 }
2926 ssa_info& info = ctx.info[instr->operands[i].tempId()];
2927 if (is_copy_label(ctx, instr, info) && info.temp.type() == RegType::sgpr)
2928 operand_mask |= 1u << i;
2929 if (info.is_extract() && info.instr->operands[0].getTemp().type() == RegType::sgpr)
2930 operand_mask |= 1u << i;
2931 }
2932 unsigned max_sgprs = 1;
2933 if (ctx.program->chip_class >= GFX10 && !is_shift64)
2934 max_sgprs = 2;
2935 if (has_literal)
2936 max_sgprs--;
2937
2938 unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
2939
2940 /* keep on applying sgprs until there is nothing left to be done */
2941 while (operand_mask) {
2942 uint32_t sgpr_idx = 0;
2943 uint32_t sgpr_info_id = 0;
2944 uint32_t mask = operand_mask;
2945 /* choose a sgpr */
2946 while (mask) {
2947 unsigned i = u_bit_scan(&mask);
2948 uint16_t uses = ctx.uses[instr->operands[i].tempId()];
2949 if (sgpr_info_id == 0 || uses < ctx.uses[sgpr_info_id]) {
2950 sgpr_idx = i;
2951 sgpr_info_id = instr->operands[i].tempId();
2952 }
2953 }
2954 operand_mask &= ~(1u << sgpr_idx);
2955
2956 ssa_info& info = ctx.info[sgpr_info_id];
2957
2958 /* Applying two sgprs require making it VOP3, so don't do it unless it's
2959 * definitively beneficial.
2960 * TODO: this is too conservative because later the use count could be reduced to 1 */
2961 if (!info.is_extract() && num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3() &&
2962 !instr->isSDWA() && instr->format != Format::VOP3P)
2963 break;
2964
2965 Temp sgpr = info.is_extract() ? info.instr->operands[0].getTemp() : info.temp;
2966 bool new_sgpr = sgpr.id() != sgpr_ids[0] && sgpr.id() != sgpr_ids[1];
2967 if (new_sgpr && num_sgprs >= max_sgprs)
2968 continue;
2969
2970 if (sgpr_idx == 0)
2971 instr->format = withoutDPP(instr->format);
2972
2973 if (sgpr_idx == 0 || instr->isVOP3() || instr->isSDWA() || instr->isVOP3P() ||
2974 info.is_extract()) {
2975 /* can_apply_extract() checks SGPR encoding restrictions */
2976 if (info.is_extract() && can_apply_extract(ctx, instr, sgpr_idx, info))
2977 apply_extract(ctx, instr, sgpr_idx, info);
2978 else if (info.is_extract())
2979 continue;
2980 instr->operands[sgpr_idx] = Operand(sgpr);
2981 } else if (can_swap_operands(instr, &instr->opcode)) {
2982 instr->operands[sgpr_idx] = instr->operands[0];
2983 instr->operands[0] = Operand(sgpr);
2984 /* swap bits using a 4-entry LUT */
2985 uint32_t swapped = (0x3120 >> (operand_mask & 0x3)) & 0xf;
2986 operand_mask = (operand_mask & ~0x3) | swapped;
2987 } else if (can_use_VOP3(ctx, instr) && !info.is_extract()) {
2988 to_VOP3(ctx, instr);
2989 instr->operands[sgpr_idx] = Operand(sgpr);
2990 } else {
2991 continue;
2992 }
2993
2994 if (new_sgpr)
2995 sgpr_ids[num_sgprs++] = sgpr.id();
2996 ctx.uses[sgpr_info_id]--;
2997 ctx.uses[sgpr.id()]++;
2998
2999 /* TODO: handle when it's a VGPR */
3000 if ((ctx.info[sgpr.id()].label & (label_extract | label_temp)) &&
3001 ctx.info[sgpr.id()].temp.type() == RegType::sgpr)
3002 operand_mask |= 1u << sgpr_idx;
3003 }
3004 }
3005
3006 template <typename T>
3007 bool
apply_omod_clamp_helper(opt_ctx & ctx,T * instr,ssa_info & def_info)3008 apply_omod_clamp_helper(opt_ctx& ctx, T* instr, ssa_info& def_info)
3009 {
3010 if (!def_info.is_clamp() && (instr->clamp || instr->omod))
3011 return false;
3012
3013 if (def_info.is_omod2())
3014 instr->omod = 1;
3015 else if (def_info.is_omod4())
3016 instr->omod = 2;
3017 else if (def_info.is_omod5())
3018 instr->omod = 3;
3019 else if (def_info.is_clamp())
3020 instr->clamp = true;
3021
3022 return true;
3023 }
3024
3025 /* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */
3026 bool
apply_omod_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr)3027 apply_omod_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3028 {
3029 if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1 ||
3030 !instr_info.can_use_output_modifiers[(int)instr->opcode])
3031 return false;
3032
3033 bool can_vop3 = can_use_VOP3(ctx, instr);
3034 bool is_mad_mix =
3035 instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16;
3036 if (!instr->isSDWA() && !is_mad_mix && !can_vop3)
3037 return false;
3038
3039 /* omod flushes -0 to +0 and has no effect if denormals are enabled. SDWA omod is GFX9+. */
3040 bool can_use_omod = (can_vop3 || ctx.program->chip_class >= GFX9) && !instr->isVOP3P();
3041 if (instr->definitions[0].bytes() == 4)
3042 can_use_omod =
3043 can_use_omod && ctx.fp_mode.denorm32 == 0 && !ctx.fp_mode.preserve_signed_zero_inf_nan32;
3044 else
3045 can_use_omod = can_use_omod && ctx.fp_mode.denorm16_64 == 0 &&
3046 !ctx.fp_mode.preserve_signed_zero_inf_nan16_64;
3047
3048 ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3049
3050 uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5;
3051 if (!def_info.is_clamp() && !(can_use_omod && (def_info.label & omod_labels)))
3052 return false;
3053 /* if the omod/clamp instruction is dead, then the single user of this
3054 * instruction is a different instruction */
3055 if (!ctx.uses[def_info.instr->definitions[0].tempId()])
3056 return false;
3057
3058 if (def_info.instr->definitions[0].bytes() != instr->definitions[0].bytes())
3059 return false;
3060
3061 /* MADs/FMAs are created later, so we don't have to update the original add */
3062 assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
3063
3064 if (instr->isSDWA()) {
3065 if (!apply_omod_clamp_helper(ctx, &instr->sdwa(), def_info))
3066 return false;
3067 } else if (instr->isVOP3P()) {
3068 assert(def_info.is_clamp());
3069 instr->vop3p().clamp = true;
3070 } else {
3071 to_VOP3(ctx, instr);
3072 if (!apply_omod_clamp_helper(ctx, &instr->vop3(), def_info))
3073 return false;
3074 }
3075
3076 instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
3077 ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_insert | label_f2f16;
3078 ctx.uses[def_info.instr->definitions[0].tempId()]--;
3079
3080 return true;
3081 }
3082
3083 /* Combine an p_insert (or p_extract, in some cases) instruction with instr.
3084 * p_insert(instr(...)) -> instr_insert().
3085 */
3086 bool
apply_insert(opt_ctx & ctx,aco_ptr<Instruction> & instr)3087 apply_insert(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3088 {
3089 if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1)
3090 return false;
3091
3092 ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3093 if (!def_info.is_insert())
3094 return false;
3095 /* if the insert instruction is dead, then the single user of this
3096 * instruction is a different instruction */
3097 if (!ctx.uses[def_info.instr->definitions[0].tempId()])
3098 return false;
3099
3100 /* MADs/FMAs are created later, so we don't have to update the original add */
3101 assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
3102
3103 SubdwordSel sel = parse_insert(def_info.instr);
3104 assert(sel);
3105
3106 if (instr->isVOP3() && sel.size() == 2 && !sel.sign_extend() &&
3107 can_use_opsel(ctx.program->chip_class, instr->opcode, -1)) {
3108 if (instr->vop3().opsel & (1 << 3))
3109 return false;
3110 if (sel.offset())
3111 instr->vop3().opsel |= 1 << 3;
3112 } else {
3113 if (!can_use_SDWA(ctx.program->chip_class, instr, true))
3114 return false;
3115
3116 to_SDWA(ctx, instr);
3117 if (instr->sdwa().dst_sel.size() != 4)
3118 return false;
3119 static_cast<SDWA_instruction*>(instr.get())->dst_sel = sel;
3120 }
3121
3122 instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
3123 ctx.info[instr->definitions[0].tempId()].label = 0;
3124 ctx.uses[def_info.instr->definitions[0].tempId()]--;
3125
3126 return true;
3127 }
3128
3129 /* Remove superfluous extract after ds_read like so:
3130 * p_extract(ds_read_uN(), 0, N, 0) -> ds_read_uN()
3131 */
3132 bool
apply_ds_extract(opt_ctx & ctx,aco_ptr<Instruction> & extract)3133 apply_ds_extract(opt_ctx& ctx, aco_ptr<Instruction>& extract)
3134 {
3135 /* Check if p_extract has a usedef operand and is the only user. */
3136 if (!ctx.info[extract->operands[0].tempId()].is_usedef() ||
3137 ctx.uses[extract->operands[0].tempId()] > 1)
3138 return false;
3139
3140 /* Check if the usedef is a DS instruction. */
3141 Instruction* ds = ctx.info[extract->operands[0].tempId()].instr;
3142 if (ds->format != Format::DS)
3143 return false;
3144
3145 unsigned extract_idx = extract->operands[1].constantValue();
3146 unsigned bits_extracted = extract->operands[2].constantValue();
3147 unsigned sign_ext = extract->operands[3].constantValue();
3148 unsigned dst_bitsize = extract->definitions[0].bytes() * 8u;
3149
3150 /* TODO: These are doable, but probably don't occour too often. */
3151 if (extract_idx || sign_ext || dst_bitsize != 32)
3152 return false;
3153
3154 unsigned bits_loaded = 0;
3155 if (ds->opcode == aco_opcode::ds_read_u8 || ds->opcode == aco_opcode::ds_read_u8_d16)
3156 bits_loaded = 8;
3157 else if (ds->opcode == aco_opcode::ds_read_u16 || ds->opcode == aco_opcode::ds_read_u16_d16)
3158 bits_loaded = 16;
3159 else
3160 return false;
3161
3162 /* Shrink the DS load if the extracted bit size is smaller. */
3163 bits_loaded = MIN2(bits_loaded, bits_extracted);
3164
3165 /* Change the DS opcode so it writes the full register. */
3166 if (bits_loaded == 8)
3167 ds->opcode = aco_opcode::ds_read_u8;
3168 else if (bits_loaded == 16)
3169 ds->opcode = aco_opcode::ds_read_u16;
3170 else
3171 unreachable("Forgot to add DS opcode above.");
3172
3173 /* The DS now produces the exact same thing as the extract, remove the extract. */
3174 std::swap(ds->definitions[0], extract->definitions[0]);
3175 ctx.uses[extract->definitions[0].tempId()] = 0;
3176 ctx.info[ds->definitions[0].tempId()].label = 0;
3177 return true;
3178 }
3179
3180 /* v_and(a, v_subbrev_co(0, 0, vcc)) -> v_cndmask(0, a, vcc) */
3181 bool
combine_and_subbrev(opt_ctx & ctx,aco_ptr<Instruction> & instr)3182 combine_and_subbrev(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3183 {
3184 if (instr->usesModifiers())
3185 return false;
3186
3187 for (unsigned i = 0; i < 2; i++) {
3188 Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3189 if (op_instr && op_instr->opcode == aco_opcode::v_subbrev_co_u32 &&
3190 op_instr->operands[0].constantEquals(0) && op_instr->operands[1].constantEquals(0) &&
3191 !op_instr->usesModifiers()) {
3192
3193 aco_ptr<Instruction> new_instr;
3194 if (instr->operands[!i].isTemp() &&
3195 instr->operands[!i].getTemp().type() == RegType::vgpr) {
3196 new_instr.reset(
3197 create_instruction<VOP2_instruction>(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1));
3198 } else if (ctx.program->chip_class >= GFX10 ||
3199 (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
3200 new_instr.reset(create_instruction<VOP3_instruction>(aco_opcode::v_cndmask_b32,
3201 asVOP3(Format::VOP2), 3, 1));
3202 } else {
3203 return false;
3204 }
3205
3206 ctx.uses[instr->operands[i].tempId()]--;
3207 if (ctx.uses[instr->operands[i].tempId()])
3208 ctx.uses[op_instr->operands[2].tempId()]++;
3209
3210 new_instr->operands[0] = Operand::zero();
3211 new_instr->operands[1] = instr->operands[!i];
3212 new_instr->operands[2] = Operand(op_instr->operands[2]);
3213 new_instr->definitions[0] = instr->definitions[0];
3214 instr = std::move(new_instr);
3215 ctx.info[instr->definitions[0].tempId()].label = 0;
3216 return true;
3217 }
3218 }
3219
3220 return false;
3221 }
3222
3223 /* v_add_co(c, s_lshl(a, b)) -> v_mad_u32_u24(a, 1<<b, c)
3224 * v_add_co(c, v_lshlrev(a, b)) -> v_mad_u32_u24(b, 1<<a, c)
3225 * v_sub(c, s_lshl(a, b)) -> v_mad_i32_i24(a, -(1<<b), c)
3226 * v_sub(c, v_lshlrev(a, b)) -> v_mad_i32_i24(b, -(1<<a), c)
3227 */
3228 bool
combine_add_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr,bool is_sub)3229 combine_add_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr, bool is_sub)
3230 {
3231 if (instr->usesModifiers())
3232 return false;
3233
3234 /* Substractions: start at operand 1 to avoid mixup such as
3235 * turning v_sub(v_lshlrev(a, b), c) into v_mad_i32_i24(b, -(1<<a), c)
3236 */
3237 unsigned start_op_idx = is_sub ? 1 : 0;
3238
3239 /* Don't allow 24-bit operands on subtraction because
3240 * v_mad_i32_i24 applies a sign extension.
3241 */
3242 bool allow_24bit = !is_sub;
3243
3244 for (unsigned i = start_op_idx; i < 2; i++) {
3245 Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
3246 if (!op_instr)
3247 continue;
3248
3249 if (op_instr->opcode != aco_opcode::s_lshl_b32 &&
3250 op_instr->opcode != aco_opcode::v_lshlrev_b32)
3251 continue;
3252
3253 int shift_op_idx = op_instr->opcode == aco_opcode::s_lshl_b32 ? 1 : 0;
3254
3255 if (op_instr->operands[shift_op_idx].isConstant() &&
3256 ((allow_24bit && op_instr->operands[!shift_op_idx].is24bit()) ||
3257 op_instr->operands[!shift_op_idx].is16bit())) {
3258 uint32_t multiplier = 1 << (op_instr->operands[shift_op_idx].constantValue() % 32u);
3259 if (is_sub)
3260 multiplier = -multiplier;
3261 if (is_sub ? (multiplier < 0xff800000) : (multiplier > 0xffffff))
3262 continue;
3263
3264 Operand ops[3] = {
3265 op_instr->operands[!shift_op_idx],
3266 Operand::c32(multiplier),
3267 instr->operands[!i],
3268 };
3269 if (!check_vop3_operands(ctx, 3, ops))
3270 return false;
3271
3272 ctx.uses[instr->operands[i].tempId()]--;
3273
3274 aco_opcode mad_op = is_sub ? aco_opcode::v_mad_i32_i24 : aco_opcode::v_mad_u32_u24;
3275 aco_ptr<VOP3_instruction> new_instr{
3276 create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 1)};
3277 for (unsigned op_idx = 0; op_idx < 3; ++op_idx)
3278 new_instr->operands[op_idx] = ops[op_idx];
3279 new_instr->definitions[0] = instr->definitions[0];
3280 instr = std::move(new_instr);
3281 ctx.info[instr->definitions[0].tempId()].label = 0;
3282 return true;
3283 }
3284 }
3285
3286 return false;
3287 }
3288
3289 void
propagate_swizzles(VOP3P_instruction * instr,uint8_t opsel_lo,uint8_t opsel_hi)3290 propagate_swizzles(VOP3P_instruction* instr, uint8_t opsel_lo, uint8_t opsel_hi)
3291 {
3292 /* propagate swizzles which apply to a result down to the instruction's operands:
3293 * result = a.xy + b.xx -> result.yx = a.yx + b.xx */
3294 assert((opsel_lo & 1) == opsel_lo);
3295 assert((opsel_hi & 1) == opsel_hi);
3296 uint8_t tmp_lo = instr->opsel_lo;
3297 uint8_t tmp_hi = instr->opsel_hi;
3298 bool neg_lo[3] = {instr->neg_lo[0], instr->neg_lo[1], instr->neg_lo[2]};
3299 bool neg_hi[3] = {instr->neg_hi[0], instr->neg_hi[1], instr->neg_hi[2]};
3300 if (opsel_lo == 1) {
3301 instr->opsel_lo = tmp_hi;
3302 for (unsigned i = 0; i < 3; i++)
3303 instr->neg_lo[i] = neg_hi[i];
3304 }
3305 if (opsel_hi == 0) {
3306 instr->opsel_hi = tmp_lo;
3307 for (unsigned i = 0; i < 3; i++)
3308 instr->neg_hi[i] = neg_lo[i];
3309 }
3310 }
3311
3312 void
combine_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr)3313 combine_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3314 {
3315 VOP3P_instruction* vop3p = &instr->vop3p();
3316
3317 /* apply clamp */
3318 if (instr->opcode == aco_opcode::v_pk_mul_f16 && instr->operands[1].constantEquals(0x3C00) &&
3319 vop3p->clamp && instr->operands[0].isTemp() && ctx.uses[instr->operands[0].tempId()] == 1) {
3320
3321 ssa_info& info = ctx.info[instr->operands[0].tempId()];
3322 if (info.is_vop3p() && instr_info.can_use_output_modifiers[(int)info.instr->opcode]) {
3323 VOP3P_instruction* candidate = &ctx.info[instr->operands[0].tempId()].instr->vop3p();
3324 candidate->clamp = true;
3325 propagate_swizzles(candidate, vop3p->opsel_lo, vop3p->opsel_hi);
3326 instr->definitions[0].swapTemp(candidate->definitions[0]);
3327 ctx.info[candidate->definitions[0].tempId()].instr = candidate;
3328 ctx.uses[instr->definitions[0].tempId()]--;
3329 return;
3330 }
3331 }
3332
3333 /* check for fneg modifiers */
3334 if (instr_info.can_use_input_modifiers[(int)instr->opcode]) {
3335 for (unsigned i = 0; i < instr->operands.size(); i++) {
3336 Operand& op = instr->operands[i];
3337 if (!op.isTemp())
3338 continue;
3339
3340 ssa_info& info = ctx.info[op.tempId()];
3341 if (info.is_vop3p() && info.instr->opcode == aco_opcode::v_pk_mul_f16 &&
3342 info.instr->operands[1].constantEquals(0x3C00)) {
3343 Operand ops[3];
3344 for (unsigned j = 0; j < instr->operands.size(); j++)
3345 ops[j] = instr->operands[j];
3346 ops[i] = info.instr->operands[0];
3347 if (!check_vop3_operands(ctx, instr->operands.size(), ops))
3348 continue;
3349
3350 VOP3P_instruction* fneg = &info.instr->vop3p();
3351 if (fneg->clamp)
3352 continue;
3353 instr->operands[i] = fneg->operands[0];
3354
3355 /* opsel_lo/hi is either 0 or 1:
3356 * if 0 - pick selection from fneg->lo
3357 * if 1 - pick selection from fneg->hi
3358 */
3359 bool opsel_lo = (vop3p->opsel_lo >> i) & 1;
3360 bool opsel_hi = (vop3p->opsel_hi >> i) & 1;
3361 bool neg_lo = fneg->neg_lo[0] ^ fneg->neg_lo[1];
3362 bool neg_hi = fneg->neg_hi[0] ^ fneg->neg_hi[1];
3363 vop3p->neg_lo[i] ^= opsel_lo ? neg_hi : neg_lo;
3364 vop3p->neg_hi[i] ^= opsel_hi ? neg_hi : neg_lo;
3365 vop3p->opsel_lo ^= ((opsel_lo ? ~fneg->opsel_hi : fneg->opsel_lo) & 1) << i;
3366 vop3p->opsel_hi ^= ((opsel_hi ? ~fneg->opsel_hi : fneg->opsel_lo) & 1) << i;
3367
3368 if (--ctx.uses[fneg->definitions[0].tempId()])
3369 ctx.uses[fneg->operands[0].tempId()]++;
3370 }
3371 }
3372 }
3373
3374 if (instr->opcode == aco_opcode::v_pk_add_f16 || instr->opcode == aco_opcode::v_pk_add_u16) {
3375 bool fadd = instr->opcode == aco_opcode::v_pk_add_f16;
3376 if (fadd && instr->definitions[0].isPrecise())
3377 return;
3378
3379 Instruction* mul_instr = nullptr;
3380 unsigned add_op_idx = 0;
3381 uint8_t opsel_lo = 0, opsel_hi = 0;
3382 uint32_t uses = UINT32_MAX;
3383
3384 /* find the 'best' mul instruction to combine with the add */
3385 for (unsigned i = 0; i < 2; i++) {
3386 if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_vop3p())
3387 continue;
3388 ssa_info& info = ctx.info[instr->operands[i].tempId()];
3389 if (fadd) {
3390 if (info.instr->opcode != aco_opcode::v_pk_mul_f16 ||
3391 info.instr->definitions[0].isPrecise())
3392 continue;
3393 } else {
3394 if (info.instr->opcode != aco_opcode::v_pk_mul_lo_u16)
3395 continue;
3396 }
3397
3398 Operand op[3] = {info.instr->operands[0], info.instr->operands[1], instr->operands[1 - i]};
3399 if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
3400 continue;
3401
3402 /* no clamp allowed between mul and add */
3403 if (info.instr->vop3p().clamp)
3404 continue;
3405
3406 mul_instr = info.instr;
3407 add_op_idx = 1 - i;
3408 opsel_lo = (vop3p->opsel_lo >> i) & 1;
3409 opsel_hi = (vop3p->opsel_hi >> i) & 1;
3410 uses = ctx.uses[instr->operands[i].tempId()];
3411 }
3412
3413 if (!mul_instr)
3414 return;
3415
3416 /* convert to mad */
3417 Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1], instr->operands[add_op_idx]};
3418 ctx.uses[mul_instr->definitions[0].tempId()]--;
3419 if (ctx.uses[mul_instr->definitions[0].tempId()]) {
3420 if (op[0].isTemp())
3421 ctx.uses[op[0].tempId()]++;
3422 if (op[1].isTemp())
3423 ctx.uses[op[1].tempId()]++;
3424 }
3425
3426 /* turn packed mul+add into v_pk_fma_f16 */
3427 assert(mul_instr->isVOP3P());
3428 aco_opcode mad = fadd ? aco_opcode::v_pk_fma_f16 : aco_opcode::v_pk_mad_u16;
3429 aco_ptr<VOP3P_instruction> fma{
3430 create_instruction<VOP3P_instruction>(mad, Format::VOP3P, 3, 1)};
3431 VOP3P_instruction* mul = &mul_instr->vop3p();
3432 for (unsigned i = 0; i < 2; i++) {
3433 fma->operands[i] = op[i];
3434 fma->neg_lo[i] = mul->neg_lo[i];
3435 fma->neg_hi[i] = mul->neg_hi[i];
3436 }
3437 fma->operands[2] = op[2];
3438 fma->clamp = vop3p->clamp;
3439 fma->opsel_lo = mul->opsel_lo;
3440 fma->opsel_hi = mul->opsel_hi;
3441 propagate_swizzles(fma.get(), opsel_lo, opsel_hi);
3442 fma->opsel_lo |= (vop3p->opsel_lo << (2 - add_op_idx)) & 0x4;
3443 fma->opsel_hi |= (vop3p->opsel_hi << (2 - add_op_idx)) & 0x4;
3444 fma->neg_lo[2] = vop3p->neg_lo[add_op_idx];
3445 fma->neg_hi[2] = vop3p->neg_hi[add_op_idx];
3446 fma->neg_lo[1] = fma->neg_lo[1] ^ vop3p->neg_lo[1 - add_op_idx];
3447 fma->neg_hi[1] = fma->neg_hi[1] ^ vop3p->neg_hi[1 - add_op_idx];
3448 fma->definitions[0] = instr->definitions[0];
3449 instr = std::move(fma);
3450 ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
3451 return;
3452 }
3453 }
3454
3455 bool
can_use_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3456 can_use_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3457 {
3458 if (ctx.program->chip_class < GFX9)
3459 return false;
3460
3461 switch (instr->opcode) {
3462 case aco_opcode::v_add_f32:
3463 case aco_opcode::v_sub_f32:
3464 case aco_opcode::v_subrev_f32:
3465 case aco_opcode::v_mul_f32:
3466 case aco_opcode::v_fma_f32: break;
3467 case aco_opcode::v_fma_mix_f32:
3468 case aco_opcode::v_fma_mixlo_f16: return true;
3469 default: return false;
3470 }
3471
3472 if (instr->opcode == aco_opcode::v_fma_f32 && !ctx.program->dev.fused_mad_mix &&
3473 instr->definitions[0].isPrecise())
3474 return false;
3475
3476 if (instr->isVOP3())
3477 return !instr->vop3().omod && !(instr->vop3().opsel & 0x8);
3478
3479 return instr->format == Format::VOP2;
3480 }
3481
3482 void
to_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3483 to_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3484 {
3485 bool is_add = instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32;
3486
3487 aco_ptr<VOP3P_instruction> vop3p{
3488 create_instruction<VOP3P_instruction>(aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1)};
3489
3490 vop3p->opsel_lo = instr->isVOP3() ? ((instr->vop3().opsel & 0x7) << (is_add ? 1 : 0)) : 0x0;
3491 vop3p->opsel_hi = 0x0;
3492 for (unsigned i = 0; i < instr->operands.size(); i++) {
3493 vop3p->operands[is_add + i] = instr->operands[i];
3494 vop3p->neg_lo[is_add + i] = instr->isVOP3() && instr->vop3().neg[i];
3495 vop3p->neg_lo[is_add + i] |= instr->isSDWA() && instr->sdwa().neg[i];
3496 vop3p->neg_hi[is_add + i] = instr->isVOP3() && instr->vop3().abs[i];
3497 vop3p->neg_hi[is_add + i] |= instr->isSDWA() && instr->sdwa().abs[i];
3498 vop3p->opsel_lo |= (instr->isSDWA() && instr->sdwa().sel[i].offset()) << (is_add + i);
3499 }
3500 if (instr->opcode == aco_opcode::v_mul_f32) {
3501 vop3p->opsel_hi &= 0x3;
3502 vop3p->operands[2] = Operand::zero();
3503 vop3p->neg_lo[2] = true;
3504 } else if (is_add) {
3505 vop3p->opsel_hi &= 0x6;
3506 vop3p->operands[0] = Operand::c32(0x3f800000);
3507 if (instr->opcode == aco_opcode::v_sub_f32)
3508 vop3p->neg_lo[2] ^= true;
3509 else if (instr->opcode == aco_opcode::v_subrev_f32)
3510 vop3p->neg_lo[1] ^= true;
3511 }
3512 vop3p->definitions[0] = instr->definitions[0];
3513 vop3p->clamp = instr->isVOP3() && instr->vop3().clamp;
3514 instr = std::move(vop3p);
3515
3516 ctx.info[instr->definitions[0].tempId()].label &= label_f2f16 | label_clamp | label_mul;
3517 if (ctx.info[instr->definitions[0].tempId()].label & label_mul)
3518 ctx.info[instr->definitions[0].tempId()].instr = instr.get();
3519 }
3520
3521 bool
combine_output_conversion(opt_ctx & ctx,aco_ptr<Instruction> & instr)3522 combine_output_conversion(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3523 {
3524 ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3525 if (!def_info.is_f2f16())
3526 return false;
3527 Instruction* conv = def_info.instr;
3528
3529 if (!can_use_mad_mix(ctx, instr) || ctx.uses[instr->definitions[0].tempId()] != 1)
3530 return false;
3531
3532 if (!ctx.uses[conv->definitions[0].tempId()])
3533 return false;
3534
3535 if (conv->usesModifiers())
3536 return false;
3537
3538 if (!instr->isVOP3P())
3539 to_mad_mix(ctx, instr);
3540
3541 instr->opcode = aco_opcode::v_fma_mixlo_f16;
3542 instr->definitions[0].swapTemp(conv->definitions[0]);
3543 if (conv->definitions[0].isPrecise())
3544 instr->definitions[0].setPrecise(true);
3545 ctx.info[instr->definitions[0].tempId()].label &= label_clamp;
3546 ctx.uses[conv->definitions[0].tempId()]--;
3547
3548 return true;
3549 }
3550
3551 void
combine_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3552 combine_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3553 {
3554 if (!can_use_mad_mix(ctx, instr))
3555 return;
3556
3557 for (unsigned i = 0; i < instr->operands.size(); i++) {
3558 if (!instr->operands[i].isTemp())
3559 continue;
3560 Temp tmp = instr->operands[i].getTemp();
3561 if (!ctx.info[tmp.id()].is_f2f32())
3562 continue;
3563
3564 Instruction* conv = ctx.info[tmp.id()].instr;
3565 if (conv->isSDWA() && (conv->sdwa().dst_sel.size() != 4 || conv->sdwa().sel[0].size() != 2 ||
3566 conv->sdwa().clamp || conv->sdwa().omod)) {
3567 continue;
3568 } else if (conv->isVOP3() && (conv->vop3().clamp || conv->vop3().omod)) {
3569 continue;
3570 } else if (conv->isDPP()) {
3571 continue;
3572 }
3573
3574 if (get_operand_size(instr, i) != 32)
3575 continue;
3576
3577 /* Conversion to VOP3P will add inline constant operands, but that shouldn't affect
3578 * check_vop3_operands(). */
3579 Operand op[3];
3580 for (unsigned j = 0; j < instr->operands.size(); j++)
3581 op[j] = instr->operands[j];
3582 op[i] = conv->operands[0];
3583 if (!check_vop3_operands(ctx, instr->operands.size(), op))
3584 continue;
3585
3586 if (!instr->isVOP3P()) {
3587 bool is_add =
3588 instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32;
3589 to_mad_mix(ctx, instr);
3590 i += is_add;
3591 }
3592
3593 if (--ctx.uses[tmp.id()])
3594 ctx.uses[conv->operands[0].tempId()]++;
3595 instr->operands[i].setTemp(conv->operands[0].getTemp());
3596 if (conv->definitions[0].isPrecise())
3597 instr->definitions[0].setPrecise(true);
3598 instr->vop3p().opsel_hi ^= 1u << i;
3599 if (conv->isSDWA() && conv->sdwa().sel[0].offset() == 2)
3600 instr->vop3p().opsel_lo |= 1u << i;
3601 bool neg = (conv->isVOP3() && conv->vop3().neg[0]) || (conv->isSDWA() && conv->sdwa().neg[0]);
3602 bool abs = (conv->isVOP3() && conv->vop3().abs[0]) || (conv->isSDWA() && conv->sdwa().abs[0]);
3603 if (!instr->vop3p().neg_hi[i]) {
3604 instr->vop3p().neg_lo[i] ^= neg;
3605 instr->vop3p().neg_hi[i] = abs;
3606 }
3607 }
3608 }
3609
3610 // TODO: we could possibly move the whole label_instruction pass to combine_instruction:
3611 // this would mean that we'd have to fix the instruction uses while value propagation
3612
3613 void
combine_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)3614 combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3615 {
3616 if (instr->definitions.empty() || is_dead(ctx.uses, instr.get()))
3617 return;
3618
3619 if (instr->isVALU()) {
3620 /* Apply SDWA. Do this after label_instruction() so it can remove
3621 * label_extract if not all instructions can take SDWA. */
3622 for (unsigned i = 0; i < instr->operands.size(); i++) {
3623 Operand& op = instr->operands[i];
3624 if (!op.isTemp())
3625 continue;
3626 ssa_info& info = ctx.info[op.tempId()];
3627 if (!info.is_extract())
3628 continue;
3629 /* if there are that many uses, there are likely better combinations */
3630 // TODO: delay applying extract to a point where we know better
3631 if (ctx.uses[op.tempId()] > 4) {
3632 info.label &= ~label_extract;
3633 continue;
3634 }
3635 if (info.is_extract() &&
3636 (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
3637 instr->operands[i].getTemp().type() == RegType::sgpr) &&
3638 can_apply_extract(ctx, instr, i, info)) {
3639 /* Increase use count of the extract's operand if the extract still has uses. */
3640 apply_extract(ctx, instr, i, info);
3641 if (--ctx.uses[instr->operands[i].tempId()])
3642 ctx.uses[info.instr->operands[0].tempId()]++;
3643 instr->operands[i].setTemp(info.instr->operands[0].getTemp());
3644 }
3645 }
3646
3647 if (can_apply_sgprs(ctx, instr))
3648 apply_sgprs(ctx, instr);
3649 combine_mad_mix(ctx, instr);
3650 while (apply_omod_clamp(ctx, instr) | combine_output_conversion(ctx, instr))
3651 ;
3652 apply_insert(ctx, instr);
3653 }
3654
3655 if (instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_mix_f32 &&
3656 instr->opcode != aco_opcode::v_fma_mixlo_f16)
3657 return combine_vop3p(ctx, instr);
3658
3659 if (ctx.info[instr->definitions[0].tempId()].is_vcc_hint()) {
3660 instr->definitions[0].setHint(vcc);
3661 }
3662
3663 if (instr->isSDWA() || instr->isDPP())
3664 return;
3665
3666 if (instr->opcode == aco_opcode::p_extract) {
3667 ssa_info& info = ctx.info[instr->operands[0].tempId()];
3668 if (info.is_extract() && can_apply_extract(ctx, instr, 0, info)) {
3669 apply_extract(ctx, instr, 0, info);
3670 if (--ctx.uses[instr->operands[0].tempId()])
3671 ctx.uses[info.instr->operands[0].tempId()]++;
3672 instr->operands[0].setTemp(info.instr->operands[0].getTemp());
3673 }
3674
3675 apply_ds_extract(ctx, instr);
3676 }
3677
3678 /* TODO: There are still some peephole optimizations that could be done:
3679 * - abs(a - b) -> s_absdiff_i32
3680 * - various patterns for s_bitcmp{0,1}_b32 and s_bitset{0,1}_b32
3681 * - patterns for v_alignbit_b32 and v_alignbyte_b32
3682 * These aren't probably too interesting though.
3683 * There are also patterns for v_cmp_class_f{16,32,64}. This is difficult but
3684 * probably more useful than the previously mentioned optimizations.
3685 * The various comparison optimizations also currently only work with 32-bit
3686 * floats. */
3687
3688 /* neg(mul(a, b)) -> mul(neg(a), b), abs(mul(a, b)) -> mul(abs(a), abs(b)) */
3689 if ((ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)) &&
3690 ctx.uses[instr->operands[1].tempId()] == 1) {
3691 Temp val = ctx.info[instr->definitions[0].tempId()].temp;
3692
3693 if (!ctx.info[val.id()].is_mul())
3694 return;
3695
3696 Instruction* mul_instr = ctx.info[val.id()].instr;
3697
3698 if (mul_instr->operands[0].isLiteral())
3699 return;
3700 if (mul_instr->isVOP3() && mul_instr->vop3().clamp)
3701 return;
3702 if (mul_instr->isSDWA() || mul_instr->isDPP() || mul_instr->isVOP3P())
3703 return;
3704 if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32 &&
3705 ctx.fp_mode.preserve_signed_zero_inf_nan32)
3706 return;
3707 if (mul_instr->definitions[0].bytes() != instr->definitions[0].bytes())
3708 return;
3709
3710 /* convert to mul(neg(a), b), mul(abs(a), abs(b)) or mul(neg(abs(a)), abs(b)) */
3711 ctx.uses[mul_instr->definitions[0].tempId()]--;
3712 Definition def = instr->definitions[0];
3713 bool is_neg = ctx.info[instr->definitions[0].tempId()].is_neg();
3714 bool is_abs = ctx.info[instr->definitions[0].tempId()].is_abs();
3715 instr.reset(
3716 create_instruction<VOP3_instruction>(mul_instr->opcode, asVOP3(Format::VOP2), 2, 1));
3717 instr->operands[0] = mul_instr->operands[0];
3718 instr->operands[1] = mul_instr->operands[1];
3719 instr->definitions[0] = def;
3720 VOP3_instruction& new_mul = instr->vop3();
3721 if (mul_instr->isVOP3()) {
3722 VOP3_instruction& mul = mul_instr->vop3();
3723 new_mul.neg[0] = mul.neg[0];
3724 new_mul.neg[1] = mul.neg[1];
3725 new_mul.abs[0] = mul.abs[0];
3726 new_mul.abs[1] = mul.abs[1];
3727 new_mul.omod = mul.omod;
3728 }
3729 if (is_abs) {
3730 new_mul.neg[0] = new_mul.neg[1] = false;
3731 new_mul.abs[0] = new_mul.abs[1] = true;
3732 }
3733 new_mul.neg[0] ^= is_neg;
3734 new_mul.clamp = false;
3735
3736 ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
3737 return;
3738 }
3739
3740 /* combine mul+add -> mad */
3741 bool is_add_mix =
3742 (instr->opcode == aco_opcode::v_fma_mix_f32 ||
3743 instr->opcode == aco_opcode::v_fma_mixlo_f16) &&
3744 !instr->vop3p().neg_lo[0] &&
3745 ((instr->operands[0].constantEquals(0x3f800000) && (instr->vop3p().opsel_hi & 0x1) == 0) ||
3746 (instr->operands[0].constantEquals(0x3C00) && (instr->vop3p().opsel_hi & 0x1) &&
3747 !(instr->vop3p().opsel_lo & 0x1)));
3748 bool mad32 = instr->opcode == aco_opcode::v_add_f32 || instr->opcode == aco_opcode::v_sub_f32 ||
3749 instr->opcode == aco_opcode::v_subrev_f32;
3750 bool mad16 = instr->opcode == aco_opcode::v_add_f16 || instr->opcode == aco_opcode::v_sub_f16 ||
3751 instr->opcode == aco_opcode::v_subrev_f16;
3752 bool mad64 = instr->opcode == aco_opcode::v_add_f64;
3753 if (is_add_mix || mad16 || mad32 || mad64) {
3754 Instruction* mul_instr = nullptr;
3755 unsigned add_op_idx = 0;
3756 uint32_t uses = UINT32_MAX;
3757 bool emit_fma = false;
3758 /* find the 'best' mul instruction to combine with the add */
3759 for (unsigned i = is_add_mix ? 1 : 0; i < instr->operands.size(); i++) {
3760 if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_mul())
3761 continue;
3762 ssa_info& info = ctx.info[instr->operands[i].tempId()];
3763
3764 /* no clamp/omod allowed between mul and add */
3765 if (info.instr->isVOP3() && (info.instr->vop3().clamp || info.instr->vop3().omod))
3766 continue;
3767 if (info.instr->isVOP3P() && info.instr->vop3p().clamp)
3768 continue;
3769 /* v_fma_mix_f32/etc can't do omod */
3770 if (info.instr->isVOP3P() && instr->isVOP3() && instr->vop3().omod)
3771 continue;
3772 /* don't promote fp16 to fp32 or remove fp32->fp16->fp32 conversions */
3773 if (is_add_mix && info.instr->definitions[0].bytes() == 2)
3774 continue;
3775
3776 if (get_operand_size(instr, i) != info.instr->definitions[0].bytes() * 8)
3777 continue;
3778
3779 bool legacy = info.instr->opcode == aco_opcode::v_mul_legacy_f32;
3780 bool mad_mix = is_add_mix || info.instr->isVOP3P();
3781
3782 bool has_fma = mad16 || mad64 || (legacy && ctx.program->chip_class >= GFX10_3) ||
3783 (mad32 && !legacy && !mad_mix && ctx.program->dev.has_fast_fma32) ||
3784 (mad_mix && ctx.program->dev.fused_mad_mix);
3785 bool has_mad = mad_mix ? !ctx.program->dev.fused_mad_mix
3786 : ((mad32 && ctx.program->chip_class < GFX10_3) ||
3787 (mad16 && ctx.program->chip_class <= GFX9));
3788 bool can_use_fma = has_fma && !info.instr->definitions[0].isPrecise() &&
3789 !instr->definitions[0].isPrecise();
3790 bool can_use_mad =
3791 has_mad && (mad_mix || mad32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == 0;
3792 if (mad_mix && legacy)
3793 continue;
3794 if (!can_use_fma && !can_use_mad)
3795 continue;
3796
3797 unsigned candidate_add_op_idx = is_add_mix ? (3 - i) : (1 - i);
3798 Operand op[3] = {info.instr->operands[0], info.instr->operands[1],
3799 instr->operands[candidate_add_op_idx]};
3800 if (info.instr->isSDWA() || info.instr->isDPP() || !check_vop3_operands(ctx, 3, op) ||
3801 ctx.uses[instr->operands[i].tempId()] > uses)
3802 continue;
3803
3804 if (ctx.uses[instr->operands[i].tempId()] == uses) {
3805 unsigned cur_idx = mul_instr->definitions[0].tempId();
3806 unsigned new_idx = info.instr->definitions[0].tempId();
3807 if (cur_idx > new_idx)
3808 continue;
3809 }
3810
3811 mul_instr = info.instr;
3812 add_op_idx = candidate_add_op_idx;
3813 uses = ctx.uses[instr->operands[i].tempId()];
3814 emit_fma = !can_use_mad;
3815 }
3816
3817 if (mul_instr) {
3818 /* turn mul+add into v_mad/v_fma */
3819 Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1],
3820 instr->operands[add_op_idx]};
3821 ctx.uses[mul_instr->definitions[0].tempId()]--;
3822 if (ctx.uses[mul_instr->definitions[0].tempId()]) {
3823 if (op[0].isTemp())
3824 ctx.uses[op[0].tempId()]++;
3825 if (op[1].isTemp())
3826 ctx.uses[op[1].tempId()]++;
3827 }
3828
3829 bool neg[3] = {false, false, false};
3830 bool abs[3] = {false, false, false};
3831 unsigned omod = 0;
3832 bool clamp = false;
3833 uint8_t opsel_lo = 0;
3834 uint8_t opsel_hi = 0;
3835
3836 if (mul_instr->isVOP3()) {
3837 VOP3_instruction& vop3 = mul_instr->vop3();
3838 neg[0] = vop3.neg[0];
3839 neg[1] = vop3.neg[1];
3840 abs[0] = vop3.abs[0];
3841 abs[1] = vop3.abs[1];
3842 } else if (mul_instr->isVOP3P()) {
3843 VOP3P_instruction& vop3p = mul_instr->vop3p();
3844 neg[0] = vop3p.neg_lo[0];
3845 neg[1] = vop3p.neg_lo[1];
3846 abs[0] = vop3p.neg_hi[0];
3847 abs[1] = vop3p.neg_hi[1];
3848 opsel_lo = vop3p.opsel_lo & 0x3;
3849 opsel_hi = vop3p.opsel_hi & 0x3;
3850 }
3851
3852 if (instr->isVOP3()) {
3853 VOP3_instruction& vop3 = instr->vop3();
3854 neg[2] = vop3.neg[add_op_idx];
3855 abs[2] = vop3.abs[add_op_idx];
3856 omod = vop3.omod;
3857 clamp = vop3.clamp;
3858 /* abs of the multiplication result */
3859 if (vop3.abs[1 - add_op_idx]) {
3860 neg[0] = false;
3861 neg[1] = false;
3862 abs[0] = true;
3863 abs[1] = true;
3864 }
3865 /* neg of the multiplication result */
3866 neg[1] = neg[1] ^ vop3.neg[1 - add_op_idx];
3867 } else if (instr->isVOP3P()) {
3868 VOP3P_instruction& vop3p = instr->vop3p();
3869 neg[2] = vop3p.neg_lo[add_op_idx];
3870 abs[2] = vop3p.neg_hi[add_op_idx];
3871 opsel_lo |= vop3p.opsel_lo & (1 << add_op_idx) ? 0x4 : 0x0;
3872 opsel_hi |= vop3p.opsel_hi & (1 << add_op_idx) ? 0x4 : 0x0;
3873 clamp = vop3p.clamp;
3874 /* abs of the multiplication result */
3875 if (vop3p.neg_hi[3 - add_op_idx]) {
3876 neg[0] = false;
3877 neg[1] = false;
3878 abs[0] = true;
3879 abs[1] = true;
3880 }
3881 /* neg of the multiplication result */
3882 neg[1] = neg[1] ^ vop3p.neg_lo[3 - add_op_idx];
3883 }
3884
3885 if (instr->opcode == aco_opcode::v_sub_f32 || instr->opcode == aco_opcode::v_sub_f16)
3886 neg[1 + add_op_idx] = neg[1 + add_op_idx] ^ true;
3887 else if (instr->opcode == aco_opcode::v_subrev_f32 ||
3888 instr->opcode == aco_opcode::v_subrev_f16)
3889 neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
3890
3891 aco_ptr<Instruction> add_instr = std::move(instr);
3892 if (add_instr->isVOP3P() || mul_instr->isVOP3P()) {
3893 assert(!omod);
3894
3895 aco_opcode mad_op = add_instr->definitions[0].bytes() == 2 ? aco_opcode::v_fma_mixlo_f16
3896 : aco_opcode::v_fma_mix_f32;
3897 aco_ptr<VOP3P_instruction> mad{
3898 create_instruction<VOP3P_instruction>(mad_op, Format::VOP3P, 3, 1)};
3899 for (unsigned i = 0; i < 3; i++) {
3900 mad->operands[i] = op[i];
3901 mad->neg_lo[i] = neg[i];
3902 mad->neg_hi[i] = abs[i];
3903 }
3904 mad->clamp = clamp;
3905 mad->opsel_lo = opsel_lo;
3906 mad->opsel_hi = opsel_hi;
3907
3908 instr = std::move(mad);
3909 } else {
3910 aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
3911 if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
3912 assert(emit_fma == (ctx.program->chip_class >= GFX10_3));
3913 mad_op = emit_fma ? aco_opcode::v_fma_legacy_f32 : aco_opcode::v_mad_legacy_f32;
3914 } else if (mad16) {
3915 mad_op = emit_fma ? (ctx.program->chip_class == GFX8 ? aco_opcode::v_fma_legacy_f16
3916 : aco_opcode::v_fma_f16)
3917 : (ctx.program->chip_class == GFX8 ? aco_opcode::v_mad_legacy_f16
3918 : aco_opcode::v_mad_f16);
3919 } else if (mad64) {
3920 mad_op = aco_opcode::v_fma_f64;
3921 }
3922
3923 aco_ptr<VOP3_instruction> mad{
3924 create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 1)};
3925 for (unsigned i = 0; i < 3; i++) {
3926 mad->operands[i] = op[i];
3927 mad->neg[i] = neg[i];
3928 mad->abs[i] = abs[i];
3929 }
3930 mad->omod = omod;
3931 mad->clamp = clamp;
3932
3933 instr = std::move(mad);
3934 }
3935 instr->definitions[0] = add_instr->definitions[0];
3936
3937 /* mark this ssa_def to be re-checked for profitability and literals */
3938 ctx.mad_infos.emplace_back(std::move(add_instr), mul_instr->definitions[0].tempId());
3939 ctx.info[instr->definitions[0].tempId()].set_mad(instr.get(), ctx.mad_infos.size() - 1);
3940 return;
3941 }
3942 }
3943 /* v_mul_f32(v_cndmask_b32(0, 1.0, cond), a) -> v_cndmask_b32(0, a, cond) */
3944 else if (((instr->opcode == aco_opcode::v_mul_f32 &&
3945 !ctx.fp_mode.preserve_signed_zero_inf_nan32) ||
3946 instr->opcode == aco_opcode::v_mul_legacy_f32) &&
3947 !instr->usesModifiers() && !ctx.fp_mode.must_flush_denorms32) {
3948 for (unsigned i = 0; i < 2; i++) {
3949 if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2f() &&
3950 ctx.uses[instr->operands[i].tempId()] == 1 && instr->operands[!i].isTemp() &&
3951 instr->operands[!i].getTemp().type() == RegType::vgpr) {
3952 ctx.uses[instr->operands[i].tempId()]--;
3953 ctx.uses[ctx.info[instr->operands[i].tempId()].temp.id()]++;
3954
3955 aco_ptr<VOP2_instruction> new_instr{
3956 create_instruction<VOP2_instruction>(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1)};
3957 new_instr->operands[0] = Operand::zero();
3958 new_instr->operands[1] = instr->operands[!i];
3959 new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
3960 new_instr->definitions[0] = instr->definitions[0];
3961 instr = std::move(new_instr);
3962 ctx.info[instr->definitions[0].tempId()].label = 0;
3963 return;
3964 }
3965 }
3966 } else if (instr->opcode == aco_opcode::v_or_b32 && ctx.program->chip_class >= GFX9) {
3967 if (combine_three_valu_op(ctx, instr, aco_opcode::s_or_b32, aco_opcode::v_or3_b32, "012",
3968 1 | 2)) {
3969 } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_or_b32, aco_opcode::v_or3_b32,
3970 "012", 1 | 2)) {
3971 } else if (combine_add_or_then_and_lshl(ctx, instr)) {
3972 }
3973 } else if (instr->opcode == aco_opcode::v_xor_b32 && ctx.program->chip_class >= GFX10) {
3974 if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xor3_b32, "012",
3975 1 | 2)) {
3976 } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xor3_b32,
3977 "012", 1 | 2)) {
3978 }
3979 } else if (instr->opcode == aco_opcode::v_add_u16) {
3980 combine_three_valu_op(
3981 ctx, instr, aco_opcode::v_mul_lo_u16,
3982 ctx.program->chip_class == GFX8 ? aco_opcode::v_mad_legacy_u16 : aco_opcode::v_mad_u16,
3983 "120", 1 | 2);
3984 } else if (instr->opcode == aco_opcode::v_add_u16_e64) {
3985 combine_three_valu_op(ctx, instr, aco_opcode::v_mul_lo_u16_e64, aco_opcode::v_mad_u16, "120",
3986 1 | 2);
3987 } else if (instr->opcode == aco_opcode::v_add_u32) {
3988 if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
3989 } else if (combine_add_bcnt(ctx, instr)) {
3990 } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
3991 aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
3992 } else if (ctx.program->chip_class >= GFX9 && !instr->usesModifiers()) {
3993 if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xad_u32, "120",
3994 1 | 2)) {
3995 } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xad_u32,
3996 "120", 1 | 2)) {
3997 } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_i32, aco_opcode::v_add3_u32,
3998 "012", 1 | 2)) {
3999 } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_u32, aco_opcode::v_add3_u32,
4000 "012", 1 | 2)) {
4001 } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add3_u32,
4002 "012", 1 | 2)) {
4003 } else if (combine_add_or_then_and_lshl(ctx, instr)) {
4004 }
4005 }
4006 } else if (instr->opcode == aco_opcode::v_add_co_u32 ||
4007 instr->opcode == aco_opcode::v_add_co_u32_e64) {
4008 bool carry_out = ctx.uses[instr->definitions[1].tempId()] > 0;
4009 if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
4010 } else if (!carry_out && combine_add_bcnt(ctx, instr)) {
4011 } else if (!carry_out && combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
4012 aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
4013 } else if (!carry_out && combine_add_lshl(ctx, instr, false)) {
4014 }
4015 } else if (instr->opcode == aco_opcode::v_sub_u32 || instr->opcode == aco_opcode::v_sub_co_u32 ||
4016 instr->opcode == aco_opcode::v_sub_co_u32_e64) {
4017 bool carry_out =
4018 instr->opcode != aco_opcode::v_sub_u32 && ctx.uses[instr->definitions[1].tempId()] > 0;
4019 if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 2)) {
4020 } else if (!carry_out && combine_add_lshl(ctx, instr, true)) {
4021 }
4022 } else if (instr->opcode == aco_opcode::v_subrev_u32 ||
4023 instr->opcode == aco_opcode::v_subrev_co_u32 ||
4024 instr->opcode == aco_opcode::v_subrev_co_u32_e64) {
4025 combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 1);
4026 } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && ctx.program->chip_class >= GFX9) {
4027 combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add_lshl_u32, "120",
4028 2);
4029 } else if ((instr->opcode == aco_opcode::s_add_u32 || instr->opcode == aco_opcode::s_add_i32) &&
4030 ctx.program->chip_class >= GFX9) {
4031 combine_salu_lshl_add(ctx, instr);
4032 } else if (instr->opcode == aco_opcode::s_not_b32 || instr->opcode == aco_opcode::s_not_b64) {
4033 combine_salu_not_bitwise(ctx, instr);
4034 } else if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_or_b32 ||
4035 instr->opcode == aco_opcode::s_and_b64 || instr->opcode == aco_opcode::s_or_b64) {
4036 if (combine_ordering_test(ctx, instr)) {
4037 } else if (combine_comparison_ordering(ctx, instr)) {
4038 } else if (combine_constant_comparison_ordering(ctx, instr)) {
4039 } else if (combine_salu_n2(ctx, instr)) {
4040 }
4041 } else if (instr->opcode == aco_opcode::v_and_b32) {
4042 combine_and_subbrev(ctx, instr);
4043 } else if (instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) {
4044 /* set existing v_fma_f32 with label_mad so we can create v_fmamk_f32/v_fmaak_f32.
4045 * since ctx.uses[mad_info::mul_temp_id] is always 0, we don't have to worry about
4046 * select_instruction() using mad_info::add_instr.
4047 */
4048 ctx.mad_infos.emplace_back(nullptr, 0);
4049 ctx.info[instr->definitions[0].tempId()].set_mad(instr.get(), ctx.mad_infos.size() - 1);
4050 } else {
4051 aco_opcode min, max, min3, max3, med3;
4052 bool some_gfx9_only;
4053 if (get_minmax_info(instr->opcode, &min, &max, &min3, &max3, &med3, &some_gfx9_only) &&
4054 (!some_gfx9_only || ctx.program->chip_class >= GFX9)) {
4055 if (combine_minmax(ctx, instr, instr->opcode == min ? max : min,
4056 instr->opcode == min ? min3 : max3)) {
4057 } else {
4058 combine_clamp(ctx, instr, min, max, med3);
4059 }
4060 }
4061 }
4062
4063 /* do this after combine_salu_n2() */
4064 if (instr->opcode == aco_opcode::s_andn2_b32 || instr->opcode == aco_opcode::s_andn2_b64)
4065 combine_inverse_comparison(ctx, instr);
4066 }
4067
4068 bool
to_uniform_bool_instr(opt_ctx & ctx,aco_ptr<Instruction> & instr)4069 to_uniform_bool_instr(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4070 {
4071 /* Check every operand to make sure they are suitable. */
4072 for (Operand& op : instr->operands) {
4073 if (!op.isTemp())
4074 return false;
4075 if (!ctx.info[op.tempId()].is_uniform_bool() && !ctx.info[op.tempId()].is_uniform_bitwise())
4076 return false;
4077 }
4078
4079 switch (instr->opcode) {
4080 case aco_opcode::s_and_b32:
4081 case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_and_b32; break;
4082 case aco_opcode::s_or_b32:
4083 case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_or_b32; break;
4084 case aco_opcode::s_xor_b32:
4085 case aco_opcode::s_xor_b64: instr->opcode = aco_opcode::s_absdiff_i32; break;
4086 default:
4087 /* Don't transform other instructions. They are very unlikely to appear here. */
4088 return false;
4089 }
4090
4091 for (Operand& op : instr->operands) {
4092 ctx.uses[op.tempId()]--;
4093
4094 if (ctx.info[op.tempId()].is_uniform_bool()) {
4095 /* Just use the uniform boolean temp. */
4096 op.setTemp(ctx.info[op.tempId()].temp);
4097 } else if (ctx.info[op.tempId()].is_uniform_bitwise()) {
4098 /* Use the SCC definition of the predecessor instruction.
4099 * This allows the predecessor to get picked up by the same optimization (if it has no
4100 * divergent users), and it also makes sure that the current instruction will keep working
4101 * even if the predecessor won't be transformed.
4102 */
4103 Instruction* pred_instr = ctx.info[op.tempId()].instr;
4104 assert(pred_instr->definitions.size() >= 2);
4105 assert(pred_instr->definitions[1].isFixed() &&
4106 pred_instr->definitions[1].physReg() == scc);
4107 op.setTemp(pred_instr->definitions[1].getTemp());
4108 } else {
4109 unreachable("Invalid operand on uniform bitwise instruction.");
4110 }
4111
4112 ctx.uses[op.tempId()]++;
4113 }
4114
4115 instr->definitions[0].setTemp(Temp(instr->definitions[0].tempId(), s1));
4116 assert(instr->operands[0].regClass() == s1);
4117 assert(instr->operands[1].regClass() == s1);
4118 return true;
4119 }
4120
4121 void
select_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)4122 select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4123 {
4124 const uint32_t threshold = 4;
4125
4126 if (is_dead(ctx.uses, instr.get())) {
4127 instr.reset();
4128 return;
4129 }
4130
4131 /* convert split_vector into a copy or extract_vector if only one definition is ever used */
4132 if (instr->opcode == aco_opcode::p_split_vector) {
4133 unsigned num_used = 0;
4134 unsigned idx = 0;
4135 unsigned split_offset = 0;
4136 for (unsigned i = 0, offset = 0; i < instr->definitions.size();
4137 offset += instr->definitions[i++].bytes()) {
4138 if (ctx.uses[instr->definitions[i].tempId()]) {
4139 num_used++;
4140 idx = i;
4141 split_offset = offset;
4142 }
4143 }
4144 bool done = false;
4145 if (num_used == 1 && ctx.info[instr->operands[0].tempId()].is_vec() &&
4146 ctx.uses[instr->operands[0].tempId()] == 1) {
4147 Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
4148
4149 unsigned off = 0;
4150 Operand op;
4151 for (Operand& vec_op : vec->operands) {
4152 if (off == split_offset) {
4153 op = vec_op;
4154 break;
4155 }
4156 off += vec_op.bytes();
4157 }
4158 if (off != instr->operands[0].bytes() && op.bytes() == instr->definitions[idx].bytes()) {
4159 ctx.uses[instr->operands[0].tempId()]--;
4160 for (Operand& vec_op : vec->operands) {
4161 if (vec_op.isTemp())
4162 ctx.uses[vec_op.tempId()]--;
4163 }
4164 if (op.isTemp())
4165 ctx.uses[op.tempId()]++;
4166
4167 aco_ptr<Pseudo_instruction> extract{create_instruction<Pseudo_instruction>(
4168 aco_opcode::p_create_vector, Format::PSEUDO, 1, 1)};
4169 extract->operands[0] = op;
4170 extract->definitions[0] = instr->definitions[idx];
4171 instr = std::move(extract);
4172
4173 done = true;
4174 }
4175 }
4176
4177 if (!done && num_used == 1 &&
4178 instr->operands[0].bytes() % instr->definitions[idx].bytes() == 0 &&
4179 split_offset % instr->definitions[idx].bytes() == 0) {
4180 aco_ptr<Pseudo_instruction> extract{create_instruction<Pseudo_instruction>(
4181 aco_opcode::p_extract_vector, Format::PSEUDO, 2, 1)};
4182 extract->operands[0] = instr->operands[0];
4183 extract->operands[1] =
4184 Operand::c32((uint32_t)split_offset / instr->definitions[idx].bytes());
4185 extract->definitions[0] = instr->definitions[idx];
4186 instr = std::move(extract);
4187 }
4188 }
4189
4190 mad_info* mad_info = NULL;
4191 if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
4192 mad_info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].instr->pass_flags];
4193 /* re-check mad instructions */
4194 if (ctx.uses[mad_info->mul_temp_id] && mad_info->add_instr) {
4195 ctx.uses[mad_info->mul_temp_id]++;
4196 if (instr->operands[0].isTemp())
4197 ctx.uses[instr->operands[0].tempId()]--;
4198 if (instr->operands[1].isTemp())
4199 ctx.uses[instr->operands[1].tempId()]--;
4200 instr.swap(mad_info->add_instr);
4201 mad_info = NULL;
4202 }
4203 /* check literals */
4204 else if (!instr->usesModifiers() && !instr->isVOP3P() &&
4205 instr->opcode != aco_opcode::v_fma_f64 &&
4206 instr->opcode != aco_opcode::v_mad_legacy_f32 &&
4207 instr->opcode != aco_opcode::v_fma_legacy_f32) {
4208 /* FMA can only take literals on GFX10+ */
4209 if ((instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) &&
4210 ctx.program->chip_class < GFX10)
4211 return;
4212 /* There are no v_fmaak_legacy_f16/v_fmamk_legacy_f16 and on chips where VOP3 can take
4213 * literals (GFX10+), these instructions don't exist.
4214 */
4215 if (instr->opcode == aco_opcode::v_fma_legacy_f16)
4216 return;
4217
4218 uint32_t literal_idx = 0;
4219 uint32_t literal_uses = UINT32_MAX;
4220
4221 /* Try using v_madak/v_fmaak */
4222 if (instr->operands[2].isTemp() &&
4223 ctx.info[instr->operands[2].tempId()].is_literal(get_operand_size(instr, 2))) {
4224 bool has_sgpr = false;
4225 bool has_vgpr = false;
4226 for (unsigned i = 0; i < 2; i++) {
4227 if (!instr->operands[i].isTemp())
4228 continue;
4229 has_sgpr |= instr->operands[i].getTemp().type() == RegType::sgpr;
4230 has_vgpr |= instr->operands[i].getTemp().type() == RegType::vgpr;
4231 }
4232 /* Encoding limitations requires a VGPR operand. The constant bus limitations before
4233 * GFX10 disallows SGPRs.
4234 */
4235 if ((!has_sgpr || ctx.program->chip_class >= GFX10) && has_vgpr) {
4236 literal_idx = 2;
4237 literal_uses = ctx.uses[instr->operands[2].tempId()];
4238 }
4239 }
4240
4241 /* Try using v_madmk/v_fmamk */
4242 /* Encoding limitations requires a VGPR operand. */
4243 if (instr->operands[2].isTemp() && instr->operands[2].getTemp().type() == RegType::vgpr) {
4244 for (unsigned i = 0; i < 2; i++) {
4245 if (!instr->operands[i].isTemp())
4246 continue;
4247
4248 /* The constant bus limitations before GFX10 disallows SGPRs. */
4249 if (ctx.program->chip_class < GFX10 && instr->operands[!i].isTemp() &&
4250 instr->operands[!i].getTemp().type() == RegType::sgpr)
4251 continue;
4252
4253 if (ctx.info[instr->operands[i].tempId()].is_literal(get_operand_size(instr, i)) &&
4254 ctx.uses[instr->operands[i].tempId()] < literal_uses) {
4255 literal_idx = i;
4256 literal_uses = ctx.uses[instr->operands[i].tempId()];
4257 }
4258 }
4259 }
4260
4261 /* Limit the number of literals to apply to not increase the code
4262 * size too much, but always apply literals for v_mad->v_madak
4263 * because both instructions are 64-bit and this doesn't increase
4264 * code size.
4265 * TODO: try to apply the literals earlier to lower the number of
4266 * uses below threshold
4267 */
4268 if (literal_uses < threshold || literal_idx == 2) {
4269 ctx.uses[instr->operands[literal_idx].tempId()]--;
4270 mad_info->check_literal = true;
4271 mad_info->literal_idx = literal_idx;
4272 return;
4273 }
4274 }
4275 }
4276
4277 /* Mark SCC needed, so the uniform boolean transformation won't swap the definitions
4278 * when it isn't beneficial */
4279 if (instr->isBranch() && instr->operands.size() && instr->operands[0].isTemp() &&
4280 instr->operands[0].isFixed() && instr->operands[0].physReg() == scc) {
4281 ctx.info[instr->operands[0].tempId()].set_scc_needed();
4282 return;
4283 } else if ((instr->opcode == aco_opcode::s_cselect_b64 ||
4284 instr->opcode == aco_opcode::s_cselect_b32) &&
4285 instr->operands[2].isTemp()) {
4286 ctx.info[instr->operands[2].tempId()].set_scc_needed();
4287 } else if (instr->opcode == aco_opcode::p_wqm && instr->operands[0].isTemp() &&
4288 ctx.info[instr->definitions[0].tempId()].is_scc_needed()) {
4289 /* Propagate label so it is correctly detected by the uniform bool transform */
4290 ctx.info[instr->operands[0].tempId()].set_scc_needed();
4291
4292 /* Fix definition to SCC, this will prevent RA from adding superfluous moves */
4293 instr->definitions[0].setFixed(scc);
4294 }
4295
4296 /* check for literals */
4297 if (!instr->isSALU() && !instr->isVALU())
4298 return;
4299
4300 /* Transform uniform bitwise boolean operations to 32-bit when there are no divergent uses. */
4301 if (instr->definitions.size() && ctx.uses[instr->definitions[0].tempId()] == 0 &&
4302 ctx.info[instr->definitions[0].tempId()].is_uniform_bitwise()) {
4303 bool transform_done = to_uniform_bool_instr(ctx, instr);
4304
4305 if (transform_done && !ctx.info[instr->definitions[1].tempId()].is_scc_needed()) {
4306 /* Swap the two definition IDs in order to avoid overusing the SCC.
4307 * This reduces extra moves generated by RA. */
4308 uint32_t def0_id = instr->definitions[0].getTemp().id();
4309 uint32_t def1_id = instr->definitions[1].getTemp().id();
4310 instr->definitions[0].setTemp(Temp(def1_id, s1));
4311 instr->definitions[1].setTemp(Temp(def0_id, s1));
4312 }
4313
4314 return;
4315 }
4316
4317 /* Combine DPP copies into VALU. This should be done after creating MAD/FMA. */
4318 if (instr->isVALU()) {
4319 for (unsigned i = 0; i < instr->operands.size(); i++) {
4320 if (!instr->operands[i].isTemp())
4321 continue;
4322 ssa_info info = ctx.info[instr->operands[i].tempId()];
4323
4324 aco_opcode swapped_op;
4325 if (info.is_dpp() && info.instr->pass_flags == instr->pass_flags &&
4326 (i == 0 || can_swap_operands(instr, &swapped_op)) &&
4327 can_use_DPP(instr, true, info.is_dpp8()) && !instr->isDPP()) {
4328 bool dpp8 = info.is_dpp8();
4329 convert_to_DPP(instr, dpp8);
4330 if (dpp8) {
4331 DPP8_instruction* dpp = &instr->dpp8();
4332 for (unsigned j = 0; j < 8; ++j)
4333 dpp->lane_sel[j] = info.instr->dpp8().lane_sel[j];
4334 if (i) {
4335 instr->opcode = swapped_op;
4336 std::swap(instr->operands[0], instr->operands[1]);
4337 }
4338 } else {
4339 DPP16_instruction* dpp = &instr->dpp16();
4340 if (i) {
4341 instr->opcode = swapped_op;
4342 std::swap(instr->operands[0], instr->operands[1]);
4343 std::swap(dpp->neg[0], dpp->neg[1]);
4344 std::swap(dpp->abs[0], dpp->abs[1]);
4345 }
4346 dpp->dpp_ctrl = info.instr->dpp16().dpp_ctrl;
4347 dpp->bound_ctrl = info.instr->dpp16().bound_ctrl;
4348 dpp->neg[0] ^= info.instr->dpp16().neg[0] && !dpp->abs[0];
4349 dpp->abs[0] |= info.instr->dpp16().abs[0];
4350 }
4351 if (--ctx.uses[info.instr->definitions[0].tempId()])
4352 ctx.uses[info.instr->operands[0].tempId()]++;
4353 instr->operands[0].setTemp(info.instr->operands[0].getTemp());
4354 break;
4355 }
4356 }
4357 }
4358
4359 if (instr->isSDWA() || (instr->isVOP3() && ctx.program->chip_class < GFX10) ||
4360 (instr->isVOP3P() && ctx.program->chip_class < GFX10))
4361 return; /* some encodings can't ever take literals */
4362
4363 /* we do not apply the literals yet as we don't know if it is profitable */
4364 Operand current_literal(s1);
4365
4366 unsigned literal_id = 0;
4367 unsigned literal_uses = UINT32_MAX;
4368 Operand literal(s1);
4369 unsigned num_operands = 1;
4370 if (instr->isSALU() ||
4371 (ctx.program->chip_class >= GFX10 && (can_use_VOP3(ctx, instr) || instr->isVOP3P())))
4372 num_operands = instr->operands.size();
4373 /* catch VOP2 with a 3rd SGPR operand (e.g. v_cndmask_b32, v_addc_co_u32) */
4374 else if (instr->isVALU() && instr->operands.size() >= 3)
4375 return;
4376
4377 unsigned sgpr_ids[2] = {0, 0};
4378 bool is_literal_sgpr = false;
4379 uint32_t mask = 0;
4380
4381 /* choose a literal to apply */
4382 for (unsigned i = 0; i < num_operands; i++) {
4383 Operand op = instr->operands[i];
4384 unsigned bits = get_operand_size(instr, i);
4385
4386 if (instr->isVALU() && op.isTemp() && op.getTemp().type() == RegType::sgpr &&
4387 op.tempId() != sgpr_ids[0])
4388 sgpr_ids[!!sgpr_ids[0]] = op.tempId();
4389
4390 if (op.isLiteral()) {
4391 current_literal = op;
4392 continue;
4393 } else if (!op.isTemp() || !ctx.info[op.tempId()].is_literal(bits)) {
4394 continue;
4395 }
4396
4397 if (!alu_can_accept_constant(instr->opcode, i))
4398 continue;
4399
4400 if (ctx.uses[op.tempId()] < literal_uses) {
4401 is_literal_sgpr = op.getTemp().type() == RegType::sgpr;
4402 mask = 0;
4403 literal = Operand::c32(ctx.info[op.tempId()].val);
4404 literal_uses = ctx.uses[op.tempId()];
4405 literal_id = op.tempId();
4406 }
4407
4408 mask |= (op.tempId() == literal_id) << i;
4409 }
4410
4411 /* don't go over the constant bus limit */
4412 bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64 ||
4413 instr->opcode == aco_opcode::v_lshrrev_b64 ||
4414 instr->opcode == aco_opcode::v_ashrrev_i64;
4415 unsigned const_bus_limit = instr->isVALU() ? 1 : UINT32_MAX;
4416 if (ctx.program->chip_class >= GFX10 && !is_shift64)
4417 const_bus_limit = 2;
4418
4419 unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
4420 if (num_sgprs == const_bus_limit && !is_literal_sgpr)
4421 return;
4422
4423 if (literal_id && literal_uses < threshold &&
4424 (current_literal.isUndefined() ||
4425 (current_literal.size() == literal.size() &&
4426 current_literal.constantValue() == literal.constantValue()))) {
4427 /* mark the literal to be applied */
4428 while (mask) {
4429 unsigned i = u_bit_scan(&mask);
4430 if (instr->operands[i].isTemp() && instr->operands[i].tempId() == literal_id)
4431 ctx.uses[instr->operands[i].tempId()]--;
4432 }
4433 }
4434 }
4435
4436 void
apply_literals(opt_ctx & ctx,aco_ptr<Instruction> & instr)4437 apply_literals(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4438 {
4439 /* Cleanup Dead Instructions */
4440 if (!instr)
4441 return;
4442
4443 /* apply literals on MAD */
4444 if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
4445 mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].instr->pass_flags];
4446 if (info->check_literal &&
4447 (ctx.uses[instr->operands[info->literal_idx].tempId()] == 0 || info->literal_idx == 2)) {
4448 aco_ptr<Instruction> new_mad;
4449
4450 aco_opcode new_op =
4451 info->literal_idx == 2 ? aco_opcode::v_madak_f32 : aco_opcode::v_madmk_f32;
4452 if (instr->opcode == aco_opcode::v_fma_f32)
4453 new_op = info->literal_idx == 2 ? aco_opcode::v_fmaak_f32 : aco_opcode::v_fmamk_f32;
4454 else if (instr->opcode == aco_opcode::v_mad_f16 ||
4455 instr->opcode == aco_opcode::v_mad_legacy_f16)
4456 new_op = info->literal_idx == 2 ? aco_opcode::v_madak_f16 : aco_opcode::v_madmk_f16;
4457 else if (instr->opcode == aco_opcode::v_fma_f16)
4458 new_op = info->literal_idx == 2 ? aco_opcode::v_fmaak_f16 : aco_opcode::v_fmamk_f16;
4459
4460 new_mad.reset(create_instruction<VOP2_instruction>(new_op, Format::VOP2, 3, 1));
4461 if (info->literal_idx == 2) { /* add literal -> madak */
4462 new_mad->operands[0] = instr->operands[0];
4463 new_mad->operands[1] = instr->operands[1];
4464 if (!new_mad->operands[1].isTemp() ||
4465 new_mad->operands[1].getTemp().type() == RegType::sgpr)
4466 std::swap(new_mad->operands[0], new_mad->operands[1]);
4467 } else { /* mul literal -> madmk */
4468 new_mad->operands[0] = instr->operands[1 - info->literal_idx];
4469 new_mad->operands[1] = instr->operands[2];
4470 }
4471 new_mad->operands[2] =
4472 Operand::c32(ctx.info[instr->operands[info->literal_idx].tempId()].val);
4473 new_mad->definitions[0] = instr->definitions[0];
4474 ctx.instructions.emplace_back(std::move(new_mad));
4475 return;
4476 }
4477 }
4478
4479 /* apply literals on other SALU/VALU */
4480 if (instr->isSALU() || instr->isVALU()) {
4481 for (unsigned i = 0; i < instr->operands.size(); i++) {
4482 Operand op = instr->operands[i];
4483 unsigned bits = get_operand_size(instr, i);
4484 if (op.isTemp() && ctx.info[op.tempId()].is_literal(bits) && ctx.uses[op.tempId()] == 0) {
4485 Operand literal = Operand::c32(ctx.info[op.tempId()].val);
4486 instr->format = withoutDPP(instr->format);
4487 if (instr->isVALU() && i > 0 && instr->format != Format::VOP3P)
4488 to_VOP3(ctx, instr);
4489 instr->operands[i] = literal;
4490 }
4491 }
4492 }
4493
4494 ctx.instructions.emplace_back(std::move(instr));
4495 }
4496
4497 void
optimize(Program * program)4498 optimize(Program* program)
4499 {
4500 opt_ctx ctx;
4501 ctx.program = program;
4502 std::vector<ssa_info> info(program->peekAllocationId());
4503 ctx.info = info.data();
4504
4505 /* 1. Bottom-Up DAG pass (forward) to label all ssa-defs */
4506 for (Block& block : program->blocks) {
4507 ctx.fp_mode = block.fp_mode;
4508 for (aco_ptr<Instruction>& instr : block.instructions)
4509 label_instruction(ctx, instr);
4510 }
4511
4512 ctx.uses = dead_code_analysis(program);
4513
4514 /* 2. Combine v_mad, omod, clamp and propagate sgpr on VALU instructions */
4515 for (Block& block : program->blocks) {
4516 ctx.fp_mode = block.fp_mode;
4517 for (aco_ptr<Instruction>& instr : block.instructions)
4518 combine_instruction(ctx, instr);
4519 }
4520
4521 /* 3. Top-Down DAG pass (backward) to select instructions (includes DCE) */
4522 for (auto block_rit = program->blocks.rbegin(); block_rit != program->blocks.rend();
4523 ++block_rit) {
4524 Block* block = &(*block_rit);
4525 ctx.fp_mode = block->fp_mode;
4526 for (auto instr_rit = block->instructions.rbegin(); instr_rit != block->instructions.rend();
4527 ++instr_rit)
4528 select_instruction(ctx, *instr_rit);
4529 }
4530
4531 /* 4. Add literals to instructions */
4532 for (Block& block : program->blocks) {
4533 ctx.instructions.clear();
4534 ctx.fp_mode = block.fp_mode;
4535 for (aco_ptr<Instruction>& instr : block.instructions)
4536 apply_literals(ctx, instr);
4537 block.instructions.swap(ctx.instructions);
4538 }
4539 }
4540
4541 } // namespace aco
4542