1 //=== lib/CodeGen/GlobalISel/AArch64PreLegalizerCombiner.cpp --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass does combining of machine instructions at the generic MI level,
10 // before the legalizer.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "AArch64GlobalISelUtils.h"
15 #include "AArch64TargetMachine.h"
16 #include "llvm/CodeGen/GlobalISel/CSEInfo.h"
17 #include "llvm/CodeGen/GlobalISel/Combiner.h"
18 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
19 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
20 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
21 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
22 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
23 #include "llvm/CodeGen/MachineDominators.h"
24 #include "llvm/CodeGen/MachineFunction.h"
25 #include "llvm/CodeGen/MachineFunctionPass.h"
26 #include "llvm/CodeGen/MachineRegisterInfo.h"
27 #include "llvm/CodeGen/TargetPassConfig.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/Support/Debug.h"
30 
31 #define DEBUG_TYPE "aarch64-prelegalizer-combiner"
32 
33 using namespace llvm;
34 using namespace MIPatternMatch;
35 
36 /// Return true if a G_FCONSTANT instruction is known to be better-represented
37 /// as a G_CONSTANT.
38 static bool matchFConstantToConstant(MachineInstr &MI,
39                                      MachineRegisterInfo &MRI) {
40   assert(MI.getOpcode() == TargetOpcode::G_FCONSTANT);
41   Register DstReg = MI.getOperand(0).getReg();
42   const unsigned DstSize = MRI.getType(DstReg).getSizeInBits();
43   if (DstSize != 32 && DstSize != 64)
44     return false;
45 
46   // When we're storing a value, it doesn't matter what register bank it's on.
47   // Since not all floating point constants can be materialized using a fmov,
48   // it makes more sense to just use a GPR.
49   return all_of(MRI.use_nodbg_instructions(DstReg),
50                 [](const MachineInstr &Use) { return Use.mayStore(); });
51 }
52 
53 /// Change a G_FCONSTANT into a G_CONSTANT.
54 static void applyFConstantToConstant(MachineInstr &MI) {
55   assert(MI.getOpcode() == TargetOpcode::G_FCONSTANT);
56   MachineIRBuilder MIB(MI);
57   const APFloat &ImmValAPF = MI.getOperand(1).getFPImm()->getValueAPF();
58   MIB.buildConstant(MI.getOperand(0).getReg(), ImmValAPF.bitcastToAPInt());
59   MI.eraseFromParent();
60 }
61 
62 /// Try to match a G_ICMP of a G_TRUNC with zero, in which the truncated bits
63 /// are sign bits. In this case, we can transform the G_ICMP to directly compare
64 /// the wide value with a zero.
65 static bool matchICmpRedundantTrunc(MachineInstr &MI, MachineRegisterInfo &MRI,
66                                     GISelKnownBits *KB, Register &MatchInfo) {
67   assert(MI.getOpcode() == TargetOpcode::G_ICMP && KB);
68 
69   auto Pred = (CmpInst::Predicate)MI.getOperand(1).getPredicate();
70   if (!ICmpInst::isEquality(Pred))
71     return false;
72 
73   Register LHS = MI.getOperand(2).getReg();
74   LLT LHSTy = MRI.getType(LHS);
75   if (!LHSTy.isScalar())
76     return false;
77 
78   Register RHS = MI.getOperand(3).getReg();
79   Register WideReg;
80 
81   if (!mi_match(LHS, MRI, m_GTrunc(m_Reg(WideReg))) ||
82       !mi_match(RHS, MRI, m_SpecificICst(0)))
83     return false;
84 
85   LLT WideTy = MRI.getType(WideReg);
86   if (KB->computeNumSignBits(WideReg) <=
87       WideTy.getSizeInBits() - LHSTy.getSizeInBits())
88     return false;
89 
90   MatchInfo = WideReg;
91   return true;
92 }
93 
94 static bool applyICmpRedundantTrunc(MachineInstr &MI, MachineRegisterInfo &MRI,
95                                     MachineIRBuilder &Builder,
96                                     GISelChangeObserver &Observer,
97                                     Register &WideReg) {
98   assert(MI.getOpcode() == TargetOpcode::G_ICMP);
99 
100   LLT WideTy = MRI.getType(WideReg);
101   // We're going to directly use the wide register as the LHS, and then use an
102   // equivalent size zero for RHS.
103   Builder.setInstrAndDebugLoc(MI);
104   auto WideZero = Builder.buildConstant(WideTy, 0);
105   Observer.changingInstr(MI);
106   MI.getOperand(2).setReg(WideReg);
107   MI.getOperand(3).setReg(WideZero.getReg(0));
108   Observer.changedInstr(MI);
109   return true;
110 }
111 
112 /// \returns true if it is possible to fold a constant into a G_GLOBAL_VALUE.
113 ///
114 /// e.g.
115 ///
116 /// %g = G_GLOBAL_VALUE @x -> %g = G_GLOBAL_VALUE @x + cst
117 static bool matchFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
118                                   std::pair<uint64_t, uint64_t> &MatchInfo) {
119   assert(MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
120   MachineFunction &MF = *MI.getMF();
121   auto &GlobalOp = MI.getOperand(1);
122   auto *GV = GlobalOp.getGlobal();
123   if (GV->isThreadLocal())
124     return false;
125 
126   // Don't allow anything that could represent offsets etc.
127   if (MF.getSubtarget<AArch64Subtarget>().ClassifyGlobalReference(
128           GV, MF.getTarget()) != AArch64II::MO_NO_FLAG)
129     return false;
130 
131   // Look for a G_GLOBAL_VALUE only used by G_PTR_ADDs against constants:
132   //
133   //  %g = G_GLOBAL_VALUE @x
134   //  %ptr1 = G_PTR_ADD %g, cst1
135   //  %ptr2 = G_PTR_ADD %g, cst2
136   //  ...
137   //  %ptrN = G_PTR_ADD %g, cstN
138   //
139   // Identify the *smallest* constant. We want to be able to form this:
140   //
141   //  %offset_g = G_GLOBAL_VALUE @x + min_cst
142   //  %g = G_PTR_ADD %offset_g, -min_cst
143   //  %ptr1 = G_PTR_ADD %g, cst1
144   //  ...
145   Register Dst = MI.getOperand(0).getReg();
146   uint64_t MinOffset = -1ull;
147   for (auto &UseInstr : MRI.use_nodbg_instructions(Dst)) {
148     if (UseInstr.getOpcode() != TargetOpcode::G_PTR_ADD)
149       return false;
150     auto Cst = getIConstantVRegValWithLookThrough(
151         UseInstr.getOperand(2).getReg(), MRI);
152     if (!Cst)
153       return false;
154     MinOffset = std::min(MinOffset, Cst->Value.getZExtValue());
155   }
156 
157   // Require that the new offset is larger than the existing one to avoid
158   // infinite loops.
159   uint64_t CurrOffset = GlobalOp.getOffset();
160   uint64_t NewOffset = MinOffset + CurrOffset;
161   if (NewOffset <= CurrOffset)
162     return false;
163 
164   // Check whether folding this offset is legal. It must not go out of bounds of
165   // the referenced object to avoid violating the code model, and must be
166   // smaller than 2^20 because this is the largest offset expressible in all
167   // object formats. (The IMAGE_REL_ARM64_PAGEBASE_REL21 relocation in COFF
168   // stores an immediate signed 21 bit offset.)
169   //
170   // This check also prevents us from folding negative offsets, which will end
171   // up being treated in the same way as large positive ones. They could also
172   // cause code model violations, and aren't really common enough to matter.
173   if (NewOffset >= (1 << 20))
174     return false;
175 
176   Type *T = GV->getValueType();
177   if (!T->isSized() ||
178       NewOffset > GV->getParent()->getDataLayout().getTypeAllocSize(T))
179     return false;
180   MatchInfo = std::make_pair(NewOffset, MinOffset);
181   return true;
182 }
183 
184 static bool applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
185                                   MachineIRBuilder &B,
186                                   GISelChangeObserver &Observer,
187                                   std::pair<uint64_t, uint64_t> &MatchInfo) {
188   // Change:
189   //
190   //  %g = G_GLOBAL_VALUE @x
191   //  %ptr1 = G_PTR_ADD %g, cst1
192   //  %ptr2 = G_PTR_ADD %g, cst2
193   //  ...
194   //  %ptrN = G_PTR_ADD %g, cstN
195   //
196   // To:
197   //
198   //  %offset_g = G_GLOBAL_VALUE @x + min_cst
199   //  %g = G_PTR_ADD %offset_g, -min_cst
200   //  %ptr1 = G_PTR_ADD %g, cst1
201   //  ...
202   //  %ptrN = G_PTR_ADD %g, cstN
203   //
204   // Then, the original G_PTR_ADDs should be folded later on so that they look
205   // like this:
206   //
207   //  %ptrN = G_PTR_ADD %offset_g, cstN - min_cst
208   uint64_t Offset, MinOffset;
209   std::tie(Offset, MinOffset) = MatchInfo;
210   B.setInstrAndDebugLoc(MI);
211   Observer.changingInstr(MI);
212   auto &GlobalOp = MI.getOperand(1);
213   auto *GV = GlobalOp.getGlobal();
214   GlobalOp.ChangeToGA(GV, Offset, GlobalOp.getTargetFlags());
215   Register Dst = MI.getOperand(0).getReg();
216   Register NewGVDst = MRI.cloneVirtualRegister(Dst);
217   MI.getOperand(0).setReg(NewGVDst);
218   Observer.changedInstr(MI);
219   B.buildPtrAdd(
220       Dst, NewGVDst,
221       B.buildConstant(LLT::scalar(64), -static_cast<int64_t>(MinOffset)));
222   return true;
223 }
224 
225 static bool tryToSimplifyUADDO(MachineInstr &MI, MachineIRBuilder &B,
226                                CombinerHelper &Helper,
227                                GISelChangeObserver &Observer) {
228   // Try simplify G_UADDO with 8 or 16 bit operands to wide G_ADD and TBNZ if
229   // result is only used in the no-overflow case. It is restricted to cases
230   // where we know that the high-bits of the operands are 0. If there's an
231   // overflow, then the the 9th or 17th bit must be set, which can be checked
232   // using TBNZ.
233   //
234   // Change (for UADDOs on 8 and 16 bits):
235   //
236   //   %z0 = G_ASSERT_ZEXT _
237   //   %op0 = G_TRUNC %z0
238   //   %z1 = G_ASSERT_ZEXT _
239   //   %op1 = G_TRUNC %z1
240   //   %val, %cond = G_UADDO %op0, %op1
241   //   G_BRCOND %cond, %error.bb
242   //
243   // error.bb:
244   //   (no successors and no uses of %val)
245   //
246   // To:
247   //
248   //   %z0 = G_ASSERT_ZEXT _
249   //   %z1 = G_ASSERT_ZEXT _
250   //   %add = G_ADD %z0, %z1
251   //   %val = G_TRUNC %add
252   //   %bit = G_AND %add, 1 << scalar-size-in-bits(%op1)
253   //   %cond = G_ICMP NE, %bit, 0
254   //   G_BRCOND %cond, %error.bb
255 
256   auto &MRI = *B.getMRI();
257 
258   MachineOperand *DefOp0 = MRI.getOneDef(MI.getOperand(2).getReg());
259   MachineOperand *DefOp1 = MRI.getOneDef(MI.getOperand(3).getReg());
260   Register Op0Wide;
261   Register Op1Wide;
262   if (!mi_match(DefOp0->getParent(), MRI, m_GTrunc(m_Reg(Op0Wide))) ||
263       !mi_match(DefOp1->getParent(), MRI, m_GTrunc(m_Reg(Op1Wide))))
264     return false;
265   LLT WideTy0 = MRI.getType(Op0Wide);
266   LLT WideTy1 = MRI.getType(Op1Wide);
267   Register ResVal = MI.getOperand(0).getReg();
268   LLT OpTy = MRI.getType(ResVal);
269   MachineInstr *Op0WideDef = MRI.getVRegDef(Op0Wide);
270   MachineInstr *Op1WideDef = MRI.getVRegDef(Op1Wide);
271 
272   unsigned OpTySize = OpTy.getScalarSizeInBits();
273   // First check that the G_TRUNC feeding the G_UADDO are no-ops, because the
274   // inputs have been zero-extended.
275   if (Op0WideDef->getOpcode() != TargetOpcode::G_ASSERT_ZEXT ||
276       Op1WideDef->getOpcode() != TargetOpcode::G_ASSERT_ZEXT ||
277       OpTySize != Op0WideDef->getOperand(2).getImm() ||
278       OpTySize != Op1WideDef->getOperand(2).getImm())
279     return false;
280 
281   // Only scalar UADDO with either 8 or 16 bit operands are handled.
282   if (!WideTy0.isScalar() || !WideTy1.isScalar() || WideTy0 != WideTy1 ||
283       OpTySize >= WideTy0.getScalarSizeInBits() ||
284       (OpTySize != 8 && OpTySize != 16))
285     return false;
286 
287   // The overflow-status result must be used by a branch only.
288   Register ResStatus = MI.getOperand(1).getReg();
289   if (!MRI.hasOneNonDBGUse(ResStatus))
290     return false;
291   MachineInstr *CondUser = &*MRI.use_instr_nodbg_begin(ResStatus);
292   if (CondUser->getOpcode() != TargetOpcode::G_BRCOND)
293     return false;
294 
295   // Make sure the computed result is only used in the no-overflow blocks.
296   MachineBasicBlock *CurrentMBB = MI.getParent();
297   MachineBasicBlock *FailMBB = CondUser->getOperand(1).getMBB();
298   if (!FailMBB->succ_empty() || CondUser->getParent() != CurrentMBB)
299     return false;
300   if (any_of(MRI.use_nodbg_instructions(ResVal),
301              [&MI, FailMBB, CurrentMBB](MachineInstr &I) {
302                return &MI != &I &&
303                       (I.getParent() == FailMBB || I.getParent() == CurrentMBB);
304              }))
305     return false;
306 
307   // Remove G_ADDO.
308   B.setInstrAndDebugLoc(*MI.getNextNode());
309   MI.eraseFromParent();
310 
311   // Emit wide add.
312   Register AddDst = MRI.cloneVirtualRegister(Op0Wide);
313   B.buildInstr(TargetOpcode::G_ADD, {AddDst}, {Op0Wide, Op1Wide});
314 
315   // Emit check of the 9th or 17th bit and update users (the branch). This will
316   // later be folded to TBNZ.
317   Register CondBit = MRI.cloneVirtualRegister(Op0Wide);
318   B.buildAnd(
319       CondBit, AddDst,
320       B.buildConstant(LLT::scalar(32), OpTySize == 8 ? 1 << 8 : 1 << 16));
321   B.buildICmp(CmpInst::ICMP_NE, ResStatus, CondBit,
322               B.buildConstant(LLT::scalar(32), 0));
323 
324   // Update ZEXts users of the result value. Because all uses are in the
325   // no-overflow case, we know that the top bits are 0 and we can ignore ZExts.
326   B.buildZExtOrTrunc(ResVal, AddDst);
327   for (MachineOperand &U : make_early_inc_range(MRI.use_operands(ResVal))) {
328     Register WideReg;
329     if (mi_match(U.getParent(), MRI, m_GZExt(m_Reg(WideReg)))) {
330       auto OldR = U.getParent()->getOperand(0).getReg();
331       Observer.erasingInstr(*U.getParent());
332       U.getParent()->eraseFromParent();
333       Helper.replaceRegWith(MRI, OldR, AddDst);
334     }
335   }
336 
337   return true;
338 }
339 
340 class AArch64PreLegalizerCombinerHelperState {
341 protected:
342   CombinerHelper &Helper;
343 
344 public:
345   AArch64PreLegalizerCombinerHelperState(CombinerHelper &Helper)
346       : Helper(Helper) {}
347 };
348 
349 #define AARCH64PRELEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS
350 #include "AArch64GenPreLegalizeGICombiner.inc"
351 #undef AARCH64PRELEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS
352 
353 namespace {
354 #define AARCH64PRELEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H
355 #include "AArch64GenPreLegalizeGICombiner.inc"
356 #undef AARCH64PRELEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H
357 
358 class AArch64PreLegalizerCombinerInfo : public CombinerInfo {
359   GISelKnownBits *KB;
360   MachineDominatorTree *MDT;
361   AArch64GenPreLegalizerCombinerHelperRuleConfig GeneratedRuleCfg;
362 
363 public:
364   AArch64PreLegalizerCombinerInfo(bool EnableOpt, bool OptSize, bool MinSize,
365                                   GISelKnownBits *KB, MachineDominatorTree *MDT)
366       : CombinerInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false,
367                      /*LegalizerInfo*/ nullptr, EnableOpt, OptSize, MinSize),
368         KB(KB), MDT(MDT) {
369     if (!GeneratedRuleCfg.parseCommandLineOption())
370       report_fatal_error("Invalid rule identifier");
371   }
372 
373   bool combine(GISelChangeObserver &Observer, MachineInstr &MI,
374                MachineIRBuilder &B) const override;
375 };
376 
377 bool AArch64PreLegalizerCombinerInfo::combine(GISelChangeObserver &Observer,
378                                               MachineInstr &MI,
379                                               MachineIRBuilder &B) const {
380   CombinerHelper Helper(Observer, B, KB, MDT);
381   AArch64GenPreLegalizerCombinerHelper Generated(GeneratedRuleCfg, Helper);
382 
383   if (Generated.tryCombineAll(Observer, MI, B))
384     return true;
385 
386   unsigned Opc = MI.getOpcode();
387   switch (Opc) {
388   case TargetOpcode::G_CONCAT_VECTORS:
389     return Helper.tryCombineConcatVectors(MI);
390   case TargetOpcode::G_SHUFFLE_VECTOR:
391     return Helper.tryCombineShuffleVector(MI);
392   case TargetOpcode::G_UADDO:
393     return tryToSimplifyUADDO(MI, B, Helper, Observer);
394   case TargetOpcode::G_MEMCPY_INLINE:
395     return Helper.tryEmitMemcpyInline(MI);
396   case TargetOpcode::G_MEMCPY:
397   case TargetOpcode::G_MEMMOVE:
398   case TargetOpcode::G_MEMSET: {
399     // If we're at -O0 set a maxlen of 32 to inline, otherwise let the other
400     // heuristics decide.
401     unsigned MaxLen = EnableOpt ? 0 : 32;
402     // Try to inline memcpy type calls if optimizations are enabled.
403     if (Helper.tryCombineMemCpyFamily(MI, MaxLen))
404       return true;
405     if (Opc == TargetOpcode::G_MEMSET)
406       return llvm::AArch64GISelUtils::tryEmitBZero(MI, B, EnableMinSize);
407     return false;
408   }
409   }
410 
411   return false;
412 }
413 
414 #define AARCH64PRELEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP
415 #include "AArch64GenPreLegalizeGICombiner.inc"
416 #undef AARCH64PRELEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP
417 
418 // Pass boilerplate
419 // ================
420 
421 class AArch64PreLegalizerCombiner : public MachineFunctionPass {
422 public:
423   static char ID;
424 
425   AArch64PreLegalizerCombiner();
426 
427   StringRef getPassName() const override { return "AArch64PreLegalizerCombiner"; }
428 
429   bool runOnMachineFunction(MachineFunction &MF) override;
430 
431   void getAnalysisUsage(AnalysisUsage &AU) const override;
432 };
433 } // end anonymous namespace
434 
435 void AArch64PreLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const {
436   AU.addRequired<TargetPassConfig>();
437   AU.setPreservesCFG();
438   getSelectionDAGFallbackAnalysisUsage(AU);
439   AU.addRequired<GISelKnownBitsAnalysis>();
440   AU.addPreserved<GISelKnownBitsAnalysis>();
441   AU.addRequired<MachineDominatorTree>();
442   AU.addPreserved<MachineDominatorTree>();
443   AU.addRequired<GISelCSEAnalysisWrapperPass>();
444   AU.addPreserved<GISelCSEAnalysisWrapperPass>();
445   MachineFunctionPass::getAnalysisUsage(AU);
446 }
447 
448 AArch64PreLegalizerCombiner::AArch64PreLegalizerCombiner()
449     : MachineFunctionPass(ID) {
450   initializeAArch64PreLegalizerCombinerPass(*PassRegistry::getPassRegistry());
451 }
452 
453 bool AArch64PreLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) {
454   if (MF.getProperties().hasProperty(
455           MachineFunctionProperties::Property::FailedISel))
456     return false;
457   auto &TPC = getAnalysis<TargetPassConfig>();
458 
459   // Enable CSE.
460   GISelCSEAnalysisWrapper &Wrapper =
461       getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper();
462   auto *CSEInfo = &Wrapper.get(TPC.getCSEConfig());
463 
464   const Function &F = MF.getFunction();
465   bool EnableOpt =
466       MF.getTarget().getOptLevel() != CodeGenOpt::None && !skipFunction(F);
467   GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF);
468   MachineDominatorTree *MDT = &getAnalysis<MachineDominatorTree>();
469   AArch64PreLegalizerCombinerInfo PCInfo(EnableOpt, F.hasOptSize(),
470                                          F.hasMinSize(), KB, MDT);
471   Combiner C(PCInfo, &TPC);
472   return C.combineMachineInstrs(MF, CSEInfo);
473 }
474 
475 char AArch64PreLegalizerCombiner::ID = 0;
476 INITIALIZE_PASS_BEGIN(AArch64PreLegalizerCombiner, DEBUG_TYPE,
477                       "Combine AArch64 machine instrs before legalization",
478                       false, false)
479 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
480 INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis)
481 INITIALIZE_PASS_DEPENDENCY(GISelCSEAnalysisWrapperPass)
482 INITIALIZE_PASS_END(AArch64PreLegalizerCombiner, DEBUG_TYPE,
483                     "Combine AArch64 machine instrs before legalization", false,
484                     false)
485 
486 
487 namespace llvm {
488 FunctionPass *createAArch64PreLegalizerCombiner() {
489   return new AArch64PreLegalizerCombiner();
490 }
491 } // end namespace llvm
492