1 //===- AArch64MacroFusion.cpp - AArch64 Macro Fusion ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file This file contains the AArch64 implementation of the DAG scheduling
10 ///  mutation to pair instructions back to back.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "AArch64MacroFusion.h"
15 #include "AArch64Subtarget.h"
16 #include "llvm/CodeGen/MacroFusion.h"
17 #include "llvm/CodeGen/TargetInstrInfo.h"
18 
19 using namespace llvm;
20 
21 /// CMN, CMP, TST followed by Bcc
22 static bool isArithmeticBccPair(const MachineInstr *FirstMI,
23                                 const MachineInstr &SecondMI, bool CmpOnly) {
24   if (SecondMI.getOpcode() != AArch64::Bcc)
25     return false;
26 
27   // Assume the 1st instr to be a wildcard if it is unspecified.
28   if (FirstMI == nullptr)
29     return true;
30 
31   // If we're in CmpOnly mode, we only fuse arithmetic instructions that
32   // discard their result.
33   if (CmpOnly && FirstMI->getOperand(0).isReg() &&
34       !(FirstMI->getOperand(0).getReg() == AArch64::XZR ||
35         FirstMI->getOperand(0).getReg() == AArch64::WZR)) {
36     return false;
37   }
38 
39   switch (FirstMI->getOpcode()) {
40   case AArch64::ADDSWri:
41   case AArch64::ADDSWrr:
42   case AArch64::ADDSXri:
43   case AArch64::ADDSXrr:
44   case AArch64::ANDSWri:
45   case AArch64::ANDSWrr:
46   case AArch64::ANDSXri:
47   case AArch64::ANDSXrr:
48   case AArch64::SUBSWri:
49   case AArch64::SUBSWrr:
50   case AArch64::SUBSXri:
51   case AArch64::SUBSXrr:
52   case AArch64::BICSWrr:
53   case AArch64::BICSXrr:
54     return true;
55   case AArch64::ADDSWrs:
56   case AArch64::ADDSXrs:
57   case AArch64::ANDSWrs:
58   case AArch64::ANDSXrs:
59   case AArch64::SUBSWrs:
60   case AArch64::SUBSXrs:
61   case AArch64::BICSWrs:
62   case AArch64::BICSXrs:
63     // Shift value can be 0 making these behave like the "rr" variant...
64     return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
65   }
66 
67   return false;
68 }
69 
70 /// ALU operations followed by CBZ/CBNZ.
71 static bool isArithmeticCbzPair(const MachineInstr *FirstMI,
72                                 const MachineInstr &SecondMI) {
73   if (SecondMI.getOpcode() != AArch64::CBZW &&
74       SecondMI.getOpcode() != AArch64::CBZX &&
75       SecondMI.getOpcode() != AArch64::CBNZW &&
76       SecondMI.getOpcode() != AArch64::CBNZX)
77     return false;
78 
79   // Assume the 1st instr to be a wildcard if it is unspecified.
80   if (FirstMI == nullptr)
81     return true;
82 
83   switch (FirstMI->getOpcode()) {
84   case AArch64::ADDWri:
85   case AArch64::ADDWrr:
86   case AArch64::ADDXri:
87   case AArch64::ADDXrr:
88   case AArch64::ANDWri:
89   case AArch64::ANDWrr:
90   case AArch64::ANDXri:
91   case AArch64::ANDXrr:
92   case AArch64::EORWri:
93   case AArch64::EORWrr:
94   case AArch64::EORXri:
95   case AArch64::EORXrr:
96   case AArch64::ORRWri:
97   case AArch64::ORRWrr:
98   case AArch64::ORRXri:
99   case AArch64::ORRXrr:
100   case AArch64::SUBWri:
101   case AArch64::SUBWrr:
102   case AArch64::SUBXri:
103   case AArch64::SUBXrr:
104     return true;
105   case AArch64::ADDWrs:
106   case AArch64::ADDXrs:
107   case AArch64::ANDWrs:
108   case AArch64::ANDXrs:
109   case AArch64::SUBWrs:
110   case AArch64::SUBXrs:
111   case AArch64::BICWrs:
112   case AArch64::BICXrs:
113     // Shift value can be 0 making these behave like the "rr" variant...
114     return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
115   }
116 
117   return false;
118 }
119 
120 /// AES crypto encoding or decoding.
121 static bool isAESPair(const MachineInstr *FirstMI,
122                       const MachineInstr &SecondMI) {
123   // Assume the 1st instr to be a wildcard if it is unspecified.
124   switch (SecondMI.getOpcode()) {
125   // AES encode.
126   case AArch64::AESMCrr:
127   case AArch64::AESMCrrTied:
128     return FirstMI == nullptr || FirstMI->getOpcode() == AArch64::AESErr;
129   // AES decode.
130   case AArch64::AESIMCrr:
131   case AArch64::AESIMCrrTied:
132     return FirstMI == nullptr || FirstMI->getOpcode() == AArch64::AESDrr;
133   }
134 
135   return false;
136 }
137 
138 /// AESE/AESD/PMULL + EOR.
139 static bool isCryptoEORPair(const MachineInstr *FirstMI,
140                             const MachineInstr &SecondMI) {
141   if (SecondMI.getOpcode() != AArch64::EORv16i8)
142     return false;
143 
144   // Assume the 1st instr to be a wildcard if it is unspecified.
145   if (FirstMI == nullptr)
146     return true;
147 
148   switch (FirstMI->getOpcode()) {
149   case AArch64::AESErr:
150   case AArch64::AESDrr:
151   case AArch64::PMULLv16i8:
152   case AArch64::PMULLv8i8:
153   case AArch64::PMULLv1i64:
154   case AArch64::PMULLv2i64:
155     return true;
156   }
157 
158   return false;
159 }
160 
161 static bool isAdrpAddPair(const MachineInstr *FirstMI,
162                           const MachineInstr &SecondMI) {
163   // Assume the 1st instr to be a wildcard if it is unspecified.
164   if ((FirstMI == nullptr || FirstMI->getOpcode() == AArch64::ADRP) &&
165       SecondMI.getOpcode() == AArch64::ADDXri)
166     return true;
167   return false;
168 }
169 
170 /// Literal generation.
171 static bool isLiteralsPair(const MachineInstr *FirstMI,
172                            const MachineInstr &SecondMI) {
173   // Assume the 1st instr to be a wildcard if it is unspecified.
174   // 32 bit immediate.
175   if ((FirstMI == nullptr || FirstMI->getOpcode() == AArch64::MOVZWi) &&
176       (SecondMI.getOpcode() == AArch64::MOVKWi &&
177        SecondMI.getOperand(3).getImm() == 16))
178     return true;
179 
180   // Lower half of 64 bit immediate.
181   if((FirstMI == nullptr || FirstMI->getOpcode() == AArch64::MOVZXi) &&
182      (SecondMI.getOpcode() == AArch64::MOVKXi &&
183       SecondMI.getOperand(3).getImm() == 16))
184     return true;
185 
186   // Upper half of 64 bit immediate.
187   if ((FirstMI == nullptr ||
188        (FirstMI->getOpcode() == AArch64::MOVKXi &&
189         FirstMI->getOperand(3).getImm() == 32)) &&
190       (SecondMI.getOpcode() == AArch64::MOVKXi &&
191        SecondMI.getOperand(3).getImm() == 48))
192     return true;
193 
194   return false;
195 }
196 
197 /// Fuse address generation and loads or stores.
198 static bool isAddressLdStPair(const MachineInstr *FirstMI,
199                               const MachineInstr &SecondMI) {
200   switch (SecondMI.getOpcode()) {
201   case AArch64::STRBBui:
202   case AArch64::STRBui:
203   case AArch64::STRDui:
204   case AArch64::STRHHui:
205   case AArch64::STRHui:
206   case AArch64::STRQui:
207   case AArch64::STRSui:
208   case AArch64::STRWui:
209   case AArch64::STRXui:
210   case AArch64::LDRBBui:
211   case AArch64::LDRBui:
212   case AArch64::LDRDui:
213   case AArch64::LDRHHui:
214   case AArch64::LDRHui:
215   case AArch64::LDRQui:
216   case AArch64::LDRSui:
217   case AArch64::LDRWui:
218   case AArch64::LDRXui:
219   case AArch64::LDRSBWui:
220   case AArch64::LDRSBXui:
221   case AArch64::LDRSHWui:
222   case AArch64::LDRSHXui:
223   case AArch64::LDRSWui:
224     // Assume the 1st instr to be a wildcard if it is unspecified.
225     if (FirstMI == nullptr)
226       return true;
227 
228    switch (FirstMI->getOpcode()) {
229     case AArch64::ADR:
230       return SecondMI.getOperand(2).getImm() == 0;
231     case AArch64::ADRP:
232       return true;
233     }
234   }
235 
236   return false;
237 }
238 
239 /// Compare and conditional select.
240 static bool isCCSelectPair(const MachineInstr *FirstMI,
241                            const MachineInstr &SecondMI) {
242   // 32 bits
243   if (SecondMI.getOpcode() == AArch64::CSELWr) {
244     // Assume the 1st instr to be a wildcard if it is unspecified.
245     if (FirstMI == nullptr)
246       return true;
247 
248     if (FirstMI->definesRegister(AArch64::WZR))
249       switch (FirstMI->getOpcode()) {
250       case AArch64::SUBSWrs:
251         return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
252       case AArch64::SUBSWrx:
253         return !AArch64InstrInfo::hasExtendedReg(*FirstMI);
254       case AArch64::SUBSWrr:
255       case AArch64::SUBSWri:
256         return true;
257       }
258   }
259 
260   // 64 bits
261   if (SecondMI.getOpcode() == AArch64::CSELXr) {
262     // Assume the 1st instr to be a wildcard if it is unspecified.
263     if (FirstMI == nullptr)
264       return true;
265 
266     if (FirstMI->definesRegister(AArch64::XZR))
267       switch (FirstMI->getOpcode()) {
268       case AArch64::SUBSXrs:
269         return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
270       case AArch64::SUBSXrx:
271       case AArch64::SUBSXrx64:
272         return !AArch64InstrInfo::hasExtendedReg(*FirstMI);
273       case AArch64::SUBSXrr:
274       case AArch64::SUBSXri:
275         return true;
276       }
277   }
278 
279   return false;
280 }
281 
282 // Arithmetic and logic.
283 static bool isArithmeticLogicPair(const MachineInstr *FirstMI,
284                                   const MachineInstr &SecondMI) {
285   if (AArch64InstrInfo::hasShiftedReg(SecondMI))
286     return false;
287 
288   switch (SecondMI.getOpcode()) {
289   // Arithmetic
290   case AArch64::ADDWrr:
291   case AArch64::ADDXrr:
292   case AArch64::SUBWrr:
293   case AArch64::SUBXrr:
294   case AArch64::ADDWrs:
295   case AArch64::ADDXrs:
296   case AArch64::SUBWrs:
297   case AArch64::SUBXrs:
298   // Logic
299   case AArch64::ANDWrr:
300   case AArch64::ANDXrr:
301   case AArch64::BICWrr:
302   case AArch64::BICXrr:
303   case AArch64::EONWrr:
304   case AArch64::EONXrr:
305   case AArch64::EORWrr:
306   case AArch64::EORXrr:
307   case AArch64::ORNWrr:
308   case AArch64::ORNXrr:
309   case AArch64::ORRWrr:
310   case AArch64::ORRXrr:
311   case AArch64::ANDWrs:
312   case AArch64::ANDXrs:
313   case AArch64::BICWrs:
314   case AArch64::BICXrs:
315   case AArch64::EONWrs:
316   case AArch64::EONXrs:
317   case AArch64::EORWrs:
318   case AArch64::EORXrs:
319   case AArch64::ORNWrs:
320   case AArch64::ORNXrs:
321   case AArch64::ORRWrs:
322   case AArch64::ORRXrs:
323     // Assume the 1st instr to be a wildcard if it is unspecified.
324     if (FirstMI == nullptr)
325       return true;
326 
327     // Arithmetic
328     switch (FirstMI->getOpcode()) {
329     case AArch64::ADDWrr:
330     case AArch64::ADDXrr:
331     case AArch64::ADDSWrr:
332     case AArch64::ADDSXrr:
333     case AArch64::SUBWrr:
334     case AArch64::SUBXrr:
335     case AArch64::SUBSWrr:
336     case AArch64::SUBSXrr:
337       return true;
338     case AArch64::ADDWrs:
339     case AArch64::ADDXrs:
340     case AArch64::ADDSWrs:
341     case AArch64::ADDSXrs:
342     case AArch64::SUBWrs:
343     case AArch64::SUBXrs:
344     case AArch64::SUBSWrs:
345     case AArch64::SUBSXrs:
346       return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
347     }
348     break;
349 
350   // Arithmetic, setting flags.
351   case AArch64::ADDSWrr:
352   case AArch64::ADDSXrr:
353   case AArch64::SUBSWrr:
354   case AArch64::SUBSXrr:
355   case AArch64::ADDSWrs:
356   case AArch64::ADDSXrs:
357   case AArch64::SUBSWrs:
358   case AArch64::SUBSXrs:
359     // Assume the 1st instr to be a wildcard if it is unspecified.
360     if (FirstMI == nullptr)
361       return true;
362 
363     // Arithmetic, not setting flags.
364     switch (FirstMI->getOpcode()) {
365     case AArch64::ADDWrr:
366     case AArch64::ADDXrr:
367     case AArch64::SUBWrr:
368     case AArch64::SUBXrr:
369       return true;
370     case AArch64::ADDWrs:
371     case AArch64::ADDXrs:
372     case AArch64::SUBWrs:
373     case AArch64::SUBXrs:
374       return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
375     }
376     break;
377   }
378 
379   return false;
380 }
381 
382 // "(A + B) + 1" or "(A - B) - 1"
383 static bool isAddSub2RegAndConstOnePair(const MachineInstr *FirstMI,
384                                         const MachineInstr &SecondMI) {
385   bool NeedsSubtract = false;
386 
387   // The 2nd instr must be an add-immediate or subtract-immediate.
388   switch (SecondMI.getOpcode()) {
389   case AArch64::SUBWri:
390   case AArch64::SUBXri:
391     NeedsSubtract = true;
392     [[fallthrough]];
393   case AArch64::ADDWri:
394   case AArch64::ADDXri:
395     break;
396 
397   default:
398     return false;
399   }
400 
401   // The immediate in the 2nd instr must be "1".
402   if (!SecondMI.getOperand(2).isImm() || SecondMI.getOperand(2).getImm() != 1) {
403     return false;
404   }
405 
406   // Assume the 1st instr to be a wildcard if it is unspecified.
407   if (FirstMI == nullptr) {
408     return true;
409   }
410 
411   switch (FirstMI->getOpcode()) {
412   case AArch64::SUBWrs:
413   case AArch64::SUBXrs:
414     if (AArch64InstrInfo::hasShiftedReg(*FirstMI))
415       return false;
416     [[fallthrough]];
417   case AArch64::SUBWrr:
418   case AArch64::SUBXrr:
419     if (NeedsSubtract) {
420       return true;
421     }
422     break;
423 
424   case AArch64::ADDWrs:
425   case AArch64::ADDXrs:
426     if (AArch64InstrInfo::hasShiftedReg(*FirstMI))
427       return false;
428     [[fallthrough]];
429   case AArch64::ADDWrr:
430   case AArch64::ADDXrr:
431     if (!NeedsSubtract) {
432       return true;
433     }
434     break;
435   }
436 
437   return false;
438 }
439 
440 /// \brief Check if the instr pair, FirstMI and SecondMI, should be fused
441 /// together. Given SecondMI, when FirstMI is unspecified, then check if
442 /// SecondMI may be part of a fused pair at all.
443 static bool shouldScheduleAdjacent(const TargetInstrInfo &TII,
444                                    const TargetSubtargetInfo &TSI,
445                                    const MachineInstr *FirstMI,
446                                    const MachineInstr &SecondMI) {
447   const AArch64Subtarget &ST = static_cast<const AArch64Subtarget&>(TSI);
448 
449   // All checking functions assume that the 1st instr is a wildcard if it is
450   // unspecified.
451   if (ST.hasCmpBccFusion() || ST.hasArithmeticBccFusion()) {
452     bool CmpOnly = !ST.hasArithmeticBccFusion();
453     if (isArithmeticBccPair(FirstMI, SecondMI, CmpOnly))
454       return true;
455   }
456   if (ST.hasArithmeticCbzFusion() && isArithmeticCbzPair(FirstMI, SecondMI))
457     return true;
458   if (ST.hasFuseAES() && isAESPair(FirstMI, SecondMI))
459     return true;
460   if (ST.hasFuseCryptoEOR() && isCryptoEORPair(FirstMI, SecondMI))
461     return true;
462   if (ST.hasFuseAdrpAdd() && isAdrpAddPair(FirstMI, SecondMI))
463     return true;
464   if (ST.hasFuseLiterals() && isLiteralsPair(FirstMI, SecondMI))
465     return true;
466   if (ST.hasFuseAddress() && isAddressLdStPair(FirstMI, SecondMI))
467     return true;
468   if (ST.hasFuseCCSelect() && isCCSelectPair(FirstMI, SecondMI))
469     return true;
470   if (ST.hasFuseArithmeticLogic() && isArithmeticLogicPair(FirstMI, SecondMI))
471     return true;
472   if (ST.hasFuseAddSub2RegAndConstOne() &&
473       isAddSub2RegAndConstOnePair(FirstMI, SecondMI))
474     return true;
475 
476   return false;
477 }
478 
479 std::unique_ptr<ScheduleDAGMutation>
480 llvm::createAArch64MacroFusionDAGMutation() {
481   return createMacroFusionDAGMutation(shouldScheduleAdjacent);
482 }
483