1 //=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Post-legalization combines on generic MachineInstrs.
11 ///
12 /// The combines here must preserve instruction legality.
13 ///
14 /// Lowering combines (e.g. pseudo matching) should be handled by
15 /// AArch64PostLegalizerLowering.
16 ///
17 /// Combines which don't rely on instruction legality should go in the
18 /// AArch64PreLegalizerCombiner.
19 ///
20 //===----------------------------------------------------------------------===//
21
22 #include "AArch64TargetMachine.h"
23 #include "llvm/CodeGen/GlobalISel/Combiner.h"
24 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
25 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
26 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
27 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
28 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
29 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
30 #include "llvm/CodeGen/GlobalISel/Utils.h"
31 #include "llvm/CodeGen/MachineDominators.h"
32 #include "llvm/CodeGen/MachineFunctionPass.h"
33 #include "llvm/CodeGen/MachineRegisterInfo.h"
34 #include "llvm/CodeGen/TargetOpcodes.h"
35 #include "llvm/CodeGen/TargetPassConfig.h"
36 #include "llvm/Support/Debug.h"
37
38 #define DEBUG_TYPE "aarch64-postlegalizer-combiner"
39
40 using namespace llvm;
41 using namespace MIPatternMatch;
42
43 /// This combine tries do what performExtractVectorEltCombine does in SDAG.
44 /// Rewrite for pairwise fadd pattern
45 /// (s32 (g_extract_vector_elt
46 /// (g_fadd (vXs32 Other)
47 /// (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0))
48 /// ->
49 /// (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0)
50 /// (g_extract_vector_elt (vXs32 Other) 1))
matchExtractVecEltPairwiseAdd(MachineInstr & MI,MachineRegisterInfo & MRI,std::tuple<unsigned,LLT,Register> & MatchInfo)51 bool matchExtractVecEltPairwiseAdd(
52 MachineInstr &MI, MachineRegisterInfo &MRI,
53 std::tuple<unsigned, LLT, Register> &MatchInfo) {
54 Register Src1 = MI.getOperand(1).getReg();
55 Register Src2 = MI.getOperand(2).getReg();
56 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
57
58 auto Cst = getIConstantVRegValWithLookThrough(Src2, MRI);
59 if (!Cst || Cst->Value != 0)
60 return false;
61 // SDAG also checks for FullFP16, but this looks to be beneficial anyway.
62
63 // Now check for an fadd operation. TODO: expand this for integer add?
64 auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI);
65 if (!FAddMI)
66 return false;
67
68 // If we add support for integer add, must restrict these types to just s64.
69 unsigned DstSize = DstTy.getSizeInBits();
70 if (DstSize != 16 && DstSize != 32 && DstSize != 64)
71 return false;
72
73 Register Src1Op1 = FAddMI->getOperand(1).getReg();
74 Register Src1Op2 = FAddMI->getOperand(2).getReg();
75 MachineInstr *Shuffle =
76 getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI);
77 MachineInstr *Other = MRI.getVRegDef(Src1Op1);
78 if (!Shuffle) {
79 Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI);
80 Other = MRI.getVRegDef(Src1Op2);
81 }
82
83 // We're looking for a shuffle that moves the second element to index 0.
84 if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 &&
85 Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) {
86 std::get<0>(MatchInfo) = TargetOpcode::G_FADD;
87 std::get<1>(MatchInfo) = DstTy;
88 std::get<2>(MatchInfo) = Other->getOperand(0).getReg();
89 return true;
90 }
91 return false;
92 }
93
applyExtractVecEltPairwiseAdd(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,std::tuple<unsigned,LLT,Register> & MatchInfo)94 bool applyExtractVecEltPairwiseAdd(
95 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
96 std::tuple<unsigned, LLT, Register> &MatchInfo) {
97 unsigned Opc = std::get<0>(MatchInfo);
98 assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!");
99 // We want to generate two extracts of elements 0 and 1, and add them.
100 LLT Ty = std::get<1>(MatchInfo);
101 Register Src = std::get<2>(MatchInfo);
102 LLT s64 = LLT::scalar(64);
103 B.setInstrAndDebugLoc(MI);
104 auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0));
105 auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1));
106 B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1});
107 MI.eraseFromParent();
108 return true;
109 }
110
isSignExtended(Register R,MachineRegisterInfo & MRI)111 static bool isSignExtended(Register R, MachineRegisterInfo &MRI) {
112 // TODO: check if extended build vector as well.
113 unsigned Opc = MRI.getVRegDef(R)->getOpcode();
114 return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG;
115 }
116
isZeroExtended(Register R,MachineRegisterInfo & MRI)117 static bool isZeroExtended(Register R, MachineRegisterInfo &MRI) {
118 // TODO: check if extended build vector as well.
119 return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT;
120 }
121
matchAArch64MulConstCombine(MachineInstr & MI,MachineRegisterInfo & MRI,std::function<void (MachineIRBuilder & B,Register DstReg)> & ApplyFn)122 bool matchAArch64MulConstCombine(
123 MachineInstr &MI, MachineRegisterInfo &MRI,
124 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
125 assert(MI.getOpcode() == TargetOpcode::G_MUL);
126 Register LHS = MI.getOperand(1).getReg();
127 Register RHS = MI.getOperand(2).getReg();
128 Register Dst = MI.getOperand(0).getReg();
129 const LLT Ty = MRI.getType(LHS);
130
131 // The below optimizations require a constant RHS.
132 auto Const = getIConstantVRegValWithLookThrough(RHS, MRI);
133 if (!Const)
134 return false;
135
136 const APInt ConstValue = Const->Value.sextOrSelf(Ty.getSizeInBits());
137 // The following code is ported from AArch64ISelLowering.
138 // Multiplication of a power of two plus/minus one can be done more
139 // cheaply as as shift+add/sub. For now, this is true unilaterally. If
140 // future CPUs have a cheaper MADD instruction, this may need to be
141 // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and
142 // 64-bit is 5 cycles, so this is always a win.
143 // More aggressively, some multiplications N0 * C can be lowered to
144 // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M,
145 // e.g. 6=3*2=(2+1)*2.
146 // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45
147 // which equals to (1+2)*16-(1+2).
148 // TrailingZeroes is used to test if the mul can be lowered to
149 // shift+add+shift.
150 unsigned TrailingZeroes = ConstValue.countTrailingZeros();
151 if (TrailingZeroes) {
152 // Conservatively do not lower to shift+add+shift if the mul might be
153 // folded into smul or umul.
154 if (MRI.hasOneNonDBGUse(LHS) &&
155 (isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI)))
156 return false;
157 // Conservatively do not lower to shift+add+shift if the mul might be
158 // folded into madd or msub.
159 if (MRI.hasOneNonDBGUse(Dst)) {
160 MachineInstr &UseMI = *MRI.use_instr_begin(Dst);
161 unsigned UseOpc = UseMI.getOpcode();
162 if (UseOpc == TargetOpcode::G_ADD || UseOpc == TargetOpcode::G_PTR_ADD ||
163 UseOpc == TargetOpcode::G_SUB)
164 return false;
165 }
166 }
167 // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub
168 // and shift+add+shift.
169 APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes);
170
171 unsigned ShiftAmt, AddSubOpc;
172 // Is the shifted value the LHS operand of the add/sub?
173 bool ShiftValUseIsLHS = true;
174 // Do we need to negate the result?
175 bool NegateResult = false;
176
177 if (ConstValue.isNonNegative()) {
178 // (mul x, 2^N + 1) => (add (shl x, N), x)
179 // (mul x, 2^N - 1) => (sub (shl x, N), x)
180 // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
181 APInt SCVMinus1 = ShiftedConstValue - 1;
182 APInt CVPlus1 = ConstValue + 1;
183 if (SCVMinus1.isPowerOf2()) {
184 ShiftAmt = SCVMinus1.logBase2();
185 AddSubOpc = TargetOpcode::G_ADD;
186 } else if (CVPlus1.isPowerOf2()) {
187 ShiftAmt = CVPlus1.logBase2();
188 AddSubOpc = TargetOpcode::G_SUB;
189 } else
190 return false;
191 } else {
192 // (mul x, -(2^N - 1)) => (sub x, (shl x, N))
193 // (mul x, -(2^N + 1)) => - (add (shl x, N), x)
194 APInt CVNegPlus1 = -ConstValue + 1;
195 APInt CVNegMinus1 = -ConstValue - 1;
196 if (CVNegPlus1.isPowerOf2()) {
197 ShiftAmt = CVNegPlus1.logBase2();
198 AddSubOpc = TargetOpcode::G_SUB;
199 ShiftValUseIsLHS = false;
200 } else if (CVNegMinus1.isPowerOf2()) {
201 ShiftAmt = CVNegMinus1.logBase2();
202 AddSubOpc = TargetOpcode::G_ADD;
203 NegateResult = true;
204 } else
205 return false;
206 }
207
208 if (NegateResult && TrailingZeroes)
209 return false;
210
211 ApplyFn = [=](MachineIRBuilder &B, Register DstReg) {
212 auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt);
213 auto ShiftedVal = B.buildShl(Ty, LHS, Shift);
214
215 Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS;
216 Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0);
217 auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS});
218 assert(!(NegateResult && TrailingZeroes) &&
219 "NegateResult and TrailingZeroes cannot both be true for now.");
220 // Negate the result.
221 if (NegateResult) {
222 B.buildSub(DstReg, B.buildConstant(Ty, 0), Res);
223 return;
224 }
225 // Shift the result.
226 if (TrailingZeroes) {
227 B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes));
228 return;
229 }
230 B.buildCopy(DstReg, Res.getReg(0));
231 };
232 return true;
233 }
234
applyAArch64MulConstCombine(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,std::function<void (MachineIRBuilder & B,Register DstReg)> & ApplyFn)235 bool applyAArch64MulConstCombine(
236 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
237 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
238 B.setInstrAndDebugLoc(MI);
239 ApplyFn(B, MI.getOperand(0).getReg());
240 MI.eraseFromParent();
241 return true;
242 }
243
244 /// Try to fold a G_MERGE_VALUES of 2 s32 sources, where the second source
245 /// is a zero, into a G_ZEXT of the first.
matchFoldMergeToZext(MachineInstr & MI,MachineRegisterInfo & MRI)246 bool matchFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI) {
247 auto &Merge = cast<GMerge>(MI);
248 LLT SrcTy = MRI.getType(Merge.getSourceReg(0));
249 if (SrcTy != LLT::scalar(32) || Merge.getNumSources() != 2)
250 return false;
251 return mi_match(Merge.getSourceReg(1), MRI, m_SpecificICst(0));
252 }
253
applyFoldMergeToZext(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,GISelChangeObserver & Observer)254 void applyFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI,
255 MachineIRBuilder &B, GISelChangeObserver &Observer) {
256 // Mutate %d(s64) = G_MERGE_VALUES %a(s32), 0(s32)
257 // ->
258 // %d(s64) = G_ZEXT %a(s32)
259 Observer.changingInstr(MI);
260 MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT));
261 MI.RemoveOperand(2);
262 Observer.changedInstr(MI);
263 }
264
265 /// \returns True if a G_ANYEXT instruction \p MI should be mutated to a G_ZEXT
266 /// instruction.
matchMutateAnyExtToZExt(MachineInstr & MI,MachineRegisterInfo & MRI)267 static bool matchMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI) {
268 // If this is coming from a scalar compare then we can use a G_ZEXT instead of
269 // a G_ANYEXT:
270 //
271 // %cmp:_(s32) = G_[I|F]CMP ... <-- produces 0/1.
272 // %ext:_(s64) = G_ANYEXT %cmp(s32)
273 //
274 // By doing this, we can leverage more KnownBits combines.
275 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT);
276 Register Dst = MI.getOperand(0).getReg();
277 Register Src = MI.getOperand(1).getReg();
278 return MRI.getType(Dst).isScalar() &&
279 mi_match(Src, MRI,
280 m_any_of(m_GICmp(m_Pred(), m_Reg(), m_Reg()),
281 m_GFCmp(m_Pred(), m_Reg(), m_Reg())));
282 }
283
applyMutateAnyExtToZExt(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,GISelChangeObserver & Observer)284 static void applyMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI,
285 MachineIRBuilder &B,
286 GISelChangeObserver &Observer) {
287 Observer.changingInstr(MI);
288 MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT));
289 Observer.changedInstr(MI);
290 }
291
292 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS
293 #include "AArch64GenPostLegalizeGICombiner.inc"
294 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS
295
296 namespace {
297 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H
298 #include "AArch64GenPostLegalizeGICombiner.inc"
299 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H
300
301 class AArch64PostLegalizerCombinerInfo : public CombinerInfo {
302 GISelKnownBits *KB;
303 MachineDominatorTree *MDT;
304
305 public:
306 AArch64GenPostLegalizerCombinerHelperRuleConfig GeneratedRuleCfg;
307
AArch64PostLegalizerCombinerInfo(bool EnableOpt,bool OptSize,bool MinSize,GISelKnownBits * KB,MachineDominatorTree * MDT)308 AArch64PostLegalizerCombinerInfo(bool EnableOpt, bool OptSize, bool MinSize,
309 GISelKnownBits *KB,
310 MachineDominatorTree *MDT)
311 : CombinerInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false,
312 /*LegalizerInfo*/ nullptr, EnableOpt, OptSize, MinSize),
313 KB(KB), MDT(MDT) {
314 if (!GeneratedRuleCfg.parseCommandLineOption())
315 report_fatal_error("Invalid rule identifier");
316 }
317
318 virtual bool combine(GISelChangeObserver &Observer, MachineInstr &MI,
319 MachineIRBuilder &B) const override;
320 };
321
combine(GISelChangeObserver & Observer,MachineInstr & MI,MachineIRBuilder & B) const322 bool AArch64PostLegalizerCombinerInfo::combine(GISelChangeObserver &Observer,
323 MachineInstr &MI,
324 MachineIRBuilder &B) const {
325 const auto *LI =
326 MI.getParent()->getParent()->getSubtarget().getLegalizerInfo();
327 CombinerHelper Helper(Observer, B, KB, MDT, LI);
328 AArch64GenPostLegalizerCombinerHelper Generated(GeneratedRuleCfg);
329 return Generated.tryCombineAll(Observer, MI, B, Helper);
330 }
331
332 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP
333 #include "AArch64GenPostLegalizeGICombiner.inc"
334 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP
335
336 class AArch64PostLegalizerCombiner : public MachineFunctionPass {
337 public:
338 static char ID;
339
340 AArch64PostLegalizerCombiner(bool IsOptNone = false);
341
getPassName() const342 StringRef getPassName() const override {
343 return "AArch64PostLegalizerCombiner";
344 }
345
346 bool runOnMachineFunction(MachineFunction &MF) override;
347 void getAnalysisUsage(AnalysisUsage &AU) const override;
348
349 private:
350 bool IsOptNone;
351 };
352 } // end anonymous namespace
353
getAnalysisUsage(AnalysisUsage & AU) const354 void AArch64PostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const {
355 AU.addRequired<TargetPassConfig>();
356 AU.setPreservesCFG();
357 getSelectionDAGFallbackAnalysisUsage(AU);
358 AU.addRequired<GISelKnownBitsAnalysis>();
359 AU.addPreserved<GISelKnownBitsAnalysis>();
360 if (!IsOptNone) {
361 AU.addRequired<MachineDominatorTree>();
362 AU.addPreserved<MachineDominatorTree>();
363 AU.addRequired<GISelCSEAnalysisWrapperPass>();
364 AU.addPreserved<GISelCSEAnalysisWrapperPass>();
365 }
366 MachineFunctionPass::getAnalysisUsage(AU);
367 }
368
AArch64PostLegalizerCombiner(bool IsOptNone)369 AArch64PostLegalizerCombiner::AArch64PostLegalizerCombiner(bool IsOptNone)
370 : MachineFunctionPass(ID), IsOptNone(IsOptNone) {
371 initializeAArch64PostLegalizerCombinerPass(*PassRegistry::getPassRegistry());
372 }
373
runOnMachineFunction(MachineFunction & MF)374 bool AArch64PostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) {
375 if (MF.getProperties().hasProperty(
376 MachineFunctionProperties::Property::FailedISel))
377 return false;
378 assert(MF.getProperties().hasProperty(
379 MachineFunctionProperties::Property::Legalized) &&
380 "Expected a legalized function?");
381 auto *TPC = &getAnalysis<TargetPassConfig>();
382 const Function &F = MF.getFunction();
383 bool EnableOpt =
384 MF.getTarget().getOptLevel() != CodeGenOpt::None && !skipFunction(F);
385 GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF);
386 MachineDominatorTree *MDT =
387 IsOptNone ? nullptr : &getAnalysis<MachineDominatorTree>();
388 AArch64PostLegalizerCombinerInfo PCInfo(EnableOpt, F.hasOptSize(),
389 F.hasMinSize(), KB, MDT);
390 GISelCSEAnalysisWrapper &Wrapper =
391 getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper();
392 auto *CSEInfo = &Wrapper.get(TPC->getCSEConfig());
393 Combiner C(PCInfo, TPC);
394 return C.combineMachineInstrs(MF, CSEInfo);
395 }
396
397 char AArch64PostLegalizerCombiner::ID = 0;
398 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerCombiner, DEBUG_TYPE,
399 "Combine AArch64 MachineInstrs after legalization", false,
400 false)
401 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
402 INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis)
403 INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE,
404 "Combine AArch64 MachineInstrs after legalization", false,
405 false)
406
407 namespace llvm {
createAArch64PostLegalizerCombiner(bool IsOptNone)408 FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) {
409 return new AArch64PostLegalizerCombiner(IsOptNone);
410 }
411 } // end namespace llvm
412