1 /* This file is part of the dynarmic project.
2 * Copyright (c) 2016 MerryMage
3 * SPDX-License-Identifier: 0BSD
4 */
5
6 #include <optional>
7
8 #include "common/assert.h"
9 #include "common/bit_util.h"
10 #include "common/safe_ops.h"
11 #include "common/common_types.h"
12 #include "frontend/ir/basic_block.h"
13 #include "frontend/ir/ir_emitter.h"
14 #include "frontend/ir/opcodes.h"
15 #include "ir_opt/passes.h"
16
17 namespace Dynarmic::Optimization {
18
19 using Op = Dynarmic::IR::Opcode;
20
21 namespace {
22
23 // Tiny helper to avoid the need to store based off the opcode
24 // bit size all over the place within folding functions.
ReplaceUsesWith(IR::Inst & inst,bool is_32_bit,u64 value)25 void ReplaceUsesWith(IR::Inst& inst, bool is_32_bit, u64 value) {
26 if (is_32_bit) {
27 inst.ReplaceUsesWith(IR::Value{static_cast<u32>(value)});
28 } else {
29 inst.ReplaceUsesWith(IR::Value{value});
30 }
31 }
32
Value(bool is_32_bit,u64 value)33 IR::Value Value(bool is_32_bit, u64 value) {
34 return is_32_bit ? IR::Value{static_cast<u32>(value)} : IR::Value{value};
35 }
36
37 template <typename ImmFn>
FoldCommutative(IR::Inst & inst,bool is_32_bit,ImmFn imm_fn)38 bool FoldCommutative(IR::Inst& inst, bool is_32_bit, ImmFn imm_fn) {
39 const auto lhs = inst.GetArg(0);
40 const auto rhs = inst.GetArg(1);
41
42 const bool is_lhs_immediate = lhs.IsImmediate();
43 const bool is_rhs_immediate = rhs.IsImmediate();
44
45 if (is_lhs_immediate && is_rhs_immediate) {
46 const u64 result = imm_fn(lhs.GetImmediateAsU64(), rhs.GetImmediateAsU64());
47 ReplaceUsesWith(inst, is_32_bit, result);
48 return false;
49 }
50
51 if (is_lhs_immediate && !is_rhs_immediate) {
52 const IR::Inst* rhs_inst = rhs.GetInstRecursive();
53 if (rhs_inst->GetOpcode() == inst.GetOpcode() && rhs_inst->GetArg(1).IsImmediate()) {
54 const u64 combined = imm_fn(lhs.GetImmediateAsU64(), rhs_inst->GetArg(1).GetImmediateAsU64());
55 inst.SetArg(0, rhs_inst->GetArg(0));
56 inst.SetArg(1, Value(is_32_bit, combined));
57 } else {
58 // Normalize
59 inst.SetArg(0, rhs);
60 inst.SetArg(1, lhs);
61 }
62 }
63
64 if (!is_lhs_immediate && is_rhs_immediate) {
65 const IR::Inst* lhs_inst = lhs.GetInstRecursive();
66 if (lhs_inst->GetOpcode() == inst.GetOpcode() && lhs_inst->GetArg(1).IsImmediate()) {
67 const u64 combined = imm_fn(rhs.GetImmediateAsU64(), lhs_inst->GetArg(1).GetImmediateAsU64());
68 inst.SetArg(0, lhs_inst->GetArg(0));
69 inst.SetArg(1, Value(is_32_bit, combined));
70 }
71 }
72
73 return true;
74 }
75
FoldAdd(IR::Inst & inst,bool is_32_bit)76 void FoldAdd(IR::Inst& inst, bool is_32_bit) {
77 const auto lhs = inst.GetArg(0);
78 const auto rhs = inst.GetArg(1);
79 const auto carry = inst.GetArg(2);
80
81 if (lhs.IsImmediate() && !rhs.IsImmediate()) {
82 // Normalize
83 inst.SetArg(0, rhs);
84 inst.SetArg(1, lhs);
85 FoldAdd(inst, is_32_bit);
86 return;
87 }
88
89 if (inst.HasAssociatedPseudoOperation()) {
90 return;
91 }
92
93 if (!lhs.IsImmediate() && rhs.IsImmediate()) {
94 const IR::Inst* lhs_inst = lhs.GetInstRecursive();
95 if (lhs_inst->GetOpcode() == inst.GetOpcode() && lhs_inst->GetArg(1).IsImmediate() && lhs_inst->GetArg(2).IsImmediate()) {
96 const u64 combined = rhs.GetImmediateAsU64() + lhs_inst->GetArg(1).GetImmediateAsU64() + lhs_inst->GetArg(2).GetU1();
97 inst.SetArg(0, lhs_inst->GetArg(0));
98 inst.SetArg(1, Value(is_32_bit, combined));
99 return;
100 }
101 }
102
103 if (inst.AreAllArgsImmediates()) {
104 const u64 result = lhs.GetImmediateAsU64() + rhs.GetImmediateAsU64() + carry.GetU1();
105 ReplaceUsesWith(inst, is_32_bit, result);
106 return;
107 }
108 }
109
110 // Folds AND operations based on the following:
111 //
112 // 1. imm_x & imm_y -> result
113 // 2. x & 0 -> 0
114 // 3. 0 & y -> 0
115 // 4. x & y -> y (where x has all bits set to 1)
116 // 5. x & y -> x (where y has all bits set to 1)
117 //
FoldAND(IR::Inst & inst,bool is_32_bit)118 void FoldAND(IR::Inst& inst, bool is_32_bit) {
119 if (FoldCommutative(inst, is_32_bit, [](u64 a, u64 b) { return a & b; })) {
120 const auto rhs = inst.GetArg(1);
121 if (rhs.IsZero()) {
122 ReplaceUsesWith(inst, is_32_bit, 0);
123 } else if (rhs.HasAllBitsSet()) {
124 inst.ReplaceUsesWith(inst.GetArg(0));
125 }
126 }
127 }
128
129 // Folds byte reversal opcodes based on the following:
130 //
131 // 1. imm -> swap(imm)
132 //
FoldByteReverse(IR::Inst & inst,Op op)133 void FoldByteReverse(IR::Inst& inst, Op op) {
134 const auto operand = inst.GetArg(0);
135
136 if (!operand.IsImmediate()) {
137 return;
138 }
139
140 if (op == Op::ByteReverseWord) {
141 const u32 result = Common::Swap32(static_cast<u32>(operand.GetImmediateAsU64()));
142 inst.ReplaceUsesWith(IR::Value{result});
143 } else if (op == Op::ByteReverseHalf) {
144 const u16 result = Common::Swap16(static_cast<u16>(operand.GetImmediateAsU64()));
145 inst.ReplaceUsesWith(IR::Value{result});
146 } else {
147 const u64 result = Common::Swap64(operand.GetImmediateAsU64());
148 inst.ReplaceUsesWith(IR::Value{result});
149 }
150 }
151
152 // Folds division operations based on the following:
153 //
154 // 1. x / 0 -> 0 (NOTE: This is an ARM-specific behavior defined in the architecture reference manual)
155 // 2. imm_x / imm_y -> result
156 // 3. x / 1 -> x
157 //
FoldDivide(IR::Inst & inst,bool is_32_bit,bool is_signed)158 void FoldDivide(IR::Inst& inst, bool is_32_bit, bool is_signed) {
159 const auto rhs = inst.GetArg(1);
160
161 if (rhs.IsZero()) {
162 ReplaceUsesWith(inst, is_32_bit, 0);
163 return;
164 }
165
166 const auto lhs = inst.GetArg(0);
167 if (lhs.IsImmediate() && rhs.IsImmediate()) {
168 if (is_signed) {
169 const s64 result = lhs.GetImmediateAsS64() / rhs.GetImmediateAsS64();
170 ReplaceUsesWith(inst, is_32_bit, static_cast<u64>(result));
171 } else {
172 const u64 result = lhs.GetImmediateAsU64() / rhs.GetImmediateAsU64();
173 ReplaceUsesWith(inst, is_32_bit, result);
174 }
175 } else if (rhs.IsUnsignedImmediate(1)) {
176 inst.ReplaceUsesWith(IR::Value{lhs});
177 }
178 }
179
180 // Folds EOR operations based on the following:
181 //
182 // 1. imm_x ^ imm_y -> result
183 // 2. x ^ 0 -> x
184 // 3. 0 ^ y -> y
185 //
FoldEOR(IR::Inst & inst,bool is_32_bit)186 void FoldEOR(IR::Inst& inst, bool is_32_bit) {
187 if (FoldCommutative(inst, is_32_bit, [](u64 a, u64 b) { return a ^ b; })) {
188 const auto rhs = inst.GetArg(1);
189 if (rhs.IsZero()) {
190 inst.ReplaceUsesWith(inst.GetArg(0));
191 }
192 }
193 }
194
FoldLeastSignificantByte(IR::Inst & inst)195 void FoldLeastSignificantByte(IR::Inst& inst) {
196 if (!inst.AreAllArgsImmediates()) {
197 return;
198 }
199
200 const auto operand = inst.GetArg(0);
201 inst.ReplaceUsesWith(IR::Value{static_cast<u8>(operand.GetImmediateAsU64())});
202 }
203
FoldLeastSignificantHalf(IR::Inst & inst)204 void FoldLeastSignificantHalf(IR::Inst& inst) {
205 if (!inst.AreAllArgsImmediates()) {
206 return;
207 }
208
209 const auto operand = inst.GetArg(0);
210 inst.ReplaceUsesWith(IR::Value{static_cast<u16>(operand.GetImmediateAsU64())});
211 }
212
FoldLeastSignificantWord(IR::Inst & inst)213 void FoldLeastSignificantWord(IR::Inst& inst) {
214 if (!inst.AreAllArgsImmediates()) {
215 return;
216 }
217
218 const auto operand = inst.GetArg(0);
219 inst.ReplaceUsesWith(IR::Value{static_cast<u32>(operand.GetImmediateAsU64())});
220 }
221
FoldMostSignificantBit(IR::Inst & inst)222 void FoldMostSignificantBit(IR::Inst& inst) {
223 if (!inst.AreAllArgsImmediates()) {
224 return;
225 }
226
227 const auto operand = inst.GetArg(0);
228 inst.ReplaceUsesWith(IR::Value{(operand.GetImmediateAsU64() >> 31) != 0});
229 }
230
FoldMostSignificantWord(IR::Inst & inst)231 void FoldMostSignificantWord(IR::Inst& inst) {
232 IR::Inst* carry_inst = inst.GetAssociatedPseudoOperation(Op::GetCarryFromOp);
233
234 if (!inst.AreAllArgsImmediates()) {
235 return;
236 }
237
238 const auto operand = inst.GetArg(0);
239 if (carry_inst) {
240 carry_inst->ReplaceUsesWith(IR::Value{Common::Bit<31>(operand.GetImmediateAsU64())});
241 }
242 inst.ReplaceUsesWith(IR::Value{static_cast<u32>(operand.GetImmediateAsU64() >> 32)});
243 }
244
245 // Folds multiplication operations based on the following:
246 //
247 // 1. imm_x * imm_y -> result
248 // 2. x * 0 -> 0
249 // 3. 0 * y -> 0
250 // 4. x * 1 -> x
251 // 5. 1 * y -> y
252 //
FoldMultiply(IR::Inst & inst,bool is_32_bit)253 void FoldMultiply(IR::Inst& inst, bool is_32_bit) {
254 if (FoldCommutative(inst, is_32_bit, [](u64 a, u64 b) { return a * b; })) {
255 const auto rhs = inst.GetArg(1);
256 if (rhs.IsZero()) {
257 ReplaceUsesWith(inst, is_32_bit, 0);
258 } else if (rhs.IsUnsignedImmediate(1)) {
259 inst.ReplaceUsesWith(inst.GetArg(0));
260 }
261 }
262 }
263
264 // Folds NOT operations if the contained value is an immediate.
FoldNOT(IR::Inst & inst,bool is_32_bit)265 void FoldNOT(IR::Inst& inst, bool is_32_bit) {
266 const auto operand = inst.GetArg(0);
267
268 if (!operand.IsImmediate()) {
269 return;
270 }
271
272 const u64 result = ~operand.GetImmediateAsU64();
273 ReplaceUsesWith(inst, is_32_bit, result);
274 }
275
276 // Folds OR operations based on the following:
277 //
278 // 1. imm_x | imm_y -> result
279 // 2. x | 0 -> x
280 // 3. 0 | y -> y
281 //
FoldOR(IR::Inst & inst,bool is_32_bit)282 void FoldOR(IR::Inst& inst, bool is_32_bit) {
283 if (FoldCommutative(inst, is_32_bit, [](u64 a, u64 b) { return a | b; })) {
284 const auto rhs = inst.GetArg(1);
285 if (rhs.IsZero()) {
286 inst.ReplaceUsesWith(inst.GetArg(0));
287 }
288 }
289 }
290
FoldShifts(IR::Inst & inst)291 bool FoldShifts(IR::Inst& inst) {
292 IR::Inst* carry_inst = inst.GetAssociatedPseudoOperation(Op::GetCarryFromOp);
293
294 // The 32-bit variants can contain 3 arguments, while the
295 // 64-bit variants only contain 2.
296 if (inst.NumArgs() == 3 && !carry_inst) {
297 inst.SetArg(2, IR::Value(false));
298 }
299
300 const auto shift_amount = inst.GetArg(1);
301 if (shift_amount.IsZero()) {
302 if (carry_inst) {
303 carry_inst->ReplaceUsesWith(inst.GetArg(2));
304 }
305 inst.ReplaceUsesWith(inst.GetArg(0));
306 return false;
307 }
308
309 if (!inst.AreAllArgsImmediates() || carry_inst) {
310 return false;
311 }
312
313 return true;
314 }
315
FoldSignExtendXToWord(IR::Inst & inst)316 void FoldSignExtendXToWord(IR::Inst& inst) {
317 if (!inst.AreAllArgsImmediates()) {
318 return;
319 }
320
321 const s64 value = inst.GetArg(0).GetImmediateAsS64();
322 inst.ReplaceUsesWith(IR::Value{static_cast<u32>(value)});
323 }
324
FoldSignExtendXToLong(IR::Inst & inst)325 void FoldSignExtendXToLong(IR::Inst& inst) {
326 if (!inst.AreAllArgsImmediates()) {
327 return;
328 }
329
330 const s64 value = inst.GetArg(0).GetImmediateAsS64();
331 inst.ReplaceUsesWith(IR::Value{static_cast<u64>(value)});
332 }
333
FoldSub(IR::Inst & inst,bool is_32_bit)334 void FoldSub(IR::Inst& inst, bool is_32_bit) {
335 if (!inst.AreAllArgsImmediates() || inst.HasAssociatedPseudoOperation()) {
336 return;
337 }
338
339 const auto lhs = inst.GetArg(0);
340 const auto rhs = inst.GetArg(1);
341 const auto carry = inst.GetArg(2);
342
343 const u64 result = lhs.GetImmediateAsU64() + (~rhs.GetImmediateAsU64()) + carry.GetU1();
344 ReplaceUsesWith(inst, is_32_bit, result);
345 }
346
FoldZeroExtendXToWord(IR::Inst & inst)347 void FoldZeroExtendXToWord(IR::Inst& inst) {
348 if (!inst.AreAllArgsImmediates()) {
349 return;
350 }
351
352 const u64 value = inst.GetArg(0).GetImmediateAsU64();
353 inst.ReplaceUsesWith(IR::Value{static_cast<u32>(value)});
354 }
355
FoldZeroExtendXToLong(IR::Inst & inst)356 void FoldZeroExtendXToLong(IR::Inst& inst) {
357 if (!inst.AreAllArgsImmediates()) {
358 return;
359 }
360
361 const u64 value = inst.GetArg(0).GetImmediateAsU64();
362 inst.ReplaceUsesWith(IR::Value{value});
363 }
364 } // Anonymous namespace
365
ConstantPropagation(IR::Block & block)366 void ConstantPropagation(IR::Block& block) {
367 for (auto& inst : block) {
368 const auto opcode = inst.GetOpcode();
369
370 switch (opcode) {
371 case Op::LeastSignificantWord:
372 FoldLeastSignificantWord(inst);
373 break;
374 case Op::MostSignificantWord:
375 FoldMostSignificantWord(inst);
376 break;
377 case Op::LeastSignificantHalf:
378 FoldLeastSignificantHalf(inst);
379 break;
380 case Op::LeastSignificantByte:
381 FoldLeastSignificantByte(inst);
382 break;
383 case Op::MostSignificantBit:
384 FoldMostSignificantBit(inst);
385 break;
386 case Op::IsZero32:
387 if (inst.AreAllArgsImmediates()) {
388 inst.ReplaceUsesWith(IR::Value{inst.GetArg(0).GetU32() == 0});
389 }
390 break;
391 case Op::IsZero64:
392 if (inst.AreAllArgsImmediates()) {
393 inst.ReplaceUsesWith(IR::Value{inst.GetArg(0).GetU64() == 0});
394 }
395 break;
396 case Op::LogicalShiftLeft32:
397 if (FoldShifts(inst)) {
398 ReplaceUsesWith(inst, true, Safe::LogicalShiftLeft<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8()));
399 }
400 break;
401 case Op::LogicalShiftLeft64:
402 if (FoldShifts(inst)) {
403 ReplaceUsesWith(inst, false, Safe::LogicalShiftLeft<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8()));
404 }
405 break;
406 case Op::LogicalShiftRight32:
407 if (FoldShifts(inst)) {
408 ReplaceUsesWith(inst, true, Safe::LogicalShiftRight<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8()));
409 }
410 break;
411 case Op::LogicalShiftRight64:
412 if (FoldShifts(inst)) {
413 ReplaceUsesWith(inst, false, Safe::LogicalShiftRight<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8()));
414 }
415 break;
416 case Op::ArithmeticShiftRight32:
417 if (FoldShifts(inst)) {
418 ReplaceUsesWith(inst, true, Safe::ArithmeticShiftRight<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8()));
419 }
420 break;
421 case Op::ArithmeticShiftRight64:
422 if (FoldShifts(inst)) {
423 ReplaceUsesWith(inst, false, Safe::ArithmeticShiftRight<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8()));
424 }
425 break;
426 case Op::RotateRight32:
427 if (FoldShifts(inst)) {
428 ReplaceUsesWith(inst, true, Common::RotateRight<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8()));
429 }
430 break;
431 case Op::RotateRight64:
432 if (FoldShifts(inst)) {
433 ReplaceUsesWith(inst, false, Common::RotateRight<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8()));
434 }
435 break;
436 case Op::LogicalShiftLeftMasked32:
437 if (inst.AreAllArgsImmediates()) {
438 ReplaceUsesWith(inst, true, inst.GetArg(0).GetU32() << (inst.GetArg(1).GetU32() & 0x1f));
439 }
440 break;
441 case Op::LogicalShiftLeftMasked64:
442 if (inst.AreAllArgsImmediates()) {
443 ReplaceUsesWith(inst, false, inst.GetArg(0).GetU64() << (inst.GetArg(1).GetU64() & 0x3f));
444 }
445 break;
446 case Op::LogicalShiftRightMasked32:
447 if (inst.AreAllArgsImmediates()) {
448 ReplaceUsesWith(inst, true, inst.GetArg(0).GetU32() >> (inst.GetArg(1).GetU32() & 0x1f));
449 }
450 break;
451 case Op::LogicalShiftRightMasked64:
452 if (inst.AreAllArgsImmediates()) {
453 ReplaceUsesWith(inst, false, inst.GetArg(0).GetU64() >> (inst.GetArg(1).GetU64() & 0x3f));
454 }
455 break;
456 case Op::ArithmeticShiftRightMasked32:
457 if (inst.AreAllArgsImmediates()) {
458 ReplaceUsesWith(inst, true, static_cast<s32>(inst.GetArg(0).GetU32()) >> (inst.GetArg(1).GetU32() & 0x1f));
459 }
460 break;
461 case Op::ArithmeticShiftRightMasked64:
462 if (inst.AreAllArgsImmediates()) {
463 ReplaceUsesWith(inst, false, static_cast<s64>(inst.GetArg(0).GetU64()) >> (inst.GetArg(1).GetU64() & 0x3f));
464 }
465 break;
466 case Op::RotateRightMasked32:
467 if (inst.AreAllArgsImmediates()) {
468 ReplaceUsesWith(inst, true, Common::RotateRight<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU32()));
469 }
470 break;
471 case Op::RotateRightMasked64:
472 if (inst.AreAllArgsImmediates()) {
473 ReplaceUsesWith(inst, false, Common::RotateRight<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU64()));
474 }
475 break;
476 case Op::Add32:
477 case Op::Add64:
478 FoldAdd(inst, opcode == Op::Add32);
479 break;
480 case Op::Sub32:
481 case Op::Sub64:
482 FoldSub(inst, opcode == Op::Sub32);
483 break;
484 case Op::Mul32:
485 case Op::Mul64:
486 FoldMultiply(inst, opcode == Op::Mul32);
487 break;
488 case Op::SignedDiv32:
489 case Op::SignedDiv64:
490 FoldDivide(inst, opcode == Op::SignedDiv32, true);
491 break;
492 case Op::UnsignedDiv32:
493 case Op::UnsignedDiv64:
494 FoldDivide(inst, opcode == Op::UnsignedDiv32, false);
495 break;
496 case Op::And32:
497 case Op::And64:
498 FoldAND(inst, opcode == Op::And32);
499 break;
500 case Op::Eor32:
501 case Op::Eor64:
502 FoldEOR(inst, opcode == Op::Eor32);
503 break;
504 case Op::Or32:
505 case Op::Or64:
506 FoldOR(inst, opcode == Op::Or32);
507 break;
508 case Op::Not32:
509 case Op::Not64:
510 FoldNOT(inst, opcode == Op::Not32);
511 break;
512 case Op::SignExtendByteToWord:
513 case Op::SignExtendHalfToWord:
514 FoldSignExtendXToWord(inst);
515 break;
516 case Op::SignExtendByteToLong:
517 case Op::SignExtendHalfToLong:
518 case Op::SignExtendWordToLong:
519 FoldSignExtendXToLong(inst);
520 break;
521 case Op::ZeroExtendByteToWord:
522 case Op::ZeroExtendHalfToWord:
523 FoldZeroExtendXToWord(inst);
524 break;
525 case Op::ZeroExtendByteToLong:
526 case Op::ZeroExtendHalfToLong:
527 case Op::ZeroExtendWordToLong:
528 FoldZeroExtendXToLong(inst);
529 break;
530 case Op::ByteReverseWord:
531 case Op::ByteReverseHalf:
532 case Op::ByteReverseDual:
533 FoldByteReverse(inst, opcode);
534 break;
535 default:
536 break;
537 }
538 }
539 }
540
541 } // namespace Dynarmic::Optimization
542