1 /* This file is part of the dynarmic project.
2  * Copyright (c) 2016 MerryMage
3  * SPDX-License-Identifier: 0BSD
4  */
5 
6 #include <optional>
7 
8 #include "common/assert.h"
9 #include "common/bit_util.h"
10 #include "common/safe_ops.h"
11 #include "common/common_types.h"
12 #include "frontend/ir/basic_block.h"
13 #include "frontend/ir/ir_emitter.h"
14 #include "frontend/ir/opcodes.h"
15 #include "ir_opt/passes.h"
16 
17 namespace Dynarmic::Optimization {
18 
19 using Op = Dynarmic::IR::Opcode;
20 
21 namespace {
22 
23 // Tiny helper to avoid the need to store based off the opcode
24 // bit size all over the place within folding functions.
ReplaceUsesWith(IR::Inst & inst,bool is_32_bit,u64 value)25 void ReplaceUsesWith(IR::Inst& inst, bool is_32_bit, u64 value) {
26     if (is_32_bit) {
27         inst.ReplaceUsesWith(IR::Value{static_cast<u32>(value)});
28     } else {
29         inst.ReplaceUsesWith(IR::Value{value});
30     }
31 }
32 
Value(bool is_32_bit,u64 value)33 IR::Value Value(bool is_32_bit, u64 value) {
34     return is_32_bit ? IR::Value{static_cast<u32>(value)} : IR::Value{value};
35 }
36 
37 template <typename ImmFn>
FoldCommutative(IR::Inst & inst,bool is_32_bit,ImmFn imm_fn)38 bool FoldCommutative(IR::Inst& inst, bool is_32_bit, ImmFn imm_fn) {
39     const auto lhs = inst.GetArg(0);
40     const auto rhs = inst.GetArg(1);
41 
42     const bool is_lhs_immediate = lhs.IsImmediate();
43     const bool is_rhs_immediate = rhs.IsImmediate();
44 
45     if (is_lhs_immediate && is_rhs_immediate) {
46         const u64 result = imm_fn(lhs.GetImmediateAsU64(), rhs.GetImmediateAsU64());
47         ReplaceUsesWith(inst, is_32_bit, result);
48         return false;
49     }
50 
51     if (is_lhs_immediate && !is_rhs_immediate) {
52         const IR::Inst* rhs_inst = rhs.GetInstRecursive();
53         if (rhs_inst->GetOpcode() == inst.GetOpcode() && rhs_inst->GetArg(1).IsImmediate()) {
54             const u64 combined = imm_fn(lhs.GetImmediateAsU64(), rhs_inst->GetArg(1).GetImmediateAsU64());
55             inst.SetArg(0, rhs_inst->GetArg(0));
56             inst.SetArg(1, Value(is_32_bit, combined));
57         } else {
58             // Normalize
59             inst.SetArg(0, rhs);
60             inst.SetArg(1, lhs);
61         }
62     }
63 
64     if (!is_lhs_immediate && is_rhs_immediate) {
65         const IR::Inst* lhs_inst = lhs.GetInstRecursive();
66         if (lhs_inst->GetOpcode() == inst.GetOpcode() && lhs_inst->GetArg(1).IsImmediate()) {
67             const u64 combined = imm_fn(rhs.GetImmediateAsU64(), lhs_inst->GetArg(1).GetImmediateAsU64());
68             inst.SetArg(0, lhs_inst->GetArg(0));
69             inst.SetArg(1, Value(is_32_bit, combined));
70         }
71     }
72 
73     return true;
74 }
75 
FoldAdd(IR::Inst & inst,bool is_32_bit)76 void FoldAdd(IR::Inst& inst, bool is_32_bit) {
77     const auto lhs = inst.GetArg(0);
78     const auto rhs = inst.GetArg(1);
79     const auto carry = inst.GetArg(2);
80 
81     if (lhs.IsImmediate() && !rhs.IsImmediate()) {
82         // Normalize
83         inst.SetArg(0, rhs);
84         inst.SetArg(1, lhs);
85         FoldAdd(inst, is_32_bit);
86         return;
87     }
88 
89     if (inst.HasAssociatedPseudoOperation()) {
90         return;
91     }
92 
93     if (!lhs.IsImmediate() && rhs.IsImmediate()) {
94         const IR::Inst* lhs_inst = lhs.GetInstRecursive();
95         if (lhs_inst->GetOpcode() == inst.GetOpcode() && lhs_inst->GetArg(1).IsImmediate() && lhs_inst->GetArg(2).IsImmediate()) {
96             const u64 combined = rhs.GetImmediateAsU64() + lhs_inst->GetArg(1).GetImmediateAsU64() + lhs_inst->GetArg(2).GetU1();
97             inst.SetArg(0, lhs_inst->GetArg(0));
98             inst.SetArg(1, Value(is_32_bit, combined));
99             return;
100         }
101     }
102 
103     if (inst.AreAllArgsImmediates()) {
104         const u64 result = lhs.GetImmediateAsU64() + rhs.GetImmediateAsU64() + carry.GetU1();
105         ReplaceUsesWith(inst, is_32_bit, result);
106         return;
107     }
108 }
109 
110 // Folds AND operations based on the following:
111 //
112 // 1. imm_x & imm_y -> result
113 // 2. x & 0 -> 0
114 // 3. 0 & y -> 0
115 // 4. x & y -> y (where x has all bits set to 1)
116 // 5. x & y -> x (where y has all bits set to 1)
117 //
FoldAND(IR::Inst & inst,bool is_32_bit)118 void FoldAND(IR::Inst& inst, bool is_32_bit) {
119     if (FoldCommutative(inst, is_32_bit, [](u64 a, u64 b) { return a & b; })) {
120         const auto rhs = inst.GetArg(1);
121         if (rhs.IsZero()) {
122             ReplaceUsesWith(inst, is_32_bit, 0);
123         } else if (rhs.HasAllBitsSet()) {
124             inst.ReplaceUsesWith(inst.GetArg(0));
125         }
126     }
127 }
128 
129 // Folds byte reversal opcodes based on the following:
130 //
131 // 1. imm -> swap(imm)
132 //
FoldByteReverse(IR::Inst & inst,Op op)133 void FoldByteReverse(IR::Inst& inst, Op op) {
134     const auto operand = inst.GetArg(0);
135 
136     if (!operand.IsImmediate()) {
137         return;
138     }
139 
140     if (op == Op::ByteReverseWord) {
141         const u32 result = Common::Swap32(static_cast<u32>(operand.GetImmediateAsU64()));
142         inst.ReplaceUsesWith(IR::Value{result});
143     } else if (op == Op::ByteReverseHalf) {
144         const u16 result = Common::Swap16(static_cast<u16>(operand.GetImmediateAsU64()));
145         inst.ReplaceUsesWith(IR::Value{result});
146     } else {
147         const u64 result = Common::Swap64(operand.GetImmediateAsU64());
148         inst.ReplaceUsesWith(IR::Value{result});
149     }
150 }
151 
152 // Folds division operations based on the following:
153 //
154 // 1. x / 0 -> 0 (NOTE: This is an ARM-specific behavior defined in the architecture reference manual)
155 // 2. imm_x / imm_y -> result
156 // 3. x / 1 -> x
157 //
FoldDivide(IR::Inst & inst,bool is_32_bit,bool is_signed)158 void FoldDivide(IR::Inst& inst, bool is_32_bit, bool is_signed) {
159     const auto rhs = inst.GetArg(1);
160 
161     if (rhs.IsZero()) {
162         ReplaceUsesWith(inst, is_32_bit, 0);
163         return;
164     }
165 
166     const auto lhs = inst.GetArg(0);
167     if (lhs.IsImmediate() && rhs.IsImmediate()) {
168         if (is_signed) {
169             const s64 result = lhs.GetImmediateAsS64() / rhs.GetImmediateAsS64();
170             ReplaceUsesWith(inst, is_32_bit, static_cast<u64>(result));
171         } else {
172             const u64 result = lhs.GetImmediateAsU64() / rhs.GetImmediateAsU64();
173             ReplaceUsesWith(inst, is_32_bit, result);
174         }
175     } else if (rhs.IsUnsignedImmediate(1)) {
176         inst.ReplaceUsesWith(IR::Value{lhs});
177     }
178 }
179 
180 // Folds EOR operations based on the following:
181 //
182 // 1. imm_x ^ imm_y -> result
183 // 2. x ^ 0 -> x
184 // 3. 0 ^ y -> y
185 //
FoldEOR(IR::Inst & inst,bool is_32_bit)186 void FoldEOR(IR::Inst& inst, bool is_32_bit) {
187     if (FoldCommutative(inst, is_32_bit, [](u64 a, u64 b) { return a ^ b; })) {
188         const auto rhs = inst.GetArg(1);
189         if (rhs.IsZero()) {
190             inst.ReplaceUsesWith(inst.GetArg(0));
191         }
192     }
193 }
194 
FoldLeastSignificantByte(IR::Inst & inst)195 void FoldLeastSignificantByte(IR::Inst& inst) {
196     if (!inst.AreAllArgsImmediates()) {
197         return;
198     }
199 
200     const auto operand = inst.GetArg(0);
201     inst.ReplaceUsesWith(IR::Value{static_cast<u8>(operand.GetImmediateAsU64())});
202 }
203 
FoldLeastSignificantHalf(IR::Inst & inst)204 void FoldLeastSignificantHalf(IR::Inst& inst) {
205     if (!inst.AreAllArgsImmediates()) {
206         return;
207     }
208 
209     const auto operand = inst.GetArg(0);
210     inst.ReplaceUsesWith(IR::Value{static_cast<u16>(operand.GetImmediateAsU64())});
211 }
212 
FoldLeastSignificantWord(IR::Inst & inst)213 void FoldLeastSignificantWord(IR::Inst& inst) {
214     if (!inst.AreAllArgsImmediates()) {
215         return;
216     }
217 
218     const auto operand = inst.GetArg(0);
219     inst.ReplaceUsesWith(IR::Value{static_cast<u32>(operand.GetImmediateAsU64())});
220 }
221 
FoldMostSignificantBit(IR::Inst & inst)222 void FoldMostSignificantBit(IR::Inst& inst) {
223     if (!inst.AreAllArgsImmediates()) {
224         return;
225     }
226 
227     const auto operand = inst.GetArg(0);
228     inst.ReplaceUsesWith(IR::Value{(operand.GetImmediateAsU64() >> 31) != 0});
229 }
230 
FoldMostSignificantWord(IR::Inst & inst)231 void FoldMostSignificantWord(IR::Inst& inst) {
232     IR::Inst* carry_inst = inst.GetAssociatedPseudoOperation(Op::GetCarryFromOp);
233 
234     if (!inst.AreAllArgsImmediates()) {
235         return;
236     }
237 
238     const auto operand = inst.GetArg(0);
239     if (carry_inst) {
240         carry_inst->ReplaceUsesWith(IR::Value{Common::Bit<31>(operand.GetImmediateAsU64())});
241     }
242     inst.ReplaceUsesWith(IR::Value{static_cast<u32>(operand.GetImmediateAsU64() >> 32)});
243 }
244 
245 // Folds multiplication operations based on the following:
246 //
247 // 1. imm_x * imm_y -> result
248 // 2. x * 0 -> 0
249 // 3. 0 * y -> 0
250 // 4. x * 1 -> x
251 // 5. 1 * y -> y
252 //
FoldMultiply(IR::Inst & inst,bool is_32_bit)253 void FoldMultiply(IR::Inst& inst, bool is_32_bit) {
254     if (FoldCommutative(inst, is_32_bit, [](u64 a, u64 b) { return a * b; })) {
255         const auto rhs = inst.GetArg(1);
256         if (rhs.IsZero()) {
257             ReplaceUsesWith(inst, is_32_bit, 0);
258         } else if (rhs.IsUnsignedImmediate(1)) {
259             inst.ReplaceUsesWith(inst.GetArg(0));
260         }
261     }
262 }
263 
264 // Folds NOT operations if the contained value is an immediate.
FoldNOT(IR::Inst & inst,bool is_32_bit)265 void FoldNOT(IR::Inst& inst, bool is_32_bit) {
266     const auto operand = inst.GetArg(0);
267 
268     if (!operand.IsImmediate()) {
269         return;
270     }
271 
272     const u64 result = ~operand.GetImmediateAsU64();
273     ReplaceUsesWith(inst, is_32_bit, result);
274 }
275 
276 // Folds OR operations based on the following:
277 //
278 // 1. imm_x | imm_y -> result
279 // 2. x | 0 -> x
280 // 3. 0 | y -> y
281 //
FoldOR(IR::Inst & inst,bool is_32_bit)282 void FoldOR(IR::Inst& inst, bool is_32_bit) {
283     if (FoldCommutative(inst, is_32_bit, [](u64 a, u64 b) { return a | b; })) {
284         const auto rhs = inst.GetArg(1);
285         if (rhs.IsZero()) {
286             inst.ReplaceUsesWith(inst.GetArg(0));
287         }
288     }
289 }
290 
FoldShifts(IR::Inst & inst)291 bool FoldShifts(IR::Inst& inst) {
292     IR::Inst* carry_inst = inst.GetAssociatedPseudoOperation(Op::GetCarryFromOp);
293 
294     // The 32-bit variants can contain 3 arguments, while the
295     // 64-bit variants only contain 2.
296     if (inst.NumArgs() == 3 && !carry_inst) {
297         inst.SetArg(2, IR::Value(false));
298     }
299 
300     const auto shift_amount = inst.GetArg(1);
301     if (shift_amount.IsZero()) {
302         if (carry_inst) {
303             carry_inst->ReplaceUsesWith(inst.GetArg(2));
304         }
305         inst.ReplaceUsesWith(inst.GetArg(0));
306         return false;
307     }
308 
309     if (!inst.AreAllArgsImmediates() || carry_inst) {
310         return false;
311     }
312 
313     return true;
314 }
315 
FoldSignExtendXToWord(IR::Inst & inst)316 void FoldSignExtendXToWord(IR::Inst& inst) {
317     if (!inst.AreAllArgsImmediates()) {
318         return;
319     }
320 
321     const s64 value = inst.GetArg(0).GetImmediateAsS64();
322     inst.ReplaceUsesWith(IR::Value{static_cast<u32>(value)});
323 }
324 
FoldSignExtendXToLong(IR::Inst & inst)325 void FoldSignExtendXToLong(IR::Inst& inst) {
326     if (!inst.AreAllArgsImmediates()) {
327         return;
328     }
329 
330     const s64 value = inst.GetArg(0).GetImmediateAsS64();
331     inst.ReplaceUsesWith(IR::Value{static_cast<u64>(value)});
332 }
333 
FoldSub(IR::Inst & inst,bool is_32_bit)334 void FoldSub(IR::Inst& inst, bool is_32_bit) {
335     if (!inst.AreAllArgsImmediates() || inst.HasAssociatedPseudoOperation()) {
336         return;
337     }
338 
339     const auto lhs = inst.GetArg(0);
340     const auto rhs = inst.GetArg(1);
341     const auto carry = inst.GetArg(2);
342 
343     const u64 result = lhs.GetImmediateAsU64() + (~rhs.GetImmediateAsU64()) + carry.GetU1();
344     ReplaceUsesWith(inst, is_32_bit, result);
345 }
346 
FoldZeroExtendXToWord(IR::Inst & inst)347 void FoldZeroExtendXToWord(IR::Inst& inst) {
348     if (!inst.AreAllArgsImmediates()) {
349         return;
350     }
351 
352     const u64 value = inst.GetArg(0).GetImmediateAsU64();
353     inst.ReplaceUsesWith(IR::Value{static_cast<u32>(value)});
354 }
355 
FoldZeroExtendXToLong(IR::Inst & inst)356 void FoldZeroExtendXToLong(IR::Inst& inst) {
357     if (!inst.AreAllArgsImmediates()) {
358         return;
359     }
360 
361     const u64 value = inst.GetArg(0).GetImmediateAsU64();
362     inst.ReplaceUsesWith(IR::Value{value});
363 }
364 } // Anonymous namespace
365 
ConstantPropagation(IR::Block & block)366 void ConstantPropagation(IR::Block& block) {
367     for (auto& inst : block) {
368         const auto opcode = inst.GetOpcode();
369 
370         switch (opcode) {
371         case Op::LeastSignificantWord:
372             FoldLeastSignificantWord(inst);
373             break;
374         case Op::MostSignificantWord:
375             FoldMostSignificantWord(inst);
376             break;
377         case Op::LeastSignificantHalf:
378             FoldLeastSignificantHalf(inst);
379             break;
380         case Op::LeastSignificantByte:
381             FoldLeastSignificantByte(inst);
382             break;
383         case Op::MostSignificantBit:
384             FoldMostSignificantBit(inst);
385             break;
386         case Op::IsZero32:
387             if (inst.AreAllArgsImmediates()) {
388                 inst.ReplaceUsesWith(IR::Value{inst.GetArg(0).GetU32() == 0});
389             }
390             break;
391         case Op::IsZero64:
392             if (inst.AreAllArgsImmediates()) {
393                 inst.ReplaceUsesWith(IR::Value{inst.GetArg(0).GetU64() == 0});
394             }
395             break;
396         case Op::LogicalShiftLeft32:
397             if (FoldShifts(inst)) {
398                 ReplaceUsesWith(inst, true, Safe::LogicalShiftLeft<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8()));
399             }
400             break;
401         case Op::LogicalShiftLeft64:
402             if (FoldShifts(inst)) {
403                 ReplaceUsesWith(inst, false, Safe::LogicalShiftLeft<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8()));
404             }
405             break;
406         case Op::LogicalShiftRight32:
407             if (FoldShifts(inst)) {
408                 ReplaceUsesWith(inst, true, Safe::LogicalShiftRight<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8()));
409             }
410             break;
411         case Op::LogicalShiftRight64:
412             if (FoldShifts(inst)) {
413                 ReplaceUsesWith(inst, false, Safe::LogicalShiftRight<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8()));
414             }
415             break;
416         case Op::ArithmeticShiftRight32:
417             if (FoldShifts(inst)) {
418                 ReplaceUsesWith(inst, true, Safe::ArithmeticShiftRight<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8()));
419             }
420             break;
421         case Op::ArithmeticShiftRight64:
422             if (FoldShifts(inst)) {
423                 ReplaceUsesWith(inst, false, Safe::ArithmeticShiftRight<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8()));
424             }
425             break;
426         case Op::RotateRight32:
427             if (FoldShifts(inst)) {
428                 ReplaceUsesWith(inst, true, Common::RotateRight<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8()));
429             }
430             break;
431         case Op::RotateRight64:
432             if (FoldShifts(inst)) {
433                 ReplaceUsesWith(inst, false, Common::RotateRight<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8()));
434             }
435             break;
436         case Op::LogicalShiftLeftMasked32:
437             if (inst.AreAllArgsImmediates()) {
438                 ReplaceUsesWith(inst, true, inst.GetArg(0).GetU32() << (inst.GetArg(1).GetU32() & 0x1f));
439             }
440             break;
441         case Op::LogicalShiftLeftMasked64:
442             if (inst.AreAllArgsImmediates()) {
443                 ReplaceUsesWith(inst, false, inst.GetArg(0).GetU64() << (inst.GetArg(1).GetU64() & 0x3f));
444             }
445             break;
446         case Op::LogicalShiftRightMasked32:
447             if (inst.AreAllArgsImmediates()) {
448                 ReplaceUsesWith(inst, true, inst.GetArg(0).GetU32() >> (inst.GetArg(1).GetU32() & 0x1f));
449             }
450             break;
451         case Op::LogicalShiftRightMasked64:
452             if (inst.AreAllArgsImmediates()) {
453                 ReplaceUsesWith(inst, false, inst.GetArg(0).GetU64() >> (inst.GetArg(1).GetU64() & 0x3f));
454             }
455             break;
456         case Op::ArithmeticShiftRightMasked32:
457             if (inst.AreAllArgsImmediates()) {
458                 ReplaceUsesWith(inst, true, static_cast<s32>(inst.GetArg(0).GetU32()) >> (inst.GetArg(1).GetU32() & 0x1f));
459             }
460             break;
461         case Op::ArithmeticShiftRightMasked64:
462             if (inst.AreAllArgsImmediates()) {
463                 ReplaceUsesWith(inst, false, static_cast<s64>(inst.GetArg(0).GetU64()) >> (inst.GetArg(1).GetU64() & 0x3f));
464             }
465             break;
466         case Op::RotateRightMasked32:
467             if (inst.AreAllArgsImmediates()) {
468                 ReplaceUsesWith(inst, true, Common::RotateRight<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU32()));
469             }
470             break;
471         case Op::RotateRightMasked64:
472             if (inst.AreAllArgsImmediates()) {
473                 ReplaceUsesWith(inst, false, Common::RotateRight<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU64()));
474             }
475             break;
476         case Op::Add32:
477         case Op::Add64:
478             FoldAdd(inst, opcode == Op::Add32);
479             break;
480         case Op::Sub32:
481         case Op::Sub64:
482             FoldSub(inst, opcode == Op::Sub32);
483             break;
484         case Op::Mul32:
485         case Op::Mul64:
486             FoldMultiply(inst, opcode == Op::Mul32);
487             break;
488         case Op::SignedDiv32:
489         case Op::SignedDiv64:
490             FoldDivide(inst, opcode == Op::SignedDiv32, true);
491             break;
492         case Op::UnsignedDiv32:
493         case Op::UnsignedDiv64:
494             FoldDivide(inst, opcode == Op::UnsignedDiv32, false);
495             break;
496         case Op::And32:
497         case Op::And64:
498             FoldAND(inst, opcode == Op::And32);
499             break;
500         case Op::Eor32:
501         case Op::Eor64:
502             FoldEOR(inst, opcode == Op::Eor32);
503             break;
504         case Op::Or32:
505         case Op::Or64:
506             FoldOR(inst, opcode == Op::Or32);
507             break;
508         case Op::Not32:
509         case Op::Not64:
510             FoldNOT(inst, opcode == Op::Not32);
511             break;
512         case Op::SignExtendByteToWord:
513         case Op::SignExtendHalfToWord:
514             FoldSignExtendXToWord(inst);
515             break;
516         case Op::SignExtendByteToLong:
517         case Op::SignExtendHalfToLong:
518         case Op::SignExtendWordToLong:
519             FoldSignExtendXToLong(inst);
520             break;
521         case Op::ZeroExtendByteToWord:
522         case Op::ZeroExtendHalfToWord:
523             FoldZeroExtendXToWord(inst);
524             break;
525         case Op::ZeroExtendByteToLong:
526         case Op::ZeroExtendHalfToLong:
527         case Op::ZeroExtendWordToLong:
528             FoldZeroExtendXToLong(inst);
529             break;
530         case Op::ByteReverseWord:
531         case Op::ByteReverseHalf:
532         case Op::ByteReverseDual:
533             FoldByteReverse(inst, opcode);
534             break;
535         default:
536             break;
537         }
538     }
539 }
540 
541 } // namespace Dynarmic::Optimization
542