1 //=== AArch64PostLegalizerLowering.cpp --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Post-legalization lowering for instructions.
11 ///
12 /// This is used to offload pattern matching from the selector.
13 ///
14 /// For example, this combiner will notice that a G_SHUFFLE_VECTOR is actually
15 /// a G_ZIP, G_UZP, etc.
16 ///
17 /// General optimization combines should be handled by either the
18 /// AArch64PostLegalizerCombiner or the AArch64PreLegalizerCombiner.
19 ///
20 //===----------------------------------------------------------------------===//
21 
22 #include "AArch64GlobalISelUtils.h"
23 #include "AArch64Subtarget.h"
24 #include "AArch64TargetMachine.h"
25 #include "GISel/AArch64LegalizerInfo.h"
26 #include "MCTargetDesc/AArch64MCTargetDesc.h"
27 #include "TargetInfo/AArch64TargetInfo.h"
28 #include "Utils/AArch64BaseInfo.h"
29 #include "llvm/CodeGen/GlobalISel/Combiner.h"
30 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
31 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
32 #include "llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h"
33 #include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
34 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
35 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
36 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
37 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
38 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
39 #include "llvm/CodeGen/GlobalISel/Utils.h"
40 #include "llvm/CodeGen/MachineFunctionPass.h"
41 #include "llvm/CodeGen/MachineInstrBuilder.h"
42 #include "llvm/CodeGen/MachineRegisterInfo.h"
43 #include "llvm/CodeGen/TargetOpcodes.h"
44 #include "llvm/CodeGen/TargetPassConfig.h"
45 #include "llvm/IR/InstrTypes.h"
46 #include "llvm/InitializePasses.h"
47 #include "llvm/Support/Debug.h"
48 #include "llvm/Support/ErrorHandling.h"
49 #include <optional>
50 
51 #define GET_GICOMBINER_DEPS
52 #include "AArch64GenPostLegalizeGILowering.inc"
53 #undef GET_GICOMBINER_DEPS
54 
55 #define DEBUG_TYPE "aarch64-postlegalizer-lowering"
56 
57 using namespace llvm;
58 using namespace MIPatternMatch;
59 using namespace AArch64GISelUtils;
60 
61 namespace {
62 
63 #define GET_GICOMBINER_TYPES
64 #include "AArch64GenPostLegalizeGILowering.inc"
65 #undef GET_GICOMBINER_TYPES
66 
67 /// Represents a pseudo instruction which replaces a G_SHUFFLE_VECTOR.
68 ///
69 /// Used for matching target-supported shuffles before codegen.
70 struct ShuffleVectorPseudo {
71   unsigned Opc;                 ///< Opcode for the instruction. (E.g. G_ZIP1)
72   Register Dst;                 ///< Destination register.
73   SmallVector<SrcOp, 2> SrcOps; ///< Source registers.
74   ShuffleVectorPseudo(unsigned Opc, Register Dst,
75                       std::initializer_list<SrcOp> SrcOps)
76       : Opc(Opc), Dst(Dst), SrcOps(SrcOps){};
77   ShuffleVectorPseudo() = default;
78 };
79 
80 /// Check if a vector shuffle corresponds to a REV instruction with the
81 /// specified blocksize.
82 bool isREVMask(ArrayRef<int> M, unsigned EltSize, unsigned NumElts,
83                unsigned BlockSize) {
84   assert((BlockSize == 16 || BlockSize == 32 || BlockSize == 64) &&
85          "Only possible block sizes for REV are: 16, 32, 64");
86   assert(EltSize != 64 && "EltSize cannot be 64 for REV mask.");
87 
88   unsigned BlockElts = M[0] + 1;
89 
90   // If the first shuffle index is UNDEF, be optimistic.
91   if (M[0] < 0)
92     BlockElts = BlockSize / EltSize;
93 
94   if (BlockSize <= EltSize || BlockSize != BlockElts * EltSize)
95     return false;
96 
97   for (unsigned i = 0; i < NumElts; ++i) {
98     // Ignore undef indices.
99     if (M[i] < 0)
100       continue;
101     if (static_cast<unsigned>(M[i]) !=
102         (i - i % BlockElts) + (BlockElts - 1 - i % BlockElts))
103       return false;
104   }
105 
106   return true;
107 }
108 
109 /// Determines if \p M is a shuffle vector mask for a TRN of \p NumElts.
110 /// Whether or not G_TRN1 or G_TRN2 should be used is stored in \p WhichResult.
111 bool isTRNMask(ArrayRef<int> M, unsigned NumElts, unsigned &WhichResult) {
112   if (NumElts % 2 != 0)
113     return false;
114   WhichResult = (M[0] == 0 ? 0 : 1);
115   for (unsigned i = 0; i < NumElts; i += 2) {
116     if ((M[i] >= 0 && static_cast<unsigned>(M[i]) != i + WhichResult) ||
117         (M[i + 1] >= 0 &&
118          static_cast<unsigned>(M[i + 1]) != i + NumElts + WhichResult))
119       return false;
120   }
121   return true;
122 }
123 
124 /// Check if a G_EXT instruction can handle a shuffle mask \p M when the vector
125 /// sources of the shuffle are different.
126 std::optional<std::pair<bool, uint64_t>> getExtMask(ArrayRef<int> M,
127                                                     unsigned NumElts) {
128   // Look for the first non-undef element.
129   auto FirstRealElt = find_if(M, [](int Elt) { return Elt >= 0; });
130   if (FirstRealElt == M.end())
131     return std::nullopt;
132 
133   // Use APInt to handle overflow when calculating expected element.
134   unsigned MaskBits = APInt(32, NumElts * 2).logBase2();
135   APInt ExpectedElt = APInt(MaskBits, *FirstRealElt + 1);
136 
137   // The following shuffle indices must be the successive elements after the
138   // first real element.
139   if (any_of(
140           make_range(std::next(FirstRealElt), M.end()),
141           [&ExpectedElt](int Elt) { return Elt != ExpectedElt++ && Elt >= 0; }))
142     return std::nullopt;
143 
144   // The index of an EXT is the first element if it is not UNDEF.
145   // Watch out for the beginning UNDEFs. The EXT index should be the expected
146   // value of the first element.  E.g.
147   // <-1, -1, 3, ...> is treated as <1, 2, 3, ...>.
148   // <-1, -1, 0, 1, ...> is treated as <2*NumElts-2, 2*NumElts-1, 0, 1, ...>.
149   // ExpectedElt is the last mask index plus 1.
150   uint64_t Imm = ExpectedElt.getZExtValue();
151   bool ReverseExt = false;
152 
153   // There are two difference cases requiring to reverse input vectors.
154   // For example, for vector <4 x i32> we have the following cases,
155   // Case 1: shufflevector(<4 x i32>,<4 x i32>,<-1, -1, -1, 0>)
156   // Case 2: shufflevector(<4 x i32>,<4 x i32>,<-1, -1, 7, 0>)
157   // For both cases, we finally use mask <5, 6, 7, 0>, which requires
158   // to reverse two input vectors.
159   if (Imm < NumElts)
160     ReverseExt = true;
161   else
162     Imm -= NumElts;
163   return std::make_pair(ReverseExt, Imm);
164 }
165 
166 /// Determines if \p M is a shuffle vector mask for a UZP of \p NumElts.
167 /// Whether or not G_UZP1 or G_UZP2 should be used is stored in \p WhichResult.
168 bool isUZPMask(ArrayRef<int> M, unsigned NumElts, unsigned &WhichResult) {
169   WhichResult = (M[0] == 0 ? 0 : 1);
170   for (unsigned i = 0; i != NumElts; ++i) {
171     // Skip undef indices.
172     if (M[i] < 0)
173       continue;
174     if (static_cast<unsigned>(M[i]) != 2 * i + WhichResult)
175       return false;
176   }
177   return true;
178 }
179 
180 /// \return true if \p M is a zip mask for a shuffle vector of \p NumElts.
181 /// Whether or not G_ZIP1 or G_ZIP2 should be used is stored in \p WhichResult.
182 bool isZipMask(ArrayRef<int> M, unsigned NumElts, unsigned &WhichResult) {
183   if (NumElts % 2 != 0)
184     return false;
185 
186   // 0 means use ZIP1, 1 means use ZIP2.
187   WhichResult = (M[0] == 0 ? 0 : 1);
188   unsigned Idx = WhichResult * NumElts / 2;
189   for (unsigned i = 0; i != NumElts; i += 2) {
190     if ((M[i] >= 0 && static_cast<unsigned>(M[i]) != Idx) ||
191         (M[i + 1] >= 0 && static_cast<unsigned>(M[i + 1]) != Idx + NumElts))
192       return false;
193     Idx += 1;
194   }
195   return true;
196 }
197 
198 /// Helper function for matchINS.
199 ///
200 /// \returns a value when \p M is an ins mask for \p NumInputElements.
201 ///
202 /// First element of the returned pair is true when the produced
203 /// G_INSERT_VECTOR_ELT destination should be the LHS of the G_SHUFFLE_VECTOR.
204 ///
205 /// Second element is the destination lane for the G_INSERT_VECTOR_ELT.
206 std::optional<std::pair<bool, int>> isINSMask(ArrayRef<int> M,
207                                               int NumInputElements) {
208   if (M.size() != static_cast<size_t>(NumInputElements))
209     return std::nullopt;
210   int NumLHSMatch = 0, NumRHSMatch = 0;
211   int LastLHSMismatch = -1, LastRHSMismatch = -1;
212   for (int Idx = 0; Idx < NumInputElements; ++Idx) {
213     if (M[Idx] == -1) {
214       ++NumLHSMatch;
215       ++NumRHSMatch;
216       continue;
217     }
218     M[Idx] == Idx ? ++NumLHSMatch : LastLHSMismatch = Idx;
219     M[Idx] == Idx + NumInputElements ? ++NumRHSMatch : LastRHSMismatch = Idx;
220   }
221   const int NumNeededToMatch = NumInputElements - 1;
222   if (NumLHSMatch == NumNeededToMatch)
223     return std::make_pair(true, LastLHSMismatch);
224   if (NumRHSMatch == NumNeededToMatch)
225     return std::make_pair(false, LastRHSMismatch);
226   return std::nullopt;
227 }
228 
229 /// \return true if a G_SHUFFLE_VECTOR instruction \p MI can be replaced with a
230 /// G_REV instruction. Returns the appropriate G_REV opcode in \p Opc.
231 bool matchREV(MachineInstr &MI, MachineRegisterInfo &MRI,
232               ShuffleVectorPseudo &MatchInfo) {
233   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
234   ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
235   Register Dst = MI.getOperand(0).getReg();
236   Register Src = MI.getOperand(1).getReg();
237   LLT Ty = MRI.getType(Dst);
238   unsigned EltSize = Ty.getScalarSizeInBits();
239 
240   // Element size for a rev cannot be 64.
241   if (EltSize == 64)
242     return false;
243 
244   unsigned NumElts = Ty.getNumElements();
245 
246   // Try to produce G_REV64
247   if (isREVMask(ShuffleMask, EltSize, NumElts, 64)) {
248     MatchInfo = ShuffleVectorPseudo(AArch64::G_REV64, Dst, {Src});
249     return true;
250   }
251 
252   // TODO: Produce G_REV32 and G_REV16 once we have proper legalization support.
253   // This should be identical to above, but with a constant 32 and constant
254   // 16.
255   return false;
256 }
257 
258 /// \return true if a G_SHUFFLE_VECTOR instruction \p MI can be replaced with
259 /// a G_TRN1 or G_TRN2 instruction.
260 bool matchTRN(MachineInstr &MI, MachineRegisterInfo &MRI,
261               ShuffleVectorPseudo &MatchInfo) {
262   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
263   unsigned WhichResult;
264   ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
265   Register Dst = MI.getOperand(0).getReg();
266   unsigned NumElts = MRI.getType(Dst).getNumElements();
267   if (!isTRNMask(ShuffleMask, NumElts, WhichResult))
268     return false;
269   unsigned Opc = (WhichResult == 0) ? AArch64::G_TRN1 : AArch64::G_TRN2;
270   Register V1 = MI.getOperand(1).getReg();
271   Register V2 = MI.getOperand(2).getReg();
272   MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2});
273   return true;
274 }
275 
276 /// \return true if a G_SHUFFLE_VECTOR instruction \p MI can be replaced with
277 /// a G_UZP1 or G_UZP2 instruction.
278 ///
279 /// \param [in] MI - The shuffle vector instruction.
280 /// \param [out] MatchInfo - Either G_UZP1 or G_UZP2 on success.
281 bool matchUZP(MachineInstr &MI, MachineRegisterInfo &MRI,
282               ShuffleVectorPseudo &MatchInfo) {
283   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
284   unsigned WhichResult;
285   ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
286   Register Dst = MI.getOperand(0).getReg();
287   unsigned NumElts = MRI.getType(Dst).getNumElements();
288   if (!isUZPMask(ShuffleMask, NumElts, WhichResult))
289     return false;
290   unsigned Opc = (WhichResult == 0) ? AArch64::G_UZP1 : AArch64::G_UZP2;
291   Register V1 = MI.getOperand(1).getReg();
292   Register V2 = MI.getOperand(2).getReg();
293   MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2});
294   return true;
295 }
296 
297 bool matchZip(MachineInstr &MI, MachineRegisterInfo &MRI,
298               ShuffleVectorPseudo &MatchInfo) {
299   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
300   unsigned WhichResult;
301   ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
302   Register Dst = MI.getOperand(0).getReg();
303   unsigned NumElts = MRI.getType(Dst).getNumElements();
304   if (!isZipMask(ShuffleMask, NumElts, WhichResult))
305     return false;
306   unsigned Opc = (WhichResult == 0) ? AArch64::G_ZIP1 : AArch64::G_ZIP2;
307   Register V1 = MI.getOperand(1).getReg();
308   Register V2 = MI.getOperand(2).getReg();
309   MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2});
310   return true;
311 }
312 
313 /// Helper function for matchDup.
314 bool matchDupFromInsertVectorElt(int Lane, MachineInstr &MI,
315                                  MachineRegisterInfo &MRI,
316                                  ShuffleVectorPseudo &MatchInfo) {
317   if (Lane != 0)
318     return false;
319 
320   // Try to match a vector splat operation into a dup instruction.
321   // We're looking for this pattern:
322   //
323   // %scalar:gpr(s64) = COPY $x0
324   // %undef:fpr(<2 x s64>) = G_IMPLICIT_DEF
325   // %cst0:gpr(s32) = G_CONSTANT i32 0
326   // %zerovec:fpr(<2 x s32>) = G_BUILD_VECTOR %cst0(s32), %cst0(s32)
327   // %ins:fpr(<2 x s64>) = G_INSERT_VECTOR_ELT %undef, %scalar(s64), %cst0(s32)
328   // %splat:fpr(<2 x s64>) = G_SHUFFLE_VECTOR %ins(<2 x s64>), %undef,
329   // %zerovec(<2 x s32>)
330   //
331   // ...into:
332   // %splat = G_DUP %scalar
333 
334   // Begin matching the insert.
335   auto *InsMI = getOpcodeDef(TargetOpcode::G_INSERT_VECTOR_ELT,
336                              MI.getOperand(1).getReg(), MRI);
337   if (!InsMI)
338     return false;
339   // Match the undef vector operand.
340   if (!getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, InsMI->getOperand(1).getReg(),
341                     MRI))
342     return false;
343 
344   // Match the index constant 0.
345   if (!mi_match(InsMI->getOperand(3).getReg(), MRI, m_ZeroInt()))
346     return false;
347 
348   MatchInfo = ShuffleVectorPseudo(AArch64::G_DUP, MI.getOperand(0).getReg(),
349                                   {InsMI->getOperand(2).getReg()});
350   return true;
351 }
352 
353 /// Helper function for matchDup.
354 bool matchDupFromBuildVector(int Lane, MachineInstr &MI,
355                              MachineRegisterInfo &MRI,
356                              ShuffleVectorPseudo &MatchInfo) {
357   assert(Lane >= 0 && "Expected positive lane?");
358   // Test if the LHS is a BUILD_VECTOR. If it is, then we can just reference the
359   // lane's definition directly.
360   auto *BuildVecMI = getOpcodeDef(TargetOpcode::G_BUILD_VECTOR,
361                                   MI.getOperand(1).getReg(), MRI);
362   if (!BuildVecMI)
363     return false;
364   Register Reg = BuildVecMI->getOperand(Lane + 1).getReg();
365   MatchInfo =
366       ShuffleVectorPseudo(AArch64::G_DUP, MI.getOperand(0).getReg(), {Reg});
367   return true;
368 }
369 
370 bool matchDup(MachineInstr &MI, MachineRegisterInfo &MRI,
371               ShuffleVectorPseudo &MatchInfo) {
372   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
373   auto MaybeLane = getSplatIndex(MI);
374   if (!MaybeLane)
375     return false;
376   int Lane = *MaybeLane;
377   // If this is undef splat, generate it via "just" vdup, if possible.
378   if (Lane < 0)
379     Lane = 0;
380   if (matchDupFromInsertVectorElt(Lane, MI, MRI, MatchInfo))
381     return true;
382   if (matchDupFromBuildVector(Lane, MI, MRI, MatchInfo))
383     return true;
384   return false;
385 }
386 
387 // Check if an EXT instruction can handle the shuffle mask when the vector
388 // sources of the shuffle are the same.
389 bool isSingletonExtMask(ArrayRef<int> M, LLT Ty) {
390   unsigned NumElts = Ty.getNumElements();
391 
392   // Assume that the first shuffle index is not UNDEF.  Fail if it is.
393   if (M[0] < 0)
394     return false;
395 
396   // If this is a VEXT shuffle, the immediate value is the index of the first
397   // element.  The other shuffle indices must be the successive elements after
398   // the first one.
399   unsigned ExpectedElt = M[0];
400   for (unsigned I = 1; I < NumElts; ++I) {
401     // Increment the expected index.  If it wraps around, just follow it
402     // back to index zero and keep going.
403     ++ExpectedElt;
404     if (ExpectedElt == NumElts)
405       ExpectedElt = 0;
406 
407     if (M[I] < 0)
408       continue; // Ignore UNDEF indices.
409     if (ExpectedElt != static_cast<unsigned>(M[I]))
410       return false;
411   }
412 
413   return true;
414 }
415 
416 bool matchEXT(MachineInstr &MI, MachineRegisterInfo &MRI,
417               ShuffleVectorPseudo &MatchInfo) {
418   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
419   Register Dst = MI.getOperand(0).getReg();
420   LLT DstTy = MRI.getType(Dst);
421   Register V1 = MI.getOperand(1).getReg();
422   Register V2 = MI.getOperand(2).getReg();
423   auto Mask = MI.getOperand(3).getShuffleMask();
424   uint64_t Imm;
425   auto ExtInfo = getExtMask(Mask, DstTy.getNumElements());
426   uint64_t ExtFactor = MRI.getType(V1).getScalarSizeInBits() / 8;
427 
428   if (!ExtInfo) {
429     if (!getOpcodeDef<GImplicitDef>(V2, MRI) ||
430         !isSingletonExtMask(Mask, DstTy))
431       return false;
432 
433     Imm = Mask[0] * ExtFactor;
434     MatchInfo = ShuffleVectorPseudo(AArch64::G_EXT, Dst, {V1, V1, Imm});
435     return true;
436   }
437   bool ReverseExt;
438   std::tie(ReverseExt, Imm) = *ExtInfo;
439   if (ReverseExt)
440     std::swap(V1, V2);
441   Imm *= ExtFactor;
442   MatchInfo = ShuffleVectorPseudo(AArch64::G_EXT, Dst, {V1, V2, Imm});
443   return true;
444 }
445 
446 /// Replace a G_SHUFFLE_VECTOR instruction with a pseudo.
447 /// \p Opc is the opcode to use. \p MI is the G_SHUFFLE_VECTOR.
448 void applyShuffleVectorPseudo(MachineInstr &MI,
449                               ShuffleVectorPseudo &MatchInfo) {
450   MachineIRBuilder MIRBuilder(MI);
451   MIRBuilder.buildInstr(MatchInfo.Opc, {MatchInfo.Dst}, MatchInfo.SrcOps);
452   MI.eraseFromParent();
453 }
454 
455 /// Replace a G_SHUFFLE_VECTOR instruction with G_EXT.
456 /// Special-cased because the constant operand must be emitted as a G_CONSTANT
457 /// for the imported tablegen patterns to work.
458 void applyEXT(MachineInstr &MI, ShuffleVectorPseudo &MatchInfo) {
459   MachineIRBuilder MIRBuilder(MI);
460   // Tablegen patterns expect an i32 G_CONSTANT as the final op.
461   auto Cst =
462       MIRBuilder.buildConstant(LLT::scalar(32), MatchInfo.SrcOps[2].getImm());
463   MIRBuilder.buildInstr(MatchInfo.Opc, {MatchInfo.Dst},
464                         {MatchInfo.SrcOps[0], MatchInfo.SrcOps[1], Cst});
465   MI.eraseFromParent();
466 }
467 
468 /// Match a G_SHUFFLE_VECTOR with a mask which corresponds to a
469 /// G_INSERT_VECTOR_ELT and G_EXTRACT_VECTOR_ELT pair.
470 ///
471 /// e.g.
472 ///   %shuf = G_SHUFFLE_VECTOR %left, %right, shufflemask(0, 0)
473 ///
474 /// Can be represented as
475 ///
476 ///   %extract = G_EXTRACT_VECTOR_ELT %left, 0
477 ///   %ins = G_INSERT_VECTOR_ELT %left, %extract, 1
478 ///
479 bool matchINS(MachineInstr &MI, MachineRegisterInfo &MRI,
480               std::tuple<Register, int, Register, int> &MatchInfo) {
481   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
482   ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
483   Register Dst = MI.getOperand(0).getReg();
484   int NumElts = MRI.getType(Dst).getNumElements();
485   auto DstIsLeftAndDstLane = isINSMask(ShuffleMask, NumElts);
486   if (!DstIsLeftAndDstLane)
487     return false;
488   bool DstIsLeft;
489   int DstLane;
490   std::tie(DstIsLeft, DstLane) = *DstIsLeftAndDstLane;
491   Register Left = MI.getOperand(1).getReg();
492   Register Right = MI.getOperand(2).getReg();
493   Register DstVec = DstIsLeft ? Left : Right;
494   Register SrcVec = Left;
495 
496   int SrcLane = ShuffleMask[DstLane];
497   if (SrcLane >= NumElts) {
498     SrcVec = Right;
499     SrcLane -= NumElts;
500   }
501 
502   MatchInfo = std::make_tuple(DstVec, DstLane, SrcVec, SrcLane);
503   return true;
504 }
505 
506 void applyINS(MachineInstr &MI, MachineRegisterInfo &MRI,
507               MachineIRBuilder &Builder,
508               std::tuple<Register, int, Register, int> &MatchInfo) {
509   Builder.setInstrAndDebugLoc(MI);
510   Register Dst = MI.getOperand(0).getReg();
511   auto ScalarTy = MRI.getType(Dst).getElementType();
512   Register DstVec, SrcVec;
513   int DstLane, SrcLane;
514   std::tie(DstVec, DstLane, SrcVec, SrcLane) = MatchInfo;
515   auto SrcCst = Builder.buildConstant(LLT::scalar(64), SrcLane);
516   auto Extract = Builder.buildExtractVectorElement(ScalarTy, SrcVec, SrcCst);
517   auto DstCst = Builder.buildConstant(LLT::scalar(64), DstLane);
518   Builder.buildInsertVectorElement(Dst, DstVec, Extract, DstCst);
519   MI.eraseFromParent();
520 }
521 
522 /// isVShiftRImm - Check if this is a valid vector for the immediate
523 /// operand of a vector shift right operation. The value must be in the range:
524 ///   1 <= Value <= ElementBits for a right shift.
525 bool isVShiftRImm(Register Reg, MachineRegisterInfo &MRI, LLT Ty,
526                   int64_t &Cnt) {
527   assert(Ty.isVector() && "vector shift count is not a vector type");
528   MachineInstr *MI = MRI.getVRegDef(Reg);
529   auto Cst = getAArch64VectorSplatScalar(*MI, MRI);
530   if (!Cst)
531     return false;
532   Cnt = *Cst;
533   int64_t ElementBits = Ty.getScalarSizeInBits();
534   return Cnt >= 1 && Cnt <= ElementBits;
535 }
536 
537 /// Match a vector G_ASHR or G_LSHR with a valid immediate shift.
538 bool matchVAshrLshrImm(MachineInstr &MI, MachineRegisterInfo &MRI,
539                        int64_t &Imm) {
540   assert(MI.getOpcode() == TargetOpcode::G_ASHR ||
541          MI.getOpcode() == TargetOpcode::G_LSHR);
542   LLT Ty = MRI.getType(MI.getOperand(1).getReg());
543   if (!Ty.isVector())
544     return false;
545   return isVShiftRImm(MI.getOperand(2).getReg(), MRI, Ty, Imm);
546 }
547 
548 void applyVAshrLshrImm(MachineInstr &MI, MachineRegisterInfo &MRI,
549                        int64_t &Imm) {
550   unsigned Opc = MI.getOpcode();
551   assert(Opc == TargetOpcode::G_ASHR || Opc == TargetOpcode::G_LSHR);
552   unsigned NewOpc =
553       Opc == TargetOpcode::G_ASHR ? AArch64::G_VASHR : AArch64::G_VLSHR;
554   MachineIRBuilder MIB(MI);
555   auto ImmDef = MIB.buildConstant(LLT::scalar(32), Imm);
556   MIB.buildInstr(NewOpc, {MI.getOperand(0)}, {MI.getOperand(1), ImmDef});
557   MI.eraseFromParent();
558 }
559 
560 /// Determine if it is possible to modify the \p RHS and predicate \p P of a
561 /// G_ICMP instruction such that the right-hand side is an arithmetic immediate.
562 ///
563 /// \returns A pair containing the updated immediate and predicate which may
564 /// be used to optimize the instruction.
565 ///
566 /// \note This assumes that the comparison has been legalized.
567 std::optional<std::pair<uint64_t, CmpInst::Predicate>>
568 tryAdjustICmpImmAndPred(Register RHS, CmpInst::Predicate P,
569                         const MachineRegisterInfo &MRI) {
570   const auto &Ty = MRI.getType(RHS);
571   if (Ty.isVector())
572     return std::nullopt;
573   unsigned Size = Ty.getSizeInBits();
574   assert((Size == 32 || Size == 64) && "Expected 32 or 64 bit compare only?");
575 
576   // If the RHS is not a constant, or the RHS is already a valid arithmetic
577   // immediate, then there is nothing to change.
578   auto ValAndVReg = getIConstantVRegValWithLookThrough(RHS, MRI);
579   if (!ValAndVReg)
580     return std::nullopt;
581   uint64_t C = ValAndVReg->Value.getZExtValue();
582   if (isLegalArithImmed(C))
583     return std::nullopt;
584 
585   // We have a non-arithmetic immediate. Check if adjusting the immediate and
586   // adjusting the predicate will result in a legal arithmetic immediate.
587   switch (P) {
588   default:
589     return std::nullopt;
590   case CmpInst::ICMP_SLT:
591   case CmpInst::ICMP_SGE:
592     // Check for
593     //
594     // x slt c => x sle c - 1
595     // x sge c => x sgt c - 1
596     //
597     // When c is not the smallest possible negative number.
598     if ((Size == 64 && static_cast<int64_t>(C) == INT64_MIN) ||
599         (Size == 32 && static_cast<int32_t>(C) == INT32_MIN))
600       return std::nullopt;
601     P = (P == CmpInst::ICMP_SLT) ? CmpInst::ICMP_SLE : CmpInst::ICMP_SGT;
602     C -= 1;
603     break;
604   case CmpInst::ICMP_ULT:
605   case CmpInst::ICMP_UGE:
606     // Check for
607     //
608     // x ult c => x ule c - 1
609     // x uge c => x ugt c - 1
610     //
611     // When c is not zero.
612     if (C == 0)
613       return std::nullopt;
614     P = (P == CmpInst::ICMP_ULT) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT;
615     C -= 1;
616     break;
617   case CmpInst::ICMP_SLE:
618   case CmpInst::ICMP_SGT:
619     // Check for
620     //
621     // x sle c => x slt c + 1
622     // x sgt c => s sge c + 1
623     //
624     // When c is not the largest possible signed integer.
625     if ((Size == 32 && static_cast<int32_t>(C) == INT32_MAX) ||
626         (Size == 64 && static_cast<int64_t>(C) == INT64_MAX))
627       return std::nullopt;
628     P = (P == CmpInst::ICMP_SLE) ? CmpInst::ICMP_SLT : CmpInst::ICMP_SGE;
629     C += 1;
630     break;
631   case CmpInst::ICMP_ULE:
632   case CmpInst::ICMP_UGT:
633     // Check for
634     //
635     // x ule c => x ult c + 1
636     // x ugt c => s uge c + 1
637     //
638     // When c is not the largest possible unsigned integer.
639     if ((Size == 32 && static_cast<uint32_t>(C) == UINT32_MAX) ||
640         (Size == 64 && C == UINT64_MAX))
641       return std::nullopt;
642     P = (P == CmpInst::ICMP_ULE) ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE;
643     C += 1;
644     break;
645   }
646 
647   // Check if the new constant is valid, and return the updated constant and
648   // predicate if it is.
649   if (Size == 32)
650     C = static_cast<uint32_t>(C);
651   if (!isLegalArithImmed(C))
652     return std::nullopt;
653   return {{C, P}};
654 }
655 
656 /// Determine whether or not it is possible to update the RHS and predicate of
657 /// a G_ICMP instruction such that the RHS will be selected as an arithmetic
658 /// immediate.
659 ///
660 /// \p MI - The G_ICMP instruction
661 /// \p MatchInfo - The new RHS immediate and predicate on success
662 ///
663 /// See tryAdjustICmpImmAndPred for valid transformations.
664 bool matchAdjustICmpImmAndPred(
665     MachineInstr &MI, const MachineRegisterInfo &MRI,
666     std::pair<uint64_t, CmpInst::Predicate> &MatchInfo) {
667   assert(MI.getOpcode() == TargetOpcode::G_ICMP);
668   Register RHS = MI.getOperand(3).getReg();
669   auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
670   if (auto MaybeNewImmAndPred = tryAdjustICmpImmAndPred(RHS, Pred, MRI)) {
671     MatchInfo = *MaybeNewImmAndPred;
672     return true;
673   }
674   return false;
675 }
676 
677 void applyAdjustICmpImmAndPred(
678     MachineInstr &MI, std::pair<uint64_t, CmpInst::Predicate> &MatchInfo,
679     MachineIRBuilder &MIB, GISelChangeObserver &Observer) {
680   MIB.setInstrAndDebugLoc(MI);
681   MachineOperand &RHS = MI.getOperand(3);
682   MachineRegisterInfo &MRI = *MIB.getMRI();
683   auto Cst = MIB.buildConstant(MRI.cloneVirtualRegister(RHS.getReg()),
684                                MatchInfo.first);
685   Observer.changingInstr(MI);
686   RHS.setReg(Cst->getOperand(0).getReg());
687   MI.getOperand(1).setPredicate(MatchInfo.second);
688   Observer.changedInstr(MI);
689 }
690 
691 bool matchDupLane(MachineInstr &MI, MachineRegisterInfo &MRI,
692                   std::pair<unsigned, int> &MatchInfo) {
693   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
694   Register Src1Reg = MI.getOperand(1).getReg();
695   const LLT SrcTy = MRI.getType(Src1Reg);
696   const LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
697 
698   auto LaneIdx = getSplatIndex(MI);
699   if (!LaneIdx)
700     return false;
701 
702   // The lane idx should be within the first source vector.
703   if (*LaneIdx >= SrcTy.getNumElements())
704     return false;
705 
706   if (DstTy != SrcTy)
707     return false;
708 
709   LLT ScalarTy = SrcTy.getElementType();
710   unsigned ScalarSize = ScalarTy.getSizeInBits();
711 
712   unsigned Opc = 0;
713   switch (SrcTy.getNumElements()) {
714   case 2:
715     if (ScalarSize == 64)
716       Opc = AArch64::G_DUPLANE64;
717     else if (ScalarSize == 32)
718       Opc = AArch64::G_DUPLANE32;
719     break;
720   case 4:
721     if (ScalarSize == 32)
722       Opc = AArch64::G_DUPLANE32;
723     break;
724   case 8:
725     if (ScalarSize == 16)
726       Opc = AArch64::G_DUPLANE16;
727     break;
728   case 16:
729     if (ScalarSize == 8)
730       Opc = AArch64::G_DUPLANE8;
731     break;
732   default:
733     break;
734   }
735   if (!Opc)
736     return false;
737 
738   MatchInfo.first = Opc;
739   MatchInfo.second = *LaneIdx;
740   return true;
741 }
742 
743 void applyDupLane(MachineInstr &MI, MachineRegisterInfo &MRI,
744                   MachineIRBuilder &B, std::pair<unsigned, int> &MatchInfo) {
745   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
746   Register Src1Reg = MI.getOperand(1).getReg();
747   const LLT SrcTy = MRI.getType(Src1Reg);
748 
749   B.setInstrAndDebugLoc(MI);
750   auto Lane = B.buildConstant(LLT::scalar(64), MatchInfo.second);
751 
752   Register DupSrc = MI.getOperand(1).getReg();
753   // For types like <2 x s32>, we can use G_DUPLANE32, with a <4 x s32> source.
754   // To do this, we can use a G_CONCAT_VECTORS to do the widening.
755   if (SrcTy == LLT::fixed_vector(2, LLT::scalar(32))) {
756     assert(MRI.getType(MI.getOperand(0).getReg()).getNumElements() == 2 &&
757            "Unexpected dest elements");
758     auto Undef = B.buildUndef(SrcTy);
759     DupSrc = B.buildConcatVectors(
760                   SrcTy.changeElementCount(ElementCount::getFixed(4)),
761                   {Src1Reg, Undef.getReg(0)})
762                  .getReg(0);
763   }
764   B.buildInstr(MatchInfo.first, {MI.getOperand(0).getReg()}, {DupSrc, Lane});
765   MI.eraseFromParent();
766 }
767 
768 bool matchBuildVectorToDup(MachineInstr &MI, MachineRegisterInfo &MRI) {
769   assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
770   auto Splat = getAArch64VectorSplat(MI, MRI);
771   if (!Splat)
772     return false;
773   if (Splat->isReg())
774     return true;
775   // Later, during selection, we'll try to match imported patterns using
776   // immAllOnesV and immAllZerosV. These require G_BUILD_VECTOR. Don't lower
777   // G_BUILD_VECTORs which could match those patterns.
778   int64_t Cst = Splat->getCst();
779   return (Cst != 0 && Cst != -1);
780 }
781 
782 void applyBuildVectorToDup(MachineInstr &MI, MachineRegisterInfo &MRI,
783                            MachineIRBuilder &B) {
784   B.setInstrAndDebugLoc(MI);
785   B.buildInstr(AArch64::G_DUP, {MI.getOperand(0).getReg()},
786                {MI.getOperand(1).getReg()});
787   MI.eraseFromParent();
788 }
789 
790 /// \returns how many instructions would be saved by folding a G_ICMP's shift
791 /// and/or extension operations.
792 unsigned getCmpOperandFoldingProfit(Register CmpOp, MachineRegisterInfo &MRI) {
793   // No instructions to save if there's more than one use or no uses.
794   if (!MRI.hasOneNonDBGUse(CmpOp))
795     return 0;
796 
797   // FIXME: This is duplicated with the selector. (See: selectShiftedRegister)
798   auto IsSupportedExtend = [&](const MachineInstr &MI) {
799     if (MI.getOpcode() == TargetOpcode::G_SEXT_INREG)
800       return true;
801     if (MI.getOpcode() != TargetOpcode::G_AND)
802       return false;
803     auto ValAndVReg =
804         getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
805     if (!ValAndVReg)
806       return false;
807     uint64_t Mask = ValAndVReg->Value.getZExtValue();
808     return (Mask == 0xFF || Mask == 0xFFFF || Mask == 0xFFFFFFFF);
809   };
810 
811   MachineInstr *Def = getDefIgnoringCopies(CmpOp, MRI);
812   if (IsSupportedExtend(*Def))
813     return 1;
814 
815   unsigned Opc = Def->getOpcode();
816   if (Opc != TargetOpcode::G_SHL && Opc != TargetOpcode::G_ASHR &&
817       Opc != TargetOpcode::G_LSHR)
818     return 0;
819 
820   auto MaybeShiftAmt =
821       getIConstantVRegValWithLookThrough(Def->getOperand(2).getReg(), MRI);
822   if (!MaybeShiftAmt)
823     return 0;
824   uint64_t ShiftAmt = MaybeShiftAmt->Value.getZExtValue();
825   MachineInstr *ShiftLHS =
826       getDefIgnoringCopies(Def->getOperand(1).getReg(), MRI);
827 
828   // Check if we can fold an extend and a shift.
829   // FIXME: This is duplicated with the selector. (See:
830   // selectArithExtendedRegister)
831   if (IsSupportedExtend(*ShiftLHS))
832     return (ShiftAmt <= 4) ? 2 : 1;
833 
834   LLT Ty = MRI.getType(Def->getOperand(0).getReg());
835   if (Ty.isVector())
836     return 0;
837   unsigned ShiftSize = Ty.getSizeInBits();
838   if ((ShiftSize == 32 && ShiftAmt <= 31) ||
839       (ShiftSize == 64 && ShiftAmt <= 63))
840     return 1;
841   return 0;
842 }
843 
844 /// \returns true if it would be profitable to swap the LHS and RHS of a G_ICMP
845 /// instruction \p MI.
846 bool trySwapICmpOperands(MachineInstr &MI, MachineRegisterInfo &MRI) {
847   assert(MI.getOpcode() == TargetOpcode::G_ICMP);
848   // Swap the operands if it would introduce a profitable folding opportunity.
849   // (e.g. a shift + extend).
850   //
851   //  For example:
852   //    lsl     w13, w11, #1
853   //    cmp     w13, w12
854   // can be turned into:
855   //    cmp     w12, w11, lsl #1
856 
857   // Don't swap if there's a constant on the RHS, because we know we can fold
858   // that.
859   Register RHS = MI.getOperand(3).getReg();
860   auto RHSCst = getIConstantVRegValWithLookThrough(RHS, MRI);
861   if (RHSCst && isLegalArithImmed(RHSCst->Value.getSExtValue()))
862     return false;
863 
864   Register LHS = MI.getOperand(2).getReg();
865   auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
866   auto GetRegForProfit = [&](Register Reg) {
867     MachineInstr *Def = getDefIgnoringCopies(Reg, MRI);
868     return isCMN(Def, Pred, MRI) ? Def->getOperand(2).getReg() : Reg;
869   };
870 
871   // Don't have a constant on the RHS. If we swap the LHS and RHS of the
872   // compare, would we be able to fold more instructions?
873   Register TheLHS = GetRegForProfit(LHS);
874   Register TheRHS = GetRegForProfit(RHS);
875 
876   // If the LHS is more likely to give us a folding opportunity, then swap the
877   // LHS and RHS.
878   return (getCmpOperandFoldingProfit(TheLHS, MRI) >
879           getCmpOperandFoldingProfit(TheRHS, MRI));
880 }
881 
882 void applySwapICmpOperands(MachineInstr &MI, GISelChangeObserver &Observer) {
883   auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
884   Register LHS = MI.getOperand(2).getReg();
885   Register RHS = MI.getOperand(3).getReg();
886   Observer.changedInstr(MI);
887   MI.getOperand(1).setPredicate(CmpInst::getSwappedPredicate(Pred));
888   MI.getOperand(2).setReg(RHS);
889   MI.getOperand(3).setReg(LHS);
890   Observer.changedInstr(MI);
891 }
892 
893 /// \returns a function which builds a vector floating point compare instruction
894 /// for a condition code \p CC.
895 /// \param [in] IsZero - True if the comparison is against 0.
896 /// \param [in] NoNans - True if the target has NoNansFPMath.
897 std::function<Register(MachineIRBuilder &)>
898 getVectorFCMP(AArch64CC::CondCode CC, Register LHS, Register RHS, bool IsZero,
899               bool NoNans, MachineRegisterInfo &MRI) {
900   LLT DstTy = MRI.getType(LHS);
901   assert(DstTy.isVector() && "Expected vector types only?");
902   assert(DstTy == MRI.getType(RHS) && "Src and Dst types must match!");
903   switch (CC) {
904   default:
905     llvm_unreachable("Unexpected condition code!");
906   case AArch64CC::NE:
907     return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
908       auto FCmp = IsZero
909                       ? MIB.buildInstr(AArch64::G_FCMEQZ, {DstTy}, {LHS})
910                       : MIB.buildInstr(AArch64::G_FCMEQ, {DstTy}, {LHS, RHS});
911       return MIB.buildNot(DstTy, FCmp).getReg(0);
912     };
913   case AArch64CC::EQ:
914     return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
915       return IsZero
916                  ? MIB.buildInstr(AArch64::G_FCMEQZ, {DstTy}, {LHS}).getReg(0)
917                  : MIB.buildInstr(AArch64::G_FCMEQ, {DstTy}, {LHS, RHS})
918                        .getReg(0);
919     };
920   case AArch64CC::GE:
921     return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
922       return IsZero
923                  ? MIB.buildInstr(AArch64::G_FCMGEZ, {DstTy}, {LHS}).getReg(0)
924                  : MIB.buildInstr(AArch64::G_FCMGE, {DstTy}, {LHS, RHS})
925                        .getReg(0);
926     };
927   case AArch64CC::GT:
928     return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
929       return IsZero
930                  ? MIB.buildInstr(AArch64::G_FCMGTZ, {DstTy}, {LHS}).getReg(0)
931                  : MIB.buildInstr(AArch64::G_FCMGT, {DstTy}, {LHS, RHS})
932                        .getReg(0);
933     };
934   case AArch64CC::LS:
935     return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
936       return IsZero
937                  ? MIB.buildInstr(AArch64::G_FCMLEZ, {DstTy}, {LHS}).getReg(0)
938                  : MIB.buildInstr(AArch64::G_FCMGE, {DstTy}, {RHS, LHS})
939                        .getReg(0);
940     };
941   case AArch64CC::MI:
942     return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
943       return IsZero
944                  ? MIB.buildInstr(AArch64::G_FCMLTZ, {DstTy}, {LHS}).getReg(0)
945                  : MIB.buildInstr(AArch64::G_FCMGT, {DstTy}, {RHS, LHS})
946                        .getReg(0);
947     };
948   }
949 }
950 
951 /// Try to lower a vector G_FCMP \p MI into an AArch64-specific pseudo.
952 bool matchLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
953                           MachineIRBuilder &MIB) {
954   assert(MI.getOpcode() == TargetOpcode::G_FCMP);
955   const auto &ST = MI.getMF()->getSubtarget<AArch64Subtarget>();
956 
957   Register Dst = MI.getOperand(0).getReg();
958   LLT DstTy = MRI.getType(Dst);
959   if (!DstTy.isVector() || !ST.hasNEON())
960     return false;
961   Register LHS = MI.getOperand(2).getReg();
962   unsigned EltSize = MRI.getType(LHS).getScalarSizeInBits();
963   if (EltSize == 16 && !ST.hasFullFP16())
964     return false;
965   if (EltSize != 16 && EltSize != 32 && EltSize != 64)
966     return false;
967 
968   return true;
969 }
970 
971 /// Try to lower a vector G_FCMP \p MI into an AArch64-specific pseudo.
972 void applyLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
973                           MachineIRBuilder &MIB) {
974   assert(MI.getOpcode() == TargetOpcode::G_FCMP);
975   const auto &ST = MI.getMF()->getSubtarget<AArch64Subtarget>();
976 
977   const auto &CmpMI = cast<GFCmp>(MI);
978 
979   Register Dst = CmpMI.getReg(0);
980   CmpInst::Predicate Pred = CmpMI.getCond();
981   Register LHS = CmpMI.getLHSReg();
982   Register RHS = CmpMI.getRHSReg();
983 
984   LLT DstTy = MRI.getType(Dst);
985 
986   auto Splat = getAArch64VectorSplat(*MRI.getVRegDef(RHS), MRI);
987 
988   // Compares against 0 have special target-specific pseudos.
989   bool IsZero = Splat && Splat->isCst() && Splat->getCst() == 0;
990 
991   bool Invert = false;
992   AArch64CC::CondCode CC, CC2 = AArch64CC::AL;
993   if (Pred == CmpInst::Predicate::FCMP_ORD && IsZero) {
994     // The special case "fcmp ord %a, 0" is the canonical check that LHS isn't
995     // NaN, so equivalent to a == a and doesn't need the two comparisons an
996     // "ord" normally would.
997     RHS = LHS;
998     IsZero = false;
999     CC = AArch64CC::EQ;
1000   } else
1001     changeVectorFCMPPredToAArch64CC(Pred, CC, CC2, Invert);
1002 
1003   // Instead of having an apply function, just build here to simplify things.
1004   MIB.setInstrAndDebugLoc(MI);
1005 
1006   const bool NoNans =
1007       ST.getTargetLowering()->getTargetMachine().Options.NoNaNsFPMath;
1008 
1009   auto Cmp = getVectorFCMP(CC, LHS, RHS, IsZero, NoNans, MRI);
1010   Register CmpRes;
1011   if (CC2 == AArch64CC::AL)
1012     CmpRes = Cmp(MIB);
1013   else {
1014     auto Cmp2 = getVectorFCMP(CC2, LHS, RHS, IsZero, NoNans, MRI);
1015     auto Cmp2Dst = Cmp2(MIB);
1016     auto Cmp1Dst = Cmp(MIB);
1017     CmpRes = MIB.buildOr(DstTy, Cmp1Dst, Cmp2Dst).getReg(0);
1018   }
1019   if (Invert)
1020     CmpRes = MIB.buildNot(DstTy, CmpRes).getReg(0);
1021   MRI.replaceRegWith(Dst, CmpRes);
1022   MI.eraseFromParent();
1023 }
1024 
1025 bool matchFormTruncstore(MachineInstr &MI, MachineRegisterInfo &MRI,
1026                          Register &SrcReg) {
1027   assert(MI.getOpcode() == TargetOpcode::G_STORE);
1028   Register DstReg = MI.getOperand(0).getReg();
1029   if (MRI.getType(DstReg).isVector())
1030     return false;
1031   // Match a store of a truncate.
1032   if (!mi_match(DstReg, MRI, m_GTrunc(m_Reg(SrcReg))))
1033     return false;
1034   // Only form truncstores for value types of max 64b.
1035   return MRI.getType(SrcReg).getSizeInBits() <= 64;
1036 }
1037 
1038 void applyFormTruncstore(MachineInstr &MI, MachineRegisterInfo &MRI,
1039                          MachineIRBuilder &B, GISelChangeObserver &Observer,
1040                          Register &SrcReg) {
1041   assert(MI.getOpcode() == TargetOpcode::G_STORE);
1042   Observer.changingInstr(MI);
1043   MI.getOperand(0).setReg(SrcReg);
1044   Observer.changedInstr(MI);
1045 }
1046 
1047 // Lower vector G_SEXT_INREG back to shifts for selection. We allowed them to
1048 // form in the first place for combine opportunities, so any remaining ones
1049 // at this stage need be lowered back.
1050 bool matchVectorSextInReg(MachineInstr &MI, MachineRegisterInfo &MRI) {
1051   assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1052   Register DstReg = MI.getOperand(0).getReg();
1053   LLT DstTy = MRI.getType(DstReg);
1054   return DstTy.isVector();
1055 }
1056 
1057 void applyVectorSextInReg(MachineInstr &MI, MachineRegisterInfo &MRI,
1058                           MachineIRBuilder &B, GISelChangeObserver &Observer) {
1059   assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1060   B.setInstrAndDebugLoc(MI);
1061   LegalizerHelper Helper(*MI.getMF(), Observer, B);
1062   Helper.lower(MI, 0, /* Unused hint type */ LLT());
1063 }
1064 
1065 class AArch64PostLegalizerLoweringImpl : public GIMatchTableExecutor {
1066 protected:
1067   CombinerHelper &Helper;
1068   const AArch64PostLegalizerLoweringImplRuleConfig &RuleConfig;
1069 
1070   const AArch64Subtarget &STI;
1071   GISelChangeObserver &Observer;
1072   MachineIRBuilder &B;
1073   MachineFunction &MF;
1074 
1075   MachineRegisterInfo &MRI;
1076 
1077 public:
1078   AArch64PostLegalizerLoweringImpl(
1079       const AArch64PostLegalizerLoweringImplRuleConfig &RuleConfig,
1080       const AArch64Subtarget &STI, GISelChangeObserver &Observer,
1081       MachineIRBuilder &B, CombinerHelper &Helper);
1082 
1083   static const char *getName() { return "AArch6400PreLegalizerCombiner"; }
1084 
1085   bool tryCombineAll(MachineInstr &I) const;
1086 
1087 private:
1088 #define GET_GICOMBINER_CLASS_MEMBERS
1089 #include "AArch64GenPostLegalizeGILowering.inc"
1090 #undef GET_GICOMBINER_CLASS_MEMBERS
1091 };
1092 
1093 #define GET_GICOMBINER_IMPL
1094 #include "AArch64GenPostLegalizeGILowering.inc"
1095 #undef GET_GICOMBINER_IMPL
1096 
1097 AArch64PostLegalizerLoweringImpl::AArch64PostLegalizerLoweringImpl(
1098     const AArch64PostLegalizerLoweringImplRuleConfig &RuleConfig,
1099     const AArch64Subtarget &STI, GISelChangeObserver &Observer,
1100     MachineIRBuilder &B, CombinerHelper &Helper)
1101     : Helper(Helper), RuleConfig(RuleConfig), STI(STI), Observer(Observer),
1102       B(B), MF(B.getMF()), MRI(*B.getMRI()),
1103 #define GET_GICOMBINER_CONSTRUCTOR_INITS
1104 #include "AArch64GenPostLegalizeGILowering.inc"
1105 #undef GET_GICOMBINER_CONSTRUCTOR_INITS
1106 {
1107 }
1108 
1109 class AArch64PostLegalizerLoweringInfo : public CombinerInfo {
1110 public:
1111   AArch64PostLegalizerLoweringImplRuleConfig RuleConfig;
1112 
1113   AArch64PostLegalizerLoweringInfo(bool OptSize, bool MinSize)
1114       : CombinerInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false,
1115                      /*LegalizerInfo*/ nullptr, /*OptEnabled = */ true, OptSize,
1116                      MinSize) {
1117     if (!RuleConfig.parseCommandLineOption())
1118       report_fatal_error("Invalid rule identifier");
1119   }
1120 
1121   bool combine(GISelChangeObserver &Observer, MachineInstr &MI,
1122                MachineIRBuilder &B) const override;
1123 };
1124 
1125 bool AArch64PostLegalizerLoweringInfo::combine(GISelChangeObserver &Observer,
1126                                                MachineInstr &MI,
1127                                                MachineIRBuilder &B) const {
1128   const auto &STI = MI.getMF()->getSubtarget<AArch64Subtarget>();
1129   CombinerHelper Helper(Observer, B, /* IsPreLegalize*/ false);
1130   AArch64PostLegalizerLoweringImpl Impl(RuleConfig, STI, Observer, B, Helper);
1131   Impl.setupMF(*MI.getMF(), Helper.getKnownBits());
1132   return Impl.tryCombineAll(MI);
1133 }
1134 class AArch64PostLegalizerLowering : public MachineFunctionPass {
1135 public:
1136   static char ID;
1137 
1138   AArch64PostLegalizerLowering();
1139 
1140   StringRef getPassName() const override {
1141     return "AArch64PostLegalizerLowering";
1142   }
1143 
1144   bool runOnMachineFunction(MachineFunction &MF) override;
1145   void getAnalysisUsage(AnalysisUsage &AU) const override;
1146 };
1147 } // end anonymous namespace
1148 
1149 void AArch64PostLegalizerLowering::getAnalysisUsage(AnalysisUsage &AU) const {
1150   AU.addRequired<TargetPassConfig>();
1151   AU.setPreservesCFG();
1152   getSelectionDAGFallbackAnalysisUsage(AU);
1153   MachineFunctionPass::getAnalysisUsage(AU);
1154 }
1155 
1156 AArch64PostLegalizerLowering::AArch64PostLegalizerLowering()
1157     : MachineFunctionPass(ID) {
1158   initializeAArch64PostLegalizerLoweringPass(*PassRegistry::getPassRegistry());
1159 }
1160 
1161 bool AArch64PostLegalizerLowering::runOnMachineFunction(MachineFunction &MF) {
1162   if (MF.getProperties().hasProperty(
1163           MachineFunctionProperties::Property::FailedISel))
1164     return false;
1165   assert(MF.getProperties().hasProperty(
1166              MachineFunctionProperties::Property::Legalized) &&
1167          "Expected a legalized function?");
1168   auto *TPC = &getAnalysis<TargetPassConfig>();
1169   const Function &F = MF.getFunction();
1170   AArch64PostLegalizerLoweringInfo PCInfo(F.hasOptSize(), F.hasMinSize());
1171   Combiner C(PCInfo, TPC);
1172   return C.combineMachineInstrs(MF, /*CSEInfo*/ nullptr);
1173 }
1174 
1175 char AArch64PostLegalizerLowering::ID = 0;
1176 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerLowering, DEBUG_TYPE,
1177                       "Lower AArch64 MachineInstrs after legalization", false,
1178                       false)
1179 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
1180 INITIALIZE_PASS_END(AArch64PostLegalizerLowering, DEBUG_TYPE,
1181                     "Lower AArch64 MachineInstrs after legalization", false,
1182                     false)
1183 
1184 namespace llvm {
1185 FunctionPass *createAArch64PostLegalizerLowering() {
1186   return new AArch64PostLegalizerLowering();
1187 }
1188 } // end namespace llvm
1189