1 //===-- llvm/CodeGen/GlobalISel/LegalizerHelper.cpp -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file This file implements the LegalizerHelper class to legalize
10 /// individual instructions and the LegalizeMachineIR wrapper pass for the
11 /// primary legalization.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
16 #include "llvm/CodeGen/GlobalISel/CallLowering.h"
17 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
18 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
19 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
20 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
21 #include "llvm/CodeGen/GlobalISel/LostDebugLocObserver.h"
22 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
23 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
24 #include "llvm/CodeGen/GlobalISel/Utils.h"
25 #include "llvm/CodeGen/MachineConstantPool.h"
26 #include "llvm/CodeGen/MachineFrameInfo.h"
27 #include "llvm/CodeGen/MachineRegisterInfo.h"
28 #include "llvm/CodeGen/RuntimeLibcalls.h"
29 #include "llvm/CodeGen/TargetFrameLowering.h"
30 #include "llvm/CodeGen/TargetInstrInfo.h"
31 #include "llvm/CodeGen/TargetLowering.h"
32 #include "llvm/CodeGen/TargetOpcodes.h"
33 #include "llvm/CodeGen/TargetSubtargetInfo.h"
34 #include "llvm/IR/Instructions.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include "llvm/Target/TargetMachine.h"
39 #include <numeric>
40 #include <optional>
41 
42 #define DEBUG_TYPE "legalizer"
43 
44 using namespace llvm;
45 using namespace LegalizeActions;
46 using namespace MIPatternMatch;
47 
48 /// Try to break down \p OrigTy into \p NarrowTy sized pieces.
49 ///
50 /// Returns the number of \p NarrowTy elements needed to reconstruct \p OrigTy,
51 /// with any leftover piece as type \p LeftoverTy
52 ///
53 /// Returns -1 in the first element of the pair if the breakdown is not
54 /// satisfiable.
55 static std::pair<int, int>
getNarrowTypeBreakDown(LLT OrigTy,LLT NarrowTy,LLT & LeftoverTy)56 getNarrowTypeBreakDown(LLT OrigTy, LLT NarrowTy, LLT &LeftoverTy) {
57   assert(!LeftoverTy.isValid() && "this is an out argument");
58 
59   unsigned Size = OrigTy.getSizeInBits();
60   unsigned NarrowSize = NarrowTy.getSizeInBits();
61   unsigned NumParts = Size / NarrowSize;
62   unsigned LeftoverSize = Size - NumParts * NarrowSize;
63   assert(Size > NarrowSize);
64 
65   if (LeftoverSize == 0)
66     return {NumParts, 0};
67 
68   if (NarrowTy.isVector()) {
69     unsigned EltSize = OrigTy.getScalarSizeInBits();
70     if (LeftoverSize % EltSize != 0)
71       return {-1, -1};
72     LeftoverTy = LLT::scalarOrVector(
73         ElementCount::getFixed(LeftoverSize / EltSize), EltSize);
74   } else {
75     LeftoverTy = LLT::scalar(LeftoverSize);
76   }
77 
78   int NumLeftover = LeftoverSize / LeftoverTy.getSizeInBits();
79   return std::make_pair(NumParts, NumLeftover);
80 }
81 
getFloatTypeForLLT(LLVMContext & Ctx,LLT Ty)82 static Type *getFloatTypeForLLT(LLVMContext &Ctx, LLT Ty) {
83 
84   if (!Ty.isScalar())
85     return nullptr;
86 
87   switch (Ty.getSizeInBits()) {
88   case 16:
89     return Type::getHalfTy(Ctx);
90   case 32:
91     return Type::getFloatTy(Ctx);
92   case 64:
93     return Type::getDoubleTy(Ctx);
94   case 80:
95     return Type::getX86_FP80Ty(Ctx);
96   case 128:
97     return Type::getFP128Ty(Ctx);
98   default:
99     return nullptr;
100   }
101 }
102 
LegalizerHelper(MachineFunction & MF,GISelChangeObserver & Observer,MachineIRBuilder & Builder)103 LegalizerHelper::LegalizerHelper(MachineFunction &MF,
104                                  GISelChangeObserver &Observer,
105                                  MachineIRBuilder &Builder)
106     : MIRBuilder(Builder), Observer(Observer), MRI(MF.getRegInfo()),
107       LI(*MF.getSubtarget().getLegalizerInfo()),
108       TLI(*MF.getSubtarget().getTargetLowering()), KB(nullptr) {}
109 
LegalizerHelper(MachineFunction & MF,const LegalizerInfo & LI,GISelChangeObserver & Observer,MachineIRBuilder & B,GISelKnownBits * KB)110 LegalizerHelper::LegalizerHelper(MachineFunction &MF, const LegalizerInfo &LI,
111                                  GISelChangeObserver &Observer,
112                                  MachineIRBuilder &B, GISelKnownBits *KB)
113     : MIRBuilder(B), Observer(Observer), MRI(MF.getRegInfo()), LI(LI),
114       TLI(*MF.getSubtarget().getTargetLowering()), KB(KB) {}
115 
116 LegalizerHelper::LegalizeResult
legalizeInstrStep(MachineInstr & MI,LostDebugLocObserver & LocObserver)117 LegalizerHelper::legalizeInstrStep(MachineInstr &MI,
118                                    LostDebugLocObserver &LocObserver) {
119   LLVM_DEBUG(dbgs() << "Legalizing: " << MI);
120 
121   MIRBuilder.setInstrAndDebugLoc(MI);
122 
123   if (isa<GIntrinsic>(MI))
124     return LI.legalizeIntrinsic(*this, MI) ? Legalized : UnableToLegalize;
125   auto Step = LI.getAction(MI, MRI);
126   switch (Step.Action) {
127   case Legal:
128     LLVM_DEBUG(dbgs() << ".. Already legal\n");
129     return AlreadyLegal;
130   case Libcall:
131     LLVM_DEBUG(dbgs() << ".. Convert to libcall\n");
132     return libcall(MI, LocObserver);
133   case NarrowScalar:
134     LLVM_DEBUG(dbgs() << ".. Narrow scalar\n");
135     return narrowScalar(MI, Step.TypeIdx, Step.NewType);
136   case WidenScalar:
137     LLVM_DEBUG(dbgs() << ".. Widen scalar\n");
138     return widenScalar(MI, Step.TypeIdx, Step.NewType);
139   case Bitcast:
140     LLVM_DEBUG(dbgs() << ".. Bitcast type\n");
141     return bitcast(MI, Step.TypeIdx, Step.NewType);
142   case Lower:
143     LLVM_DEBUG(dbgs() << ".. Lower\n");
144     return lower(MI, Step.TypeIdx, Step.NewType);
145   case FewerElements:
146     LLVM_DEBUG(dbgs() << ".. Reduce number of elements\n");
147     return fewerElementsVector(MI, Step.TypeIdx, Step.NewType);
148   case MoreElements:
149     LLVM_DEBUG(dbgs() << ".. Increase number of elements\n");
150     return moreElementsVector(MI, Step.TypeIdx, Step.NewType);
151   case Custom:
152     LLVM_DEBUG(dbgs() << ".. Custom legalization\n");
153     return LI.legalizeCustom(*this, MI, LocObserver) ? Legalized
154                                                      : UnableToLegalize;
155   default:
156     LLVM_DEBUG(dbgs() << ".. Unable to legalize\n");
157     return UnableToLegalize;
158   }
159 }
160 
insertParts(Register DstReg,LLT ResultTy,LLT PartTy,ArrayRef<Register> PartRegs,LLT LeftoverTy,ArrayRef<Register> LeftoverRegs)161 void LegalizerHelper::insertParts(Register DstReg,
162                                   LLT ResultTy, LLT PartTy,
163                                   ArrayRef<Register> PartRegs,
164                                   LLT LeftoverTy,
165                                   ArrayRef<Register> LeftoverRegs) {
166   if (!LeftoverTy.isValid()) {
167     assert(LeftoverRegs.empty());
168 
169     if (!ResultTy.isVector()) {
170       MIRBuilder.buildMergeLikeInstr(DstReg, PartRegs);
171       return;
172     }
173 
174     if (PartTy.isVector())
175       MIRBuilder.buildConcatVectors(DstReg, PartRegs);
176     else
177       MIRBuilder.buildBuildVector(DstReg, PartRegs);
178     return;
179   }
180 
181   // Merge sub-vectors with different number of elements and insert into DstReg.
182   if (ResultTy.isVector()) {
183     assert(LeftoverRegs.size() == 1 && "Expected one leftover register");
184     SmallVector<Register, 8> AllRegs;
185     for (auto Reg : concat<const Register>(PartRegs, LeftoverRegs))
186       AllRegs.push_back(Reg);
187     return mergeMixedSubvectors(DstReg, AllRegs);
188   }
189 
190   SmallVector<Register> GCDRegs;
191   LLT GCDTy = getGCDType(getGCDType(ResultTy, LeftoverTy), PartTy);
192   for (auto PartReg : concat<const Register>(PartRegs, LeftoverRegs))
193     extractGCDType(GCDRegs, GCDTy, PartReg);
194   LLT ResultLCMTy = buildLCMMergePieces(ResultTy, LeftoverTy, GCDTy, GCDRegs);
195   buildWidenedRemergeToDst(DstReg, ResultLCMTy, GCDRegs);
196 }
197 
appendVectorElts(SmallVectorImpl<Register> & Elts,Register Reg)198 void LegalizerHelper::appendVectorElts(SmallVectorImpl<Register> &Elts,
199                                        Register Reg) {
200   LLT Ty = MRI.getType(Reg);
201   SmallVector<Register, 8> RegElts;
202   extractParts(Reg, Ty.getScalarType(), Ty.getNumElements(), RegElts,
203                MIRBuilder, MRI);
204   Elts.append(RegElts);
205 }
206 
207 /// Merge \p PartRegs with different types into \p DstReg.
mergeMixedSubvectors(Register DstReg,ArrayRef<Register> PartRegs)208 void LegalizerHelper::mergeMixedSubvectors(Register DstReg,
209                                            ArrayRef<Register> PartRegs) {
210   SmallVector<Register, 8> AllElts;
211   for (unsigned i = 0; i < PartRegs.size() - 1; ++i)
212     appendVectorElts(AllElts, PartRegs[i]);
213 
214   Register Leftover = PartRegs[PartRegs.size() - 1];
215   if (MRI.getType(Leftover).isScalar())
216     AllElts.push_back(Leftover);
217   else
218     appendVectorElts(AllElts, Leftover);
219 
220   MIRBuilder.buildMergeLikeInstr(DstReg, AllElts);
221 }
222 
223 /// Append the result registers of G_UNMERGE_VALUES \p MI to \p Regs.
getUnmergeResults(SmallVectorImpl<Register> & Regs,const MachineInstr & MI)224 static void getUnmergeResults(SmallVectorImpl<Register> &Regs,
225                               const MachineInstr &MI) {
226   assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES);
227 
228   const int StartIdx = Regs.size();
229   const int NumResults = MI.getNumOperands() - 1;
230   Regs.resize(Regs.size() + NumResults);
231   for (int I = 0; I != NumResults; ++I)
232     Regs[StartIdx + I] = MI.getOperand(I).getReg();
233 }
234 
extractGCDType(SmallVectorImpl<Register> & Parts,LLT GCDTy,Register SrcReg)235 void LegalizerHelper::extractGCDType(SmallVectorImpl<Register> &Parts,
236                                      LLT GCDTy, Register SrcReg) {
237   LLT SrcTy = MRI.getType(SrcReg);
238   if (SrcTy == GCDTy) {
239     // If the source already evenly divides the result type, we don't need to do
240     // anything.
241     Parts.push_back(SrcReg);
242   } else {
243     // Need to split into common type sized pieces.
244     auto Unmerge = MIRBuilder.buildUnmerge(GCDTy, SrcReg);
245     getUnmergeResults(Parts, *Unmerge);
246   }
247 }
248 
extractGCDType(SmallVectorImpl<Register> & Parts,LLT DstTy,LLT NarrowTy,Register SrcReg)249 LLT LegalizerHelper::extractGCDType(SmallVectorImpl<Register> &Parts, LLT DstTy,
250                                     LLT NarrowTy, Register SrcReg) {
251   LLT SrcTy = MRI.getType(SrcReg);
252   LLT GCDTy = getGCDType(getGCDType(SrcTy, NarrowTy), DstTy);
253   extractGCDType(Parts, GCDTy, SrcReg);
254   return GCDTy;
255 }
256 
buildLCMMergePieces(LLT DstTy,LLT NarrowTy,LLT GCDTy,SmallVectorImpl<Register> & VRegs,unsigned PadStrategy)257 LLT LegalizerHelper::buildLCMMergePieces(LLT DstTy, LLT NarrowTy, LLT GCDTy,
258                                          SmallVectorImpl<Register> &VRegs,
259                                          unsigned PadStrategy) {
260   LLT LCMTy = getLCMType(DstTy, NarrowTy);
261 
262   int NumParts = LCMTy.getSizeInBits() / NarrowTy.getSizeInBits();
263   int NumSubParts = NarrowTy.getSizeInBits() / GCDTy.getSizeInBits();
264   int NumOrigSrc = VRegs.size();
265 
266   Register PadReg;
267 
268   // Get a value we can use to pad the source value if the sources won't evenly
269   // cover the result type.
270   if (NumOrigSrc < NumParts * NumSubParts) {
271     if (PadStrategy == TargetOpcode::G_ZEXT)
272       PadReg = MIRBuilder.buildConstant(GCDTy, 0).getReg(0);
273     else if (PadStrategy == TargetOpcode::G_ANYEXT)
274       PadReg = MIRBuilder.buildUndef(GCDTy).getReg(0);
275     else {
276       assert(PadStrategy == TargetOpcode::G_SEXT);
277 
278       // Shift the sign bit of the low register through the high register.
279       auto ShiftAmt =
280         MIRBuilder.buildConstant(LLT::scalar(64), GCDTy.getSizeInBits() - 1);
281       PadReg = MIRBuilder.buildAShr(GCDTy, VRegs.back(), ShiftAmt).getReg(0);
282     }
283   }
284 
285   // Registers for the final merge to be produced.
286   SmallVector<Register, 4> Remerge(NumParts);
287 
288   // Registers needed for intermediate merges, which will be merged into a
289   // source for Remerge.
290   SmallVector<Register, 4> SubMerge(NumSubParts);
291 
292   // Once we've fully read off the end of the original source bits, we can reuse
293   // the same high bits for remaining padding elements.
294   Register AllPadReg;
295 
296   // Build merges to the LCM type to cover the original result type.
297   for (int I = 0; I != NumParts; ++I) {
298     bool AllMergePartsArePadding = true;
299 
300     // Build the requested merges to the requested type.
301     for (int J = 0; J != NumSubParts; ++J) {
302       int Idx = I * NumSubParts + J;
303       if (Idx >= NumOrigSrc) {
304         SubMerge[J] = PadReg;
305         continue;
306       }
307 
308       SubMerge[J] = VRegs[Idx];
309 
310       // There are meaningful bits here we can't reuse later.
311       AllMergePartsArePadding = false;
312     }
313 
314     // If we've filled up a complete piece with padding bits, we can directly
315     // emit the natural sized constant if applicable, rather than a merge of
316     // smaller constants.
317     if (AllMergePartsArePadding && !AllPadReg) {
318       if (PadStrategy == TargetOpcode::G_ANYEXT)
319         AllPadReg = MIRBuilder.buildUndef(NarrowTy).getReg(0);
320       else if (PadStrategy == TargetOpcode::G_ZEXT)
321         AllPadReg = MIRBuilder.buildConstant(NarrowTy, 0).getReg(0);
322 
323       // If this is a sign extension, we can't materialize a trivial constant
324       // with the right type and have to produce a merge.
325     }
326 
327     if (AllPadReg) {
328       // Avoid creating additional instructions if we're just adding additional
329       // copies of padding bits.
330       Remerge[I] = AllPadReg;
331       continue;
332     }
333 
334     if (NumSubParts == 1)
335       Remerge[I] = SubMerge[0];
336     else
337       Remerge[I] = MIRBuilder.buildMergeLikeInstr(NarrowTy, SubMerge).getReg(0);
338 
339     // In the sign extend padding case, re-use the first all-signbit merge.
340     if (AllMergePartsArePadding && !AllPadReg)
341       AllPadReg = Remerge[I];
342   }
343 
344   VRegs = std::move(Remerge);
345   return LCMTy;
346 }
347 
buildWidenedRemergeToDst(Register DstReg,LLT LCMTy,ArrayRef<Register> RemergeRegs)348 void LegalizerHelper::buildWidenedRemergeToDst(Register DstReg, LLT LCMTy,
349                                                ArrayRef<Register> RemergeRegs) {
350   LLT DstTy = MRI.getType(DstReg);
351 
352   // Create the merge to the widened source, and extract the relevant bits into
353   // the result.
354 
355   if (DstTy == LCMTy) {
356     MIRBuilder.buildMergeLikeInstr(DstReg, RemergeRegs);
357     return;
358   }
359 
360   auto Remerge = MIRBuilder.buildMergeLikeInstr(LCMTy, RemergeRegs);
361   if (DstTy.isScalar() && LCMTy.isScalar()) {
362     MIRBuilder.buildTrunc(DstReg, Remerge);
363     return;
364   }
365 
366   if (LCMTy.isVector()) {
367     unsigned NumDefs = LCMTy.getSizeInBits() / DstTy.getSizeInBits();
368     SmallVector<Register, 8> UnmergeDefs(NumDefs);
369     UnmergeDefs[0] = DstReg;
370     for (unsigned I = 1; I != NumDefs; ++I)
371       UnmergeDefs[I] = MRI.createGenericVirtualRegister(DstTy);
372 
373     MIRBuilder.buildUnmerge(UnmergeDefs,
374                             MIRBuilder.buildMergeLikeInstr(LCMTy, RemergeRegs));
375     return;
376   }
377 
378   llvm_unreachable("unhandled case");
379 }
380 
getRTLibDesc(unsigned Opcode,unsigned Size)381 static RTLIB::Libcall getRTLibDesc(unsigned Opcode, unsigned Size) {
382 #define RTLIBCASE_INT(LibcallPrefix)                                           \
383   do {                                                                         \
384     switch (Size) {                                                            \
385     case 32:                                                                   \
386       return RTLIB::LibcallPrefix##32;                                         \
387     case 64:                                                                   \
388       return RTLIB::LibcallPrefix##64;                                         \
389     case 128:                                                                  \
390       return RTLIB::LibcallPrefix##128;                                        \
391     default:                                                                   \
392       llvm_unreachable("unexpected size");                                     \
393     }                                                                          \
394   } while (0)
395 
396 #define RTLIBCASE(LibcallPrefix)                                               \
397   do {                                                                         \
398     switch (Size) {                                                            \
399     case 32:                                                                   \
400       return RTLIB::LibcallPrefix##32;                                         \
401     case 64:                                                                   \
402       return RTLIB::LibcallPrefix##64;                                         \
403     case 80:                                                                   \
404       return RTLIB::LibcallPrefix##80;                                         \
405     case 128:                                                                  \
406       return RTLIB::LibcallPrefix##128;                                        \
407     default:                                                                   \
408       llvm_unreachable("unexpected size");                                     \
409     }                                                                          \
410   } while (0)
411 
412   switch (Opcode) {
413   case TargetOpcode::G_MUL:
414     RTLIBCASE_INT(MUL_I);
415   case TargetOpcode::G_SDIV:
416     RTLIBCASE_INT(SDIV_I);
417   case TargetOpcode::G_UDIV:
418     RTLIBCASE_INT(UDIV_I);
419   case TargetOpcode::G_SREM:
420     RTLIBCASE_INT(SREM_I);
421   case TargetOpcode::G_UREM:
422     RTLIBCASE_INT(UREM_I);
423   case TargetOpcode::G_CTLZ_ZERO_UNDEF:
424     RTLIBCASE_INT(CTLZ_I);
425   case TargetOpcode::G_FADD:
426     RTLIBCASE(ADD_F);
427   case TargetOpcode::G_FSUB:
428     RTLIBCASE(SUB_F);
429   case TargetOpcode::G_FMUL:
430     RTLIBCASE(MUL_F);
431   case TargetOpcode::G_FDIV:
432     RTLIBCASE(DIV_F);
433   case TargetOpcode::G_FEXP:
434     RTLIBCASE(EXP_F);
435   case TargetOpcode::G_FEXP2:
436     RTLIBCASE(EXP2_F);
437   case TargetOpcode::G_FEXP10:
438     RTLIBCASE(EXP10_F);
439   case TargetOpcode::G_FREM:
440     RTLIBCASE(REM_F);
441   case TargetOpcode::G_FPOW:
442     RTLIBCASE(POW_F);
443   case TargetOpcode::G_FPOWI:
444     RTLIBCASE(POWI_F);
445   case TargetOpcode::G_FMA:
446     RTLIBCASE(FMA_F);
447   case TargetOpcode::G_FSIN:
448     RTLIBCASE(SIN_F);
449   case TargetOpcode::G_FCOS:
450     RTLIBCASE(COS_F);
451   case TargetOpcode::G_FLOG10:
452     RTLIBCASE(LOG10_F);
453   case TargetOpcode::G_FLOG:
454     RTLIBCASE(LOG_F);
455   case TargetOpcode::G_FLOG2:
456     RTLIBCASE(LOG2_F);
457   case TargetOpcode::G_FLDEXP:
458     RTLIBCASE(LDEXP_F);
459   case TargetOpcode::G_FCEIL:
460     RTLIBCASE(CEIL_F);
461   case TargetOpcode::G_FFLOOR:
462     RTLIBCASE(FLOOR_F);
463   case TargetOpcode::G_FMINNUM:
464     RTLIBCASE(FMIN_F);
465   case TargetOpcode::G_FMAXNUM:
466     RTLIBCASE(FMAX_F);
467   case TargetOpcode::G_FSQRT:
468     RTLIBCASE(SQRT_F);
469   case TargetOpcode::G_FRINT:
470     RTLIBCASE(RINT_F);
471   case TargetOpcode::G_FNEARBYINT:
472     RTLIBCASE(NEARBYINT_F);
473   case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
474     RTLIBCASE(ROUNDEVEN_F);
475   }
476   llvm_unreachable("Unknown libcall function");
477 }
478 
479 /// True if an instruction is in tail position in its caller. Intended for
480 /// legalizing libcalls as tail calls when possible.
isLibCallInTailPosition(const CallLowering::ArgInfo & Result,MachineInstr & MI,const TargetInstrInfo & TII,MachineRegisterInfo & MRI)481 static bool isLibCallInTailPosition(const CallLowering::ArgInfo &Result,
482                                     MachineInstr &MI,
483                                     const TargetInstrInfo &TII,
484                                     MachineRegisterInfo &MRI) {
485   MachineBasicBlock &MBB = *MI.getParent();
486   const Function &F = MBB.getParent()->getFunction();
487 
488   // Conservatively require the attributes of the call to match those of
489   // the return. Ignore NoAlias and NonNull because they don't affect the
490   // call sequence.
491   AttributeList CallerAttrs = F.getAttributes();
492   if (AttrBuilder(F.getContext(), CallerAttrs.getRetAttrs())
493           .removeAttribute(Attribute::NoAlias)
494           .removeAttribute(Attribute::NonNull)
495           .hasAttributes())
496     return false;
497 
498   // It's not safe to eliminate the sign / zero extension of the return value.
499   if (CallerAttrs.hasRetAttr(Attribute::ZExt) ||
500       CallerAttrs.hasRetAttr(Attribute::SExt))
501     return false;
502 
503   // Only tail call if the following instruction is a standard return or if we
504   // have a `thisreturn` callee, and a sequence like:
505   //
506   //   G_MEMCPY %0, %1, %2
507   //   $x0 = COPY %0
508   //   RET_ReallyLR implicit $x0
509   auto Next = next_nodbg(MI.getIterator(), MBB.instr_end());
510   if (Next != MBB.instr_end() && Next->isCopy()) {
511     if (MI.getOpcode() == TargetOpcode::G_BZERO)
512       return false;
513 
514     // For MEMCPY/MOMMOVE/MEMSET these will be the first use (the dst), as the
515     // mempy/etc routines return the same parameter. For other it will be the
516     // returned value.
517     Register VReg = MI.getOperand(0).getReg();
518     if (!VReg.isVirtual() || VReg != Next->getOperand(1).getReg())
519       return false;
520 
521     Register PReg = Next->getOperand(0).getReg();
522     if (!PReg.isPhysical())
523       return false;
524 
525     auto Ret = next_nodbg(Next, MBB.instr_end());
526     if (Ret == MBB.instr_end() || !Ret->isReturn())
527       return false;
528 
529     if (Ret->getNumImplicitOperands() != 1)
530       return false;
531 
532     if (!Ret->getOperand(0).isReg() || PReg != Ret->getOperand(0).getReg())
533       return false;
534 
535     // Skip over the COPY that we just validated.
536     Next = Ret;
537   }
538 
539   if (Next == MBB.instr_end() || TII.isTailCall(*Next) || !Next->isReturn())
540     return false;
541 
542   return true;
543 }
544 
545 LegalizerHelper::LegalizeResult
createLibcall(MachineIRBuilder & MIRBuilder,const char * Name,const CallLowering::ArgInfo & Result,ArrayRef<CallLowering::ArgInfo> Args,const CallingConv::ID CC,LostDebugLocObserver & LocObserver,MachineInstr * MI)546 llvm::createLibcall(MachineIRBuilder &MIRBuilder, const char *Name,
547                     const CallLowering::ArgInfo &Result,
548                     ArrayRef<CallLowering::ArgInfo> Args,
549                     const CallingConv::ID CC, LostDebugLocObserver &LocObserver,
550                     MachineInstr *MI) {
551   auto &CLI = *MIRBuilder.getMF().getSubtarget().getCallLowering();
552 
553   CallLowering::CallLoweringInfo Info;
554   Info.CallConv = CC;
555   Info.Callee = MachineOperand::CreateES(Name);
556   Info.OrigRet = Result;
557   if (MI)
558     Info.IsTailCall =
559         (Result.Ty->isVoidTy() ||
560          Result.Ty == MIRBuilder.getMF().getFunction().getReturnType()) &&
561         isLibCallInTailPosition(Result, *MI, MIRBuilder.getTII(),
562                                 *MIRBuilder.getMRI());
563 
564   std::copy(Args.begin(), Args.end(), std::back_inserter(Info.OrigArgs));
565   if (!CLI.lowerCall(MIRBuilder, Info))
566     return LegalizerHelper::UnableToLegalize;
567 
568   if (MI && Info.LoweredTailCall) {
569     assert(Info.IsTailCall && "Lowered tail call when it wasn't a tail call?");
570 
571     // Check debug locations before removing the return.
572     LocObserver.checkpoint(true);
573 
574     // We must have a return following the call (or debug insts) to get past
575     // isLibCallInTailPosition.
576     do {
577       MachineInstr *Next = MI->getNextNode();
578       assert(Next &&
579              (Next->isCopy() || Next->isReturn() || Next->isDebugInstr()) &&
580              "Expected instr following MI to be return or debug inst?");
581       // We lowered a tail call, so the call is now the return from the block.
582       // Delete the old return.
583       Next->eraseFromParent();
584     } while (MI->getNextNode());
585 
586     // We expect to lose the debug location from the return.
587     LocObserver.checkpoint(false);
588   }
589   return LegalizerHelper::Legalized;
590 }
591 
592 LegalizerHelper::LegalizeResult
createLibcall(MachineIRBuilder & MIRBuilder,RTLIB::Libcall Libcall,const CallLowering::ArgInfo & Result,ArrayRef<CallLowering::ArgInfo> Args,LostDebugLocObserver & LocObserver,MachineInstr * MI)593 llvm::createLibcall(MachineIRBuilder &MIRBuilder, RTLIB::Libcall Libcall,
594                     const CallLowering::ArgInfo &Result,
595                     ArrayRef<CallLowering::ArgInfo> Args,
596                     LostDebugLocObserver &LocObserver, MachineInstr *MI) {
597   auto &TLI = *MIRBuilder.getMF().getSubtarget().getTargetLowering();
598   const char *Name = TLI.getLibcallName(Libcall);
599   if (!Name)
600     return LegalizerHelper::UnableToLegalize;
601   const CallingConv::ID CC = TLI.getLibcallCallingConv(Libcall);
602   return createLibcall(MIRBuilder, Name, Result, Args, CC, LocObserver, MI);
603 }
604 
605 // Useful for libcalls where all operands have the same type.
606 static LegalizerHelper::LegalizeResult
simpleLibcall(MachineInstr & MI,MachineIRBuilder & MIRBuilder,unsigned Size,Type * OpType,LostDebugLocObserver & LocObserver)607 simpleLibcall(MachineInstr &MI, MachineIRBuilder &MIRBuilder, unsigned Size,
608               Type *OpType, LostDebugLocObserver &LocObserver) {
609   auto Libcall = getRTLibDesc(MI.getOpcode(), Size);
610 
611   // FIXME: What does the original arg index mean here?
612   SmallVector<CallLowering::ArgInfo, 3> Args;
613   for (const MachineOperand &MO : llvm::drop_begin(MI.operands()))
614     Args.push_back({MO.getReg(), OpType, 0});
615   return createLibcall(MIRBuilder, Libcall,
616                        {MI.getOperand(0).getReg(), OpType, 0}, Args,
617                        LocObserver, &MI);
618 }
619 
620 LegalizerHelper::LegalizeResult
createMemLibcall(MachineIRBuilder & MIRBuilder,MachineRegisterInfo & MRI,MachineInstr & MI,LostDebugLocObserver & LocObserver)621 llvm::createMemLibcall(MachineIRBuilder &MIRBuilder, MachineRegisterInfo &MRI,
622                        MachineInstr &MI, LostDebugLocObserver &LocObserver) {
623   auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
624 
625   SmallVector<CallLowering::ArgInfo, 3> Args;
626   // Add all the args, except for the last which is an imm denoting 'tail'.
627   for (unsigned i = 0; i < MI.getNumOperands() - 1; ++i) {
628     Register Reg = MI.getOperand(i).getReg();
629 
630     // Need derive an IR type for call lowering.
631     LLT OpLLT = MRI.getType(Reg);
632     Type *OpTy = nullptr;
633     if (OpLLT.isPointer())
634       OpTy = PointerType::get(Ctx, OpLLT.getAddressSpace());
635     else
636       OpTy = IntegerType::get(Ctx, OpLLT.getSizeInBits());
637     Args.push_back({Reg, OpTy, 0});
638   }
639 
640   auto &CLI = *MIRBuilder.getMF().getSubtarget().getCallLowering();
641   auto &TLI = *MIRBuilder.getMF().getSubtarget().getTargetLowering();
642   RTLIB::Libcall RTLibcall;
643   unsigned Opc = MI.getOpcode();
644   switch (Opc) {
645   case TargetOpcode::G_BZERO:
646     RTLibcall = RTLIB::BZERO;
647     break;
648   case TargetOpcode::G_MEMCPY:
649     RTLibcall = RTLIB::MEMCPY;
650     Args[0].Flags[0].setReturned();
651     break;
652   case TargetOpcode::G_MEMMOVE:
653     RTLibcall = RTLIB::MEMMOVE;
654     Args[0].Flags[0].setReturned();
655     break;
656   case TargetOpcode::G_MEMSET:
657     RTLibcall = RTLIB::MEMSET;
658     Args[0].Flags[0].setReturned();
659     break;
660   default:
661     llvm_unreachable("unsupported opcode");
662   }
663   const char *Name = TLI.getLibcallName(RTLibcall);
664 
665   // Unsupported libcall on the target.
666   if (!Name) {
667     LLVM_DEBUG(dbgs() << ".. .. Could not find libcall name for "
668                       << MIRBuilder.getTII().getName(Opc) << "\n");
669     return LegalizerHelper::UnableToLegalize;
670   }
671 
672   CallLowering::CallLoweringInfo Info;
673   Info.CallConv = TLI.getLibcallCallingConv(RTLibcall);
674   Info.Callee = MachineOperand::CreateES(Name);
675   Info.OrigRet = CallLowering::ArgInfo({0}, Type::getVoidTy(Ctx), 0);
676   Info.IsTailCall =
677       MI.getOperand(MI.getNumOperands() - 1).getImm() &&
678       isLibCallInTailPosition(Info.OrigRet, MI, MIRBuilder.getTII(), MRI);
679 
680   std::copy(Args.begin(), Args.end(), std::back_inserter(Info.OrigArgs));
681   if (!CLI.lowerCall(MIRBuilder, Info))
682     return LegalizerHelper::UnableToLegalize;
683 
684   if (Info.LoweredTailCall) {
685     assert(Info.IsTailCall && "Lowered tail call when it wasn't a tail call?");
686 
687     // Check debug locations before removing the return.
688     LocObserver.checkpoint(true);
689 
690     // We must have a return following the call (or debug insts) to get past
691     // isLibCallInTailPosition.
692     do {
693       MachineInstr *Next = MI.getNextNode();
694       assert(Next &&
695              (Next->isCopy() || Next->isReturn() || Next->isDebugInstr()) &&
696              "Expected instr following MI to be return or debug inst?");
697       // We lowered a tail call, so the call is now the return from the block.
698       // Delete the old return.
699       Next->eraseFromParent();
700     } while (MI.getNextNode());
701 
702     // We expect to lose the debug location from the return.
703     LocObserver.checkpoint(false);
704   }
705 
706   return LegalizerHelper::Legalized;
707 }
708 
getOutlineAtomicLibcall(MachineInstr & MI)709 static RTLIB::Libcall getOutlineAtomicLibcall(MachineInstr &MI) {
710   unsigned Opc = MI.getOpcode();
711   auto &AtomicMI = cast<GMemOperation>(MI);
712   auto &MMO = AtomicMI.getMMO();
713   auto Ordering = MMO.getMergedOrdering();
714   LLT MemType = MMO.getMemoryType();
715   uint64_t MemSize = MemType.getSizeInBytes();
716   if (MemType.isVector())
717     return RTLIB::UNKNOWN_LIBCALL;
718 
719 #define LCALLS(A, B)                                                           \
720   { A##B##_RELAX, A##B##_ACQ, A##B##_REL, A##B##_ACQ_REL }
721 #define LCALL5(A)                                                              \
722   LCALLS(A, 1), LCALLS(A, 2), LCALLS(A, 4), LCALLS(A, 8), LCALLS(A, 16)
723   switch (Opc) {
724   case TargetOpcode::G_ATOMIC_CMPXCHG:
725   case TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS: {
726     const RTLIB::Libcall LC[5][4] = {LCALL5(RTLIB::OUTLINE_ATOMIC_CAS)};
727     return getOutlineAtomicHelper(LC, Ordering, MemSize);
728   }
729   case TargetOpcode::G_ATOMICRMW_XCHG: {
730     const RTLIB::Libcall LC[5][4] = {LCALL5(RTLIB::OUTLINE_ATOMIC_SWP)};
731     return getOutlineAtomicHelper(LC, Ordering, MemSize);
732   }
733   case TargetOpcode::G_ATOMICRMW_ADD:
734   case TargetOpcode::G_ATOMICRMW_SUB: {
735     const RTLIB::Libcall LC[5][4] = {LCALL5(RTLIB::OUTLINE_ATOMIC_LDADD)};
736     return getOutlineAtomicHelper(LC, Ordering, MemSize);
737   }
738   case TargetOpcode::G_ATOMICRMW_AND: {
739     const RTLIB::Libcall LC[5][4] = {LCALL5(RTLIB::OUTLINE_ATOMIC_LDCLR)};
740     return getOutlineAtomicHelper(LC, Ordering, MemSize);
741   }
742   case TargetOpcode::G_ATOMICRMW_OR: {
743     const RTLIB::Libcall LC[5][4] = {LCALL5(RTLIB::OUTLINE_ATOMIC_LDSET)};
744     return getOutlineAtomicHelper(LC, Ordering, MemSize);
745   }
746   case TargetOpcode::G_ATOMICRMW_XOR: {
747     const RTLIB::Libcall LC[5][4] = {LCALL5(RTLIB::OUTLINE_ATOMIC_LDEOR)};
748     return getOutlineAtomicHelper(LC, Ordering, MemSize);
749   }
750   default:
751     return RTLIB::UNKNOWN_LIBCALL;
752   }
753 #undef LCALLS
754 #undef LCALL5
755 }
756 
757 static LegalizerHelper::LegalizeResult
createAtomicLibcall(MachineIRBuilder & MIRBuilder,MachineInstr & MI)758 createAtomicLibcall(MachineIRBuilder &MIRBuilder, MachineInstr &MI) {
759   auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
760 
761   Type *RetTy;
762   SmallVector<Register> RetRegs;
763   SmallVector<CallLowering::ArgInfo, 3> Args;
764   unsigned Opc = MI.getOpcode();
765   switch (Opc) {
766   case TargetOpcode::G_ATOMIC_CMPXCHG:
767   case TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS: {
768     Register Success;
769     LLT SuccessLLT;
770     auto [Ret, RetLLT, Mem, MemLLT, Cmp, CmpLLT, New, NewLLT] =
771         MI.getFirst4RegLLTs();
772     RetRegs.push_back(Ret);
773     RetTy = IntegerType::get(Ctx, RetLLT.getSizeInBits());
774     if (Opc == TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS) {
775       std::tie(Ret, RetLLT, Success, SuccessLLT, Mem, MemLLT, Cmp, CmpLLT, New,
776                NewLLT) = MI.getFirst5RegLLTs();
777       RetRegs.push_back(Success);
778       RetTy = StructType::get(
779           Ctx, {RetTy, IntegerType::get(Ctx, SuccessLLT.getSizeInBits())});
780     }
781     Args.push_back({Cmp, IntegerType::get(Ctx, CmpLLT.getSizeInBits()), 0});
782     Args.push_back({New, IntegerType::get(Ctx, NewLLT.getSizeInBits()), 0});
783     Args.push_back({Mem, PointerType::get(Ctx, MemLLT.getAddressSpace()), 0});
784     break;
785   }
786   case TargetOpcode::G_ATOMICRMW_XCHG:
787   case TargetOpcode::G_ATOMICRMW_ADD:
788   case TargetOpcode::G_ATOMICRMW_SUB:
789   case TargetOpcode::G_ATOMICRMW_AND:
790   case TargetOpcode::G_ATOMICRMW_OR:
791   case TargetOpcode::G_ATOMICRMW_XOR: {
792     auto [Ret, RetLLT, Mem, MemLLT, Val, ValLLT] = MI.getFirst3RegLLTs();
793     RetRegs.push_back(Ret);
794     RetTy = IntegerType::get(Ctx, RetLLT.getSizeInBits());
795     if (Opc == TargetOpcode::G_ATOMICRMW_AND)
796       Val =
797           MIRBuilder.buildXor(ValLLT, MIRBuilder.buildConstant(ValLLT, -1), Val)
798               .getReg(0);
799     else if (Opc == TargetOpcode::G_ATOMICRMW_SUB)
800       Val =
801           MIRBuilder.buildSub(ValLLT, MIRBuilder.buildConstant(ValLLT, 0), Val)
802               .getReg(0);
803     Args.push_back({Val, IntegerType::get(Ctx, ValLLT.getSizeInBits()), 0});
804     Args.push_back({Mem, PointerType::get(Ctx, MemLLT.getAddressSpace()), 0});
805     break;
806   }
807   default:
808     llvm_unreachable("unsupported opcode");
809   }
810 
811   auto &CLI = *MIRBuilder.getMF().getSubtarget().getCallLowering();
812   auto &TLI = *MIRBuilder.getMF().getSubtarget().getTargetLowering();
813   RTLIB::Libcall RTLibcall = getOutlineAtomicLibcall(MI);
814   const char *Name = TLI.getLibcallName(RTLibcall);
815 
816   // Unsupported libcall on the target.
817   if (!Name) {
818     LLVM_DEBUG(dbgs() << ".. .. Could not find libcall name for "
819                       << MIRBuilder.getTII().getName(Opc) << "\n");
820     return LegalizerHelper::UnableToLegalize;
821   }
822 
823   CallLowering::CallLoweringInfo Info;
824   Info.CallConv = TLI.getLibcallCallingConv(RTLibcall);
825   Info.Callee = MachineOperand::CreateES(Name);
826   Info.OrigRet = CallLowering::ArgInfo(RetRegs, RetTy, 0);
827 
828   std::copy(Args.begin(), Args.end(), std::back_inserter(Info.OrigArgs));
829   if (!CLI.lowerCall(MIRBuilder, Info))
830     return LegalizerHelper::UnableToLegalize;
831 
832   return LegalizerHelper::Legalized;
833 }
834 
getConvRTLibDesc(unsigned Opcode,Type * ToType,Type * FromType)835 static RTLIB::Libcall getConvRTLibDesc(unsigned Opcode, Type *ToType,
836                                        Type *FromType) {
837   auto ToMVT = MVT::getVT(ToType);
838   auto FromMVT = MVT::getVT(FromType);
839 
840   switch (Opcode) {
841   case TargetOpcode::G_FPEXT:
842     return RTLIB::getFPEXT(FromMVT, ToMVT);
843   case TargetOpcode::G_FPTRUNC:
844     return RTLIB::getFPROUND(FromMVT, ToMVT);
845   case TargetOpcode::G_FPTOSI:
846     return RTLIB::getFPTOSINT(FromMVT, ToMVT);
847   case TargetOpcode::G_FPTOUI:
848     return RTLIB::getFPTOUINT(FromMVT, ToMVT);
849   case TargetOpcode::G_SITOFP:
850     return RTLIB::getSINTTOFP(FromMVT, ToMVT);
851   case TargetOpcode::G_UITOFP:
852     return RTLIB::getUINTTOFP(FromMVT, ToMVT);
853   }
854   llvm_unreachable("Unsupported libcall function");
855 }
856 
857 static LegalizerHelper::LegalizeResult
conversionLibcall(MachineInstr & MI,MachineIRBuilder & MIRBuilder,Type * ToType,Type * FromType,LostDebugLocObserver & LocObserver)858 conversionLibcall(MachineInstr &MI, MachineIRBuilder &MIRBuilder, Type *ToType,
859                   Type *FromType, LostDebugLocObserver &LocObserver) {
860   RTLIB::Libcall Libcall = getConvRTLibDesc(MI.getOpcode(), ToType, FromType);
861   return createLibcall(
862       MIRBuilder, Libcall, {MI.getOperand(0).getReg(), ToType, 0},
863       {{MI.getOperand(1).getReg(), FromType, 0}}, LocObserver, &MI);
864 }
865 
866 static RTLIB::Libcall
getStateLibraryFunctionFor(MachineInstr & MI,const TargetLowering & TLI)867 getStateLibraryFunctionFor(MachineInstr &MI, const TargetLowering &TLI) {
868   RTLIB::Libcall RTLibcall;
869   switch (MI.getOpcode()) {
870   case TargetOpcode::G_GET_FPENV:
871     RTLibcall = RTLIB::FEGETENV;
872     break;
873   case TargetOpcode::G_SET_FPENV:
874   case TargetOpcode::G_RESET_FPENV:
875     RTLibcall = RTLIB::FESETENV;
876     break;
877   case TargetOpcode::G_GET_FPMODE:
878     RTLibcall = RTLIB::FEGETMODE;
879     break;
880   case TargetOpcode::G_SET_FPMODE:
881   case TargetOpcode::G_RESET_FPMODE:
882     RTLibcall = RTLIB::FESETMODE;
883     break;
884   default:
885     llvm_unreachable("Unexpected opcode");
886   }
887   return RTLibcall;
888 }
889 
890 // Some library functions that read FP state (fegetmode, fegetenv) write the
891 // state into a region in memory. IR intrinsics that do the same operations
892 // (get_fpmode, get_fpenv) return the state as integer value. To implement these
893 // intrinsics via the library functions, we need to use temporary variable,
894 // for example:
895 //
896 //     %0:_(s32) = G_GET_FPMODE
897 //
898 // is transformed to:
899 //
900 //     %1:_(p0) = G_FRAME_INDEX %stack.0
901 //     BL &fegetmode
902 //     %0:_(s32) = G_LOAD % 1
903 //
904 LegalizerHelper::LegalizeResult
createGetStateLibcall(MachineIRBuilder & MIRBuilder,MachineInstr & MI,LostDebugLocObserver & LocObserver)905 LegalizerHelper::createGetStateLibcall(MachineIRBuilder &MIRBuilder,
906                                        MachineInstr &MI,
907                                        LostDebugLocObserver &LocObserver) {
908   const DataLayout &DL = MIRBuilder.getDataLayout();
909   auto &MF = MIRBuilder.getMF();
910   auto &MRI = *MIRBuilder.getMRI();
911   auto &Ctx = MF.getFunction().getContext();
912 
913   // Create temporary, where library function will put the read state.
914   Register Dst = MI.getOperand(0).getReg();
915   LLT StateTy = MRI.getType(Dst);
916   TypeSize StateSize = StateTy.getSizeInBytes();
917   Align TempAlign = getStackTemporaryAlignment(StateTy);
918   MachinePointerInfo TempPtrInfo;
919   auto Temp = createStackTemporary(StateSize, TempAlign, TempPtrInfo);
920 
921   // Create a call to library function, with the temporary as an argument.
922   unsigned TempAddrSpace = DL.getAllocaAddrSpace();
923   Type *StatePtrTy = PointerType::get(Ctx, TempAddrSpace);
924   RTLIB::Libcall RTLibcall = getStateLibraryFunctionFor(MI, TLI);
925   auto Res =
926       createLibcall(MIRBuilder, RTLibcall,
927                     CallLowering::ArgInfo({0}, Type::getVoidTy(Ctx), 0),
928                     CallLowering::ArgInfo({Temp.getReg(0), StatePtrTy, 0}),
929                     LocObserver, nullptr);
930   if (Res != LegalizerHelper::Legalized)
931     return Res;
932 
933   // Create a load from the temporary.
934   MachineMemOperand *MMO = MF.getMachineMemOperand(
935       TempPtrInfo, MachineMemOperand::MOLoad, StateTy, TempAlign);
936   MIRBuilder.buildLoadInstr(TargetOpcode::G_LOAD, Dst, Temp, *MMO);
937 
938   return LegalizerHelper::Legalized;
939 }
940 
941 // Similar to `createGetStateLibcall` the function calls a library function
942 // using transient space in stack. In this case the library function reads
943 // content of memory region.
944 LegalizerHelper::LegalizeResult
createSetStateLibcall(MachineIRBuilder & MIRBuilder,MachineInstr & MI,LostDebugLocObserver & LocObserver)945 LegalizerHelper::createSetStateLibcall(MachineIRBuilder &MIRBuilder,
946                                        MachineInstr &MI,
947                                        LostDebugLocObserver &LocObserver) {
948   const DataLayout &DL = MIRBuilder.getDataLayout();
949   auto &MF = MIRBuilder.getMF();
950   auto &MRI = *MIRBuilder.getMRI();
951   auto &Ctx = MF.getFunction().getContext();
952 
953   // Create temporary, where library function will get the new state.
954   Register Src = MI.getOperand(0).getReg();
955   LLT StateTy = MRI.getType(Src);
956   TypeSize StateSize = StateTy.getSizeInBytes();
957   Align TempAlign = getStackTemporaryAlignment(StateTy);
958   MachinePointerInfo TempPtrInfo;
959   auto Temp = createStackTemporary(StateSize, TempAlign, TempPtrInfo);
960 
961   // Put the new state into the temporary.
962   MachineMemOperand *MMO = MF.getMachineMemOperand(
963       TempPtrInfo, MachineMemOperand::MOStore, StateTy, TempAlign);
964   MIRBuilder.buildStore(Src, Temp, *MMO);
965 
966   // Create a call to library function, with the temporary as an argument.
967   unsigned TempAddrSpace = DL.getAllocaAddrSpace();
968   Type *StatePtrTy = PointerType::get(Ctx, TempAddrSpace);
969   RTLIB::Libcall RTLibcall = getStateLibraryFunctionFor(MI, TLI);
970   return createLibcall(MIRBuilder, RTLibcall,
971                        CallLowering::ArgInfo({0}, Type::getVoidTy(Ctx), 0),
972                        CallLowering::ArgInfo({Temp.getReg(0), StatePtrTy, 0}),
973                        LocObserver, nullptr);
974 }
975 
976 // The function is used to legalize operations that set default environment
977 // state. In C library a call like `fesetmode(FE_DFL_MODE)` is used for that.
978 // On most targets supported in glibc FE_DFL_MODE is defined as
979 // `((const femode_t *) -1)`. Such assumption is used here. If for some target
980 // it is not true, the target must provide custom lowering.
981 LegalizerHelper::LegalizeResult
createResetStateLibcall(MachineIRBuilder & MIRBuilder,MachineInstr & MI,LostDebugLocObserver & LocObserver)982 LegalizerHelper::createResetStateLibcall(MachineIRBuilder &MIRBuilder,
983                                          MachineInstr &MI,
984                                          LostDebugLocObserver &LocObserver) {
985   const DataLayout &DL = MIRBuilder.getDataLayout();
986   auto &MF = MIRBuilder.getMF();
987   auto &Ctx = MF.getFunction().getContext();
988 
989   // Create an argument for the library function.
990   unsigned AddrSpace = DL.getDefaultGlobalsAddressSpace();
991   Type *StatePtrTy = PointerType::get(Ctx, AddrSpace);
992   unsigned PtrSize = DL.getPointerSizeInBits(AddrSpace);
993   LLT MemTy = LLT::pointer(AddrSpace, PtrSize);
994   auto DefValue = MIRBuilder.buildConstant(LLT::scalar(PtrSize), -1LL);
995   DstOp Dest(MRI.createGenericVirtualRegister(MemTy));
996   MIRBuilder.buildIntToPtr(Dest, DefValue);
997 
998   RTLIB::Libcall RTLibcall = getStateLibraryFunctionFor(MI, TLI);
999   return createLibcall(MIRBuilder, RTLibcall,
1000                        CallLowering::ArgInfo({0}, Type::getVoidTy(Ctx), 0),
1001                        CallLowering::ArgInfo({Dest.getReg(), StatePtrTy, 0}),
1002                        LocObserver, &MI);
1003 }
1004 
1005 LegalizerHelper::LegalizeResult
libcall(MachineInstr & MI,LostDebugLocObserver & LocObserver)1006 LegalizerHelper::libcall(MachineInstr &MI, LostDebugLocObserver &LocObserver) {
1007   auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
1008 
1009   switch (MI.getOpcode()) {
1010   default:
1011     return UnableToLegalize;
1012   case TargetOpcode::G_MUL:
1013   case TargetOpcode::G_SDIV:
1014   case TargetOpcode::G_UDIV:
1015   case TargetOpcode::G_SREM:
1016   case TargetOpcode::G_UREM:
1017   case TargetOpcode::G_CTLZ_ZERO_UNDEF: {
1018     LLT LLTy = MRI.getType(MI.getOperand(0).getReg());
1019     unsigned Size = LLTy.getSizeInBits();
1020     Type *HLTy = IntegerType::get(Ctx, Size);
1021     auto Status = simpleLibcall(MI, MIRBuilder, Size, HLTy, LocObserver);
1022     if (Status != Legalized)
1023       return Status;
1024     break;
1025   }
1026   case TargetOpcode::G_FADD:
1027   case TargetOpcode::G_FSUB:
1028   case TargetOpcode::G_FMUL:
1029   case TargetOpcode::G_FDIV:
1030   case TargetOpcode::G_FMA:
1031   case TargetOpcode::G_FPOW:
1032   case TargetOpcode::G_FREM:
1033   case TargetOpcode::G_FCOS:
1034   case TargetOpcode::G_FSIN:
1035   case TargetOpcode::G_FLOG10:
1036   case TargetOpcode::G_FLOG:
1037   case TargetOpcode::G_FLOG2:
1038   case TargetOpcode::G_FLDEXP:
1039   case TargetOpcode::G_FEXP:
1040   case TargetOpcode::G_FEXP2:
1041   case TargetOpcode::G_FEXP10:
1042   case TargetOpcode::G_FCEIL:
1043   case TargetOpcode::G_FFLOOR:
1044   case TargetOpcode::G_FMINNUM:
1045   case TargetOpcode::G_FMAXNUM:
1046   case TargetOpcode::G_FSQRT:
1047   case TargetOpcode::G_FRINT:
1048   case TargetOpcode::G_FNEARBYINT:
1049   case TargetOpcode::G_INTRINSIC_ROUNDEVEN: {
1050     LLT LLTy = MRI.getType(MI.getOperand(0).getReg());
1051     unsigned Size = LLTy.getSizeInBits();
1052     Type *HLTy = getFloatTypeForLLT(Ctx, LLTy);
1053     if (!HLTy || (Size != 32 && Size != 64 && Size != 80 && Size != 128)) {
1054       LLVM_DEBUG(dbgs() << "No libcall available for type " << LLTy << ".\n");
1055       return UnableToLegalize;
1056     }
1057     auto Status = simpleLibcall(MI, MIRBuilder, Size, HLTy, LocObserver);
1058     if (Status != Legalized)
1059       return Status;
1060     break;
1061   }
1062   case TargetOpcode::G_FPOWI: {
1063     LLT LLTy = MRI.getType(MI.getOperand(0).getReg());
1064     unsigned Size = LLTy.getSizeInBits();
1065     Type *HLTy = getFloatTypeForLLT(Ctx, LLTy);
1066     Type *ITy = IntegerType::get(
1067         Ctx, MRI.getType(MI.getOperand(2).getReg()).getSizeInBits());
1068     if (!HLTy || (Size != 32 && Size != 64 && Size != 80 && Size != 128)) {
1069       LLVM_DEBUG(dbgs() << "No libcall available for type " << LLTy << ".\n");
1070       return UnableToLegalize;
1071     }
1072     auto Libcall = getRTLibDesc(MI.getOpcode(), Size);
1073     std::initializer_list<CallLowering::ArgInfo> Args = {
1074         {MI.getOperand(1).getReg(), HLTy, 0},
1075         {MI.getOperand(2).getReg(), ITy, 1}};
1076     LegalizeResult Status =
1077         createLibcall(MIRBuilder, Libcall, {MI.getOperand(0).getReg(), HLTy, 0},
1078                       Args, LocObserver, &MI);
1079     if (Status != Legalized)
1080       return Status;
1081     break;
1082   }
1083   case TargetOpcode::G_FPEXT:
1084   case TargetOpcode::G_FPTRUNC: {
1085     Type *FromTy = getFloatTypeForLLT(Ctx,  MRI.getType(MI.getOperand(1).getReg()));
1086     Type *ToTy = getFloatTypeForLLT(Ctx, MRI.getType(MI.getOperand(0).getReg()));
1087     if (!FromTy || !ToTy)
1088       return UnableToLegalize;
1089     LegalizeResult Status =
1090         conversionLibcall(MI, MIRBuilder, ToTy, FromTy, LocObserver);
1091     if (Status != Legalized)
1092       return Status;
1093     break;
1094   }
1095   case TargetOpcode::G_FPTOSI:
1096   case TargetOpcode::G_FPTOUI: {
1097     // FIXME: Support other types
1098     unsigned FromSize = MRI.getType(MI.getOperand(1).getReg()).getSizeInBits();
1099     unsigned ToSize = MRI.getType(MI.getOperand(0).getReg()).getSizeInBits();
1100     if ((ToSize != 32 && ToSize != 64) || (FromSize != 32 && FromSize != 64))
1101       return UnableToLegalize;
1102     LegalizeResult Status = conversionLibcall(
1103         MI, MIRBuilder,
1104         ToSize == 32 ? Type::getInt32Ty(Ctx) : Type::getInt64Ty(Ctx),
1105         FromSize == 64 ? Type::getDoubleTy(Ctx) : Type::getFloatTy(Ctx),
1106         LocObserver);
1107     if (Status != Legalized)
1108       return Status;
1109     break;
1110   }
1111   case TargetOpcode::G_SITOFP:
1112   case TargetOpcode::G_UITOFP: {
1113     // FIXME: Support other types
1114     unsigned FromSize = MRI.getType(MI.getOperand(1).getReg()).getSizeInBits();
1115     unsigned ToSize = MRI.getType(MI.getOperand(0).getReg()).getSizeInBits();
1116     if ((FromSize != 32 && FromSize != 64) || (ToSize != 32 && ToSize != 64))
1117       return UnableToLegalize;
1118     LegalizeResult Status = conversionLibcall(
1119         MI, MIRBuilder,
1120         ToSize == 64 ? Type::getDoubleTy(Ctx) : Type::getFloatTy(Ctx),
1121         FromSize == 32 ? Type::getInt32Ty(Ctx) : Type::getInt64Ty(Ctx),
1122         LocObserver);
1123     if (Status != Legalized)
1124       return Status;
1125     break;
1126   }
1127   case TargetOpcode::G_ATOMICRMW_XCHG:
1128   case TargetOpcode::G_ATOMICRMW_ADD:
1129   case TargetOpcode::G_ATOMICRMW_SUB:
1130   case TargetOpcode::G_ATOMICRMW_AND:
1131   case TargetOpcode::G_ATOMICRMW_OR:
1132   case TargetOpcode::G_ATOMICRMW_XOR:
1133   case TargetOpcode::G_ATOMIC_CMPXCHG:
1134   case TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS: {
1135     auto Status = createAtomicLibcall(MIRBuilder, MI);
1136     if (Status != Legalized)
1137       return Status;
1138     break;
1139   }
1140   case TargetOpcode::G_BZERO:
1141   case TargetOpcode::G_MEMCPY:
1142   case TargetOpcode::G_MEMMOVE:
1143   case TargetOpcode::G_MEMSET: {
1144     LegalizeResult Result =
1145         createMemLibcall(MIRBuilder, *MIRBuilder.getMRI(), MI, LocObserver);
1146     if (Result != Legalized)
1147       return Result;
1148     MI.eraseFromParent();
1149     return Result;
1150   }
1151   case TargetOpcode::G_GET_FPENV:
1152   case TargetOpcode::G_GET_FPMODE: {
1153     LegalizeResult Result = createGetStateLibcall(MIRBuilder, MI, LocObserver);
1154     if (Result != Legalized)
1155       return Result;
1156     break;
1157   }
1158   case TargetOpcode::G_SET_FPENV:
1159   case TargetOpcode::G_SET_FPMODE: {
1160     LegalizeResult Result = createSetStateLibcall(MIRBuilder, MI, LocObserver);
1161     if (Result != Legalized)
1162       return Result;
1163     break;
1164   }
1165   case TargetOpcode::G_RESET_FPENV:
1166   case TargetOpcode::G_RESET_FPMODE: {
1167     LegalizeResult Result =
1168         createResetStateLibcall(MIRBuilder, MI, LocObserver);
1169     if (Result != Legalized)
1170       return Result;
1171     break;
1172   }
1173   }
1174 
1175   MI.eraseFromParent();
1176   return Legalized;
1177 }
1178 
narrowScalar(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)1179 LegalizerHelper::LegalizeResult LegalizerHelper::narrowScalar(MachineInstr &MI,
1180                                                               unsigned TypeIdx,
1181                                                               LLT NarrowTy) {
1182   uint64_t SizeOp0 = MRI.getType(MI.getOperand(0).getReg()).getSizeInBits();
1183   uint64_t NarrowSize = NarrowTy.getSizeInBits();
1184 
1185   switch (MI.getOpcode()) {
1186   default:
1187     return UnableToLegalize;
1188   case TargetOpcode::G_IMPLICIT_DEF: {
1189     Register DstReg = MI.getOperand(0).getReg();
1190     LLT DstTy = MRI.getType(DstReg);
1191 
1192     // If SizeOp0 is not an exact multiple of NarrowSize, emit
1193     // G_ANYEXT(G_IMPLICIT_DEF). Cast result to vector if needed.
1194     // FIXME: Although this would also be legal for the general case, it causes
1195     //  a lot of regressions in the emitted code (superfluous COPYs, artifact
1196     //  combines not being hit). This seems to be a problem related to the
1197     //  artifact combiner.
1198     if (SizeOp0 % NarrowSize != 0) {
1199       LLT ImplicitTy = NarrowTy;
1200       if (DstTy.isVector())
1201         ImplicitTy = LLT::vector(DstTy.getElementCount(), ImplicitTy);
1202 
1203       Register ImplicitReg = MIRBuilder.buildUndef(ImplicitTy).getReg(0);
1204       MIRBuilder.buildAnyExt(DstReg, ImplicitReg);
1205 
1206       MI.eraseFromParent();
1207       return Legalized;
1208     }
1209 
1210     int NumParts = SizeOp0 / NarrowSize;
1211 
1212     SmallVector<Register, 2> DstRegs;
1213     for (int i = 0; i < NumParts; ++i)
1214       DstRegs.push_back(MIRBuilder.buildUndef(NarrowTy).getReg(0));
1215 
1216     if (DstTy.isVector())
1217       MIRBuilder.buildBuildVector(DstReg, DstRegs);
1218     else
1219       MIRBuilder.buildMergeLikeInstr(DstReg, DstRegs);
1220     MI.eraseFromParent();
1221     return Legalized;
1222   }
1223   case TargetOpcode::G_CONSTANT: {
1224     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
1225     const APInt &Val = MI.getOperand(1).getCImm()->getValue();
1226     unsigned TotalSize = Ty.getSizeInBits();
1227     unsigned NarrowSize = NarrowTy.getSizeInBits();
1228     int NumParts = TotalSize / NarrowSize;
1229 
1230     SmallVector<Register, 4> PartRegs;
1231     for (int I = 0; I != NumParts; ++I) {
1232       unsigned Offset = I * NarrowSize;
1233       auto K = MIRBuilder.buildConstant(NarrowTy,
1234                                         Val.lshr(Offset).trunc(NarrowSize));
1235       PartRegs.push_back(K.getReg(0));
1236     }
1237 
1238     LLT LeftoverTy;
1239     unsigned LeftoverBits = TotalSize - NumParts * NarrowSize;
1240     SmallVector<Register, 1> LeftoverRegs;
1241     if (LeftoverBits != 0) {
1242       LeftoverTy = LLT::scalar(LeftoverBits);
1243       auto K = MIRBuilder.buildConstant(
1244         LeftoverTy,
1245         Val.lshr(NumParts * NarrowSize).trunc(LeftoverBits));
1246       LeftoverRegs.push_back(K.getReg(0));
1247     }
1248 
1249     insertParts(MI.getOperand(0).getReg(),
1250                 Ty, NarrowTy, PartRegs, LeftoverTy, LeftoverRegs);
1251 
1252     MI.eraseFromParent();
1253     return Legalized;
1254   }
1255   case TargetOpcode::G_SEXT:
1256   case TargetOpcode::G_ZEXT:
1257   case TargetOpcode::G_ANYEXT:
1258     return narrowScalarExt(MI, TypeIdx, NarrowTy);
1259   case TargetOpcode::G_TRUNC: {
1260     if (TypeIdx != 1)
1261       return UnableToLegalize;
1262 
1263     uint64_t SizeOp1 = MRI.getType(MI.getOperand(1).getReg()).getSizeInBits();
1264     if (NarrowTy.getSizeInBits() * 2 != SizeOp1) {
1265       LLVM_DEBUG(dbgs() << "Can't narrow trunc to type " << NarrowTy << "\n");
1266       return UnableToLegalize;
1267     }
1268 
1269     auto Unmerge = MIRBuilder.buildUnmerge(NarrowTy, MI.getOperand(1));
1270     MIRBuilder.buildCopy(MI.getOperand(0), Unmerge.getReg(0));
1271     MI.eraseFromParent();
1272     return Legalized;
1273   }
1274 
1275   case TargetOpcode::G_FREEZE: {
1276     if (TypeIdx != 0)
1277       return UnableToLegalize;
1278 
1279     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
1280     // Should widen scalar first
1281     if (Ty.getSizeInBits() % NarrowTy.getSizeInBits() != 0)
1282       return UnableToLegalize;
1283 
1284     auto Unmerge = MIRBuilder.buildUnmerge(NarrowTy, MI.getOperand(1).getReg());
1285     SmallVector<Register, 8> Parts;
1286     for (unsigned i = 0; i < Unmerge->getNumDefs(); ++i) {
1287       Parts.push_back(
1288           MIRBuilder.buildFreeze(NarrowTy, Unmerge.getReg(i)).getReg(0));
1289     }
1290 
1291     MIRBuilder.buildMergeLikeInstr(MI.getOperand(0).getReg(), Parts);
1292     MI.eraseFromParent();
1293     return Legalized;
1294   }
1295   case TargetOpcode::G_ADD:
1296   case TargetOpcode::G_SUB:
1297   case TargetOpcode::G_SADDO:
1298   case TargetOpcode::G_SSUBO:
1299   case TargetOpcode::G_SADDE:
1300   case TargetOpcode::G_SSUBE:
1301   case TargetOpcode::G_UADDO:
1302   case TargetOpcode::G_USUBO:
1303   case TargetOpcode::G_UADDE:
1304   case TargetOpcode::G_USUBE:
1305     return narrowScalarAddSub(MI, TypeIdx, NarrowTy);
1306   case TargetOpcode::G_MUL:
1307   case TargetOpcode::G_UMULH:
1308     return narrowScalarMul(MI, NarrowTy);
1309   case TargetOpcode::G_EXTRACT:
1310     return narrowScalarExtract(MI, TypeIdx, NarrowTy);
1311   case TargetOpcode::G_INSERT:
1312     return narrowScalarInsert(MI, TypeIdx, NarrowTy);
1313   case TargetOpcode::G_LOAD: {
1314     auto &LoadMI = cast<GLoad>(MI);
1315     Register DstReg = LoadMI.getDstReg();
1316     LLT DstTy = MRI.getType(DstReg);
1317     if (DstTy.isVector())
1318       return UnableToLegalize;
1319 
1320     if (8 * LoadMI.getMemSize() != DstTy.getSizeInBits()) {
1321       Register TmpReg = MRI.createGenericVirtualRegister(NarrowTy);
1322       MIRBuilder.buildLoad(TmpReg, LoadMI.getPointerReg(), LoadMI.getMMO());
1323       MIRBuilder.buildAnyExt(DstReg, TmpReg);
1324       LoadMI.eraseFromParent();
1325       return Legalized;
1326     }
1327 
1328     return reduceLoadStoreWidth(LoadMI, TypeIdx, NarrowTy);
1329   }
1330   case TargetOpcode::G_ZEXTLOAD:
1331   case TargetOpcode::G_SEXTLOAD: {
1332     auto &LoadMI = cast<GExtLoad>(MI);
1333     Register DstReg = LoadMI.getDstReg();
1334     Register PtrReg = LoadMI.getPointerReg();
1335 
1336     Register TmpReg = MRI.createGenericVirtualRegister(NarrowTy);
1337     auto &MMO = LoadMI.getMMO();
1338     unsigned MemSize = MMO.getSizeInBits();
1339 
1340     if (MemSize == NarrowSize) {
1341       MIRBuilder.buildLoad(TmpReg, PtrReg, MMO);
1342     } else if (MemSize < NarrowSize) {
1343       MIRBuilder.buildLoadInstr(LoadMI.getOpcode(), TmpReg, PtrReg, MMO);
1344     } else if (MemSize > NarrowSize) {
1345       // FIXME: Need to split the load.
1346       return UnableToLegalize;
1347     }
1348 
1349     if (isa<GZExtLoad>(LoadMI))
1350       MIRBuilder.buildZExt(DstReg, TmpReg);
1351     else
1352       MIRBuilder.buildSExt(DstReg, TmpReg);
1353 
1354     LoadMI.eraseFromParent();
1355     return Legalized;
1356   }
1357   case TargetOpcode::G_STORE: {
1358     auto &StoreMI = cast<GStore>(MI);
1359 
1360     Register SrcReg = StoreMI.getValueReg();
1361     LLT SrcTy = MRI.getType(SrcReg);
1362     if (SrcTy.isVector())
1363       return UnableToLegalize;
1364 
1365     int NumParts = SizeOp0 / NarrowSize;
1366     unsigned HandledSize = NumParts * NarrowTy.getSizeInBits();
1367     unsigned LeftoverBits = SrcTy.getSizeInBits() - HandledSize;
1368     if (SrcTy.isVector() && LeftoverBits != 0)
1369       return UnableToLegalize;
1370 
1371     if (8 * StoreMI.getMemSize() != SrcTy.getSizeInBits()) {
1372       Register TmpReg = MRI.createGenericVirtualRegister(NarrowTy);
1373       MIRBuilder.buildTrunc(TmpReg, SrcReg);
1374       MIRBuilder.buildStore(TmpReg, StoreMI.getPointerReg(), StoreMI.getMMO());
1375       StoreMI.eraseFromParent();
1376       return Legalized;
1377     }
1378 
1379     return reduceLoadStoreWidth(StoreMI, 0, NarrowTy);
1380   }
1381   case TargetOpcode::G_SELECT:
1382     return narrowScalarSelect(MI, TypeIdx, NarrowTy);
1383   case TargetOpcode::G_AND:
1384   case TargetOpcode::G_OR:
1385   case TargetOpcode::G_XOR: {
1386     // Legalize bitwise operation:
1387     // A = BinOp<Ty> B, C
1388     // into:
1389     // B1, ..., BN = G_UNMERGE_VALUES B
1390     // C1, ..., CN = G_UNMERGE_VALUES C
1391     // A1 = BinOp<Ty/N> B1, C2
1392     // ...
1393     // AN = BinOp<Ty/N> BN, CN
1394     // A = G_MERGE_VALUES A1, ..., AN
1395     return narrowScalarBasic(MI, TypeIdx, NarrowTy);
1396   }
1397   case TargetOpcode::G_SHL:
1398   case TargetOpcode::G_LSHR:
1399   case TargetOpcode::G_ASHR:
1400     return narrowScalarShift(MI, TypeIdx, NarrowTy);
1401   case TargetOpcode::G_CTLZ:
1402   case TargetOpcode::G_CTLZ_ZERO_UNDEF:
1403   case TargetOpcode::G_CTTZ:
1404   case TargetOpcode::G_CTTZ_ZERO_UNDEF:
1405   case TargetOpcode::G_CTPOP:
1406     if (TypeIdx == 1)
1407       switch (MI.getOpcode()) {
1408       case TargetOpcode::G_CTLZ:
1409       case TargetOpcode::G_CTLZ_ZERO_UNDEF:
1410         return narrowScalarCTLZ(MI, TypeIdx, NarrowTy);
1411       case TargetOpcode::G_CTTZ:
1412       case TargetOpcode::G_CTTZ_ZERO_UNDEF:
1413         return narrowScalarCTTZ(MI, TypeIdx, NarrowTy);
1414       case TargetOpcode::G_CTPOP:
1415         return narrowScalarCTPOP(MI, TypeIdx, NarrowTy);
1416       default:
1417         return UnableToLegalize;
1418       }
1419 
1420     Observer.changingInstr(MI);
1421     narrowScalarDst(MI, NarrowTy, 0, TargetOpcode::G_ZEXT);
1422     Observer.changedInstr(MI);
1423     return Legalized;
1424   case TargetOpcode::G_INTTOPTR:
1425     if (TypeIdx != 1)
1426       return UnableToLegalize;
1427 
1428     Observer.changingInstr(MI);
1429     narrowScalarSrc(MI, NarrowTy, 1);
1430     Observer.changedInstr(MI);
1431     return Legalized;
1432   case TargetOpcode::G_PTRTOINT:
1433     if (TypeIdx != 0)
1434       return UnableToLegalize;
1435 
1436     Observer.changingInstr(MI);
1437     narrowScalarDst(MI, NarrowTy, 0, TargetOpcode::G_ZEXT);
1438     Observer.changedInstr(MI);
1439     return Legalized;
1440   case TargetOpcode::G_PHI: {
1441     // FIXME: add support for when SizeOp0 isn't an exact multiple of
1442     // NarrowSize.
1443     if (SizeOp0 % NarrowSize != 0)
1444       return UnableToLegalize;
1445 
1446     unsigned NumParts = SizeOp0 / NarrowSize;
1447     SmallVector<Register, 2> DstRegs(NumParts);
1448     SmallVector<SmallVector<Register, 2>, 2> SrcRegs(MI.getNumOperands() / 2);
1449     Observer.changingInstr(MI);
1450     for (unsigned i = 1; i < MI.getNumOperands(); i += 2) {
1451       MachineBasicBlock &OpMBB = *MI.getOperand(i + 1).getMBB();
1452       MIRBuilder.setInsertPt(OpMBB, OpMBB.getFirstTerminatorForward());
1453       extractParts(MI.getOperand(i).getReg(), NarrowTy, NumParts,
1454                    SrcRegs[i / 2], MIRBuilder, MRI);
1455     }
1456     MachineBasicBlock &MBB = *MI.getParent();
1457     MIRBuilder.setInsertPt(MBB, MI);
1458     for (unsigned i = 0; i < NumParts; ++i) {
1459       DstRegs[i] = MRI.createGenericVirtualRegister(NarrowTy);
1460       MachineInstrBuilder MIB =
1461           MIRBuilder.buildInstr(TargetOpcode::G_PHI).addDef(DstRegs[i]);
1462       for (unsigned j = 1; j < MI.getNumOperands(); j += 2)
1463         MIB.addUse(SrcRegs[j / 2][i]).add(MI.getOperand(j + 1));
1464     }
1465     MIRBuilder.setInsertPt(MBB, MBB.getFirstNonPHI());
1466     MIRBuilder.buildMergeLikeInstr(MI.getOperand(0), DstRegs);
1467     Observer.changedInstr(MI);
1468     MI.eraseFromParent();
1469     return Legalized;
1470   }
1471   case TargetOpcode::G_EXTRACT_VECTOR_ELT:
1472   case TargetOpcode::G_INSERT_VECTOR_ELT: {
1473     if (TypeIdx != 2)
1474       return UnableToLegalize;
1475 
1476     int OpIdx = MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3;
1477     Observer.changingInstr(MI);
1478     narrowScalarSrc(MI, NarrowTy, OpIdx);
1479     Observer.changedInstr(MI);
1480     return Legalized;
1481   }
1482   case TargetOpcode::G_ICMP: {
1483     Register LHS = MI.getOperand(2).getReg();
1484     LLT SrcTy = MRI.getType(LHS);
1485     uint64_t SrcSize = SrcTy.getSizeInBits();
1486     CmpInst::Predicate Pred =
1487         static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
1488 
1489     // TODO: Handle the non-equality case for weird sizes.
1490     if (NarrowSize * 2 != SrcSize && !ICmpInst::isEquality(Pred))
1491       return UnableToLegalize;
1492 
1493     LLT LeftoverTy; // Example: s88 -> s64 (NarrowTy) + s24 (leftover)
1494     SmallVector<Register, 4> LHSPartRegs, LHSLeftoverRegs;
1495     if (!extractParts(LHS, SrcTy, NarrowTy, LeftoverTy, LHSPartRegs,
1496                       LHSLeftoverRegs, MIRBuilder, MRI))
1497       return UnableToLegalize;
1498 
1499     LLT Unused; // Matches LeftoverTy; G_ICMP LHS and RHS are the same type.
1500     SmallVector<Register, 4> RHSPartRegs, RHSLeftoverRegs;
1501     if (!extractParts(MI.getOperand(3).getReg(), SrcTy, NarrowTy, Unused,
1502                       RHSPartRegs, RHSLeftoverRegs, MIRBuilder, MRI))
1503       return UnableToLegalize;
1504 
1505     // We now have the LHS and RHS of the compare split into narrow-type
1506     // registers, plus potentially some leftover type.
1507     Register Dst = MI.getOperand(0).getReg();
1508     LLT ResTy = MRI.getType(Dst);
1509     if (ICmpInst::isEquality(Pred)) {
1510       // For each part on the LHS and RHS, keep track of the result of XOR-ing
1511       // them together. For each equal part, the result should be all 0s. For
1512       // each non-equal part, we'll get at least one 1.
1513       auto Zero = MIRBuilder.buildConstant(NarrowTy, 0);
1514       SmallVector<Register, 4> Xors;
1515       for (auto LHSAndRHS : zip(LHSPartRegs, RHSPartRegs)) {
1516         auto LHS = std::get<0>(LHSAndRHS);
1517         auto RHS = std::get<1>(LHSAndRHS);
1518         auto Xor = MIRBuilder.buildXor(NarrowTy, LHS, RHS).getReg(0);
1519         Xors.push_back(Xor);
1520       }
1521 
1522       // Build a G_XOR for each leftover register. Each G_XOR must be widened
1523       // to the desired narrow type so that we can OR them together later.
1524       SmallVector<Register, 4> WidenedXors;
1525       for (auto LHSAndRHS : zip(LHSLeftoverRegs, RHSLeftoverRegs)) {
1526         auto LHS = std::get<0>(LHSAndRHS);
1527         auto RHS = std::get<1>(LHSAndRHS);
1528         auto Xor = MIRBuilder.buildXor(LeftoverTy, LHS, RHS).getReg(0);
1529         LLT GCDTy = extractGCDType(WidenedXors, NarrowTy, LeftoverTy, Xor);
1530         buildLCMMergePieces(LeftoverTy, NarrowTy, GCDTy, WidenedXors,
1531                             /* PadStrategy = */ TargetOpcode::G_ZEXT);
1532         Xors.insert(Xors.end(), WidenedXors.begin(), WidenedXors.end());
1533       }
1534 
1535       // Now, for each part we broke up, we know if they are equal/not equal
1536       // based off the G_XOR. We can OR these all together and compare against
1537       // 0 to get the result.
1538       assert(Xors.size() >= 2 && "Should have gotten at least two Xors?");
1539       auto Or = MIRBuilder.buildOr(NarrowTy, Xors[0], Xors[1]);
1540       for (unsigned I = 2, E = Xors.size(); I < E; ++I)
1541         Or = MIRBuilder.buildOr(NarrowTy, Or, Xors[I]);
1542       MIRBuilder.buildICmp(Pred, Dst, Or, Zero);
1543     } else {
1544       // TODO: Handle non-power-of-two types.
1545       assert(LHSPartRegs.size() == 2 && "Expected exactly 2 LHS part regs?");
1546       assert(RHSPartRegs.size() == 2 && "Expected exactly 2 RHS part regs?");
1547       Register LHSL = LHSPartRegs[0];
1548       Register LHSH = LHSPartRegs[1];
1549       Register RHSL = RHSPartRegs[0];
1550       Register RHSH = RHSPartRegs[1];
1551       MachineInstrBuilder CmpH = MIRBuilder.buildICmp(Pred, ResTy, LHSH, RHSH);
1552       MachineInstrBuilder CmpHEQ =
1553           MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, ResTy, LHSH, RHSH);
1554       MachineInstrBuilder CmpLU = MIRBuilder.buildICmp(
1555           ICmpInst::getUnsignedPredicate(Pred), ResTy, LHSL, RHSL);
1556       MIRBuilder.buildSelect(Dst, CmpHEQ, CmpLU, CmpH);
1557     }
1558     MI.eraseFromParent();
1559     return Legalized;
1560   }
1561   case TargetOpcode::G_SEXT_INREG: {
1562     if (TypeIdx != 0)
1563       return UnableToLegalize;
1564 
1565     int64_t SizeInBits = MI.getOperand(2).getImm();
1566 
1567     // So long as the new type has more bits than the bits we're extending we
1568     // don't need to break it apart.
1569     if (NarrowTy.getScalarSizeInBits() > SizeInBits) {
1570       Observer.changingInstr(MI);
1571       // We don't lose any non-extension bits by truncating the src and
1572       // sign-extending the dst.
1573       MachineOperand &MO1 = MI.getOperand(1);
1574       auto TruncMIB = MIRBuilder.buildTrunc(NarrowTy, MO1);
1575       MO1.setReg(TruncMIB.getReg(0));
1576 
1577       MachineOperand &MO2 = MI.getOperand(0);
1578       Register DstExt = MRI.createGenericVirtualRegister(NarrowTy);
1579       MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
1580       MIRBuilder.buildSExt(MO2, DstExt);
1581       MO2.setReg(DstExt);
1582       Observer.changedInstr(MI);
1583       return Legalized;
1584     }
1585 
1586     // Break it apart. Components below the extension point are unmodified. The
1587     // component containing the extension point becomes a narrower SEXT_INREG.
1588     // Components above it are ashr'd from the component containing the
1589     // extension point.
1590     if (SizeOp0 % NarrowSize != 0)
1591       return UnableToLegalize;
1592     int NumParts = SizeOp0 / NarrowSize;
1593 
1594     // List the registers where the destination will be scattered.
1595     SmallVector<Register, 2> DstRegs;
1596     // List the registers where the source will be split.
1597     SmallVector<Register, 2> SrcRegs;
1598 
1599     // Create all the temporary registers.
1600     for (int i = 0; i < NumParts; ++i) {
1601       Register SrcReg = MRI.createGenericVirtualRegister(NarrowTy);
1602 
1603       SrcRegs.push_back(SrcReg);
1604     }
1605 
1606     // Explode the big arguments into smaller chunks.
1607     MIRBuilder.buildUnmerge(SrcRegs, MI.getOperand(1));
1608 
1609     Register AshrCstReg =
1610         MIRBuilder.buildConstant(NarrowTy, NarrowTy.getScalarSizeInBits() - 1)
1611             .getReg(0);
1612     Register FullExtensionReg;
1613     Register PartialExtensionReg;
1614 
1615     // Do the operation on each small part.
1616     for (int i = 0; i < NumParts; ++i) {
1617       if ((i + 1) * NarrowTy.getScalarSizeInBits() <= SizeInBits) {
1618         DstRegs.push_back(SrcRegs[i]);
1619         PartialExtensionReg = DstRegs.back();
1620       } else if (i * NarrowTy.getScalarSizeInBits() >= SizeInBits) {
1621         assert(PartialExtensionReg &&
1622                "Expected to visit partial extension before full");
1623         if (FullExtensionReg) {
1624           DstRegs.push_back(FullExtensionReg);
1625           continue;
1626         }
1627         DstRegs.push_back(
1628             MIRBuilder.buildAShr(NarrowTy, PartialExtensionReg, AshrCstReg)
1629                 .getReg(0));
1630         FullExtensionReg = DstRegs.back();
1631       } else {
1632         DstRegs.push_back(
1633             MIRBuilder
1634                 .buildInstr(
1635                     TargetOpcode::G_SEXT_INREG, {NarrowTy},
1636                     {SrcRegs[i], SizeInBits % NarrowTy.getScalarSizeInBits()})
1637                 .getReg(0));
1638         PartialExtensionReg = DstRegs.back();
1639       }
1640     }
1641 
1642     // Gather the destination registers into the final destination.
1643     Register DstReg = MI.getOperand(0).getReg();
1644     MIRBuilder.buildMergeLikeInstr(DstReg, DstRegs);
1645     MI.eraseFromParent();
1646     return Legalized;
1647   }
1648   case TargetOpcode::G_BSWAP:
1649   case TargetOpcode::G_BITREVERSE: {
1650     if (SizeOp0 % NarrowSize != 0)
1651       return UnableToLegalize;
1652 
1653     Observer.changingInstr(MI);
1654     SmallVector<Register, 2> SrcRegs, DstRegs;
1655     unsigned NumParts = SizeOp0 / NarrowSize;
1656     extractParts(MI.getOperand(1).getReg(), NarrowTy, NumParts, SrcRegs,
1657                  MIRBuilder, MRI);
1658 
1659     for (unsigned i = 0; i < NumParts; ++i) {
1660       auto DstPart = MIRBuilder.buildInstr(MI.getOpcode(), {NarrowTy},
1661                                            {SrcRegs[NumParts - 1 - i]});
1662       DstRegs.push_back(DstPart.getReg(0));
1663     }
1664 
1665     MIRBuilder.buildMergeLikeInstr(MI.getOperand(0), DstRegs);
1666 
1667     Observer.changedInstr(MI);
1668     MI.eraseFromParent();
1669     return Legalized;
1670   }
1671   case TargetOpcode::G_PTR_ADD:
1672   case TargetOpcode::G_PTRMASK: {
1673     if (TypeIdx != 1)
1674       return UnableToLegalize;
1675     Observer.changingInstr(MI);
1676     narrowScalarSrc(MI, NarrowTy, 2);
1677     Observer.changedInstr(MI);
1678     return Legalized;
1679   }
1680   case TargetOpcode::G_FPTOUI:
1681   case TargetOpcode::G_FPTOSI:
1682     return narrowScalarFPTOI(MI, TypeIdx, NarrowTy);
1683   case TargetOpcode::G_FPEXT:
1684     if (TypeIdx != 0)
1685       return UnableToLegalize;
1686     Observer.changingInstr(MI);
1687     narrowScalarDst(MI, NarrowTy, 0, TargetOpcode::G_FPEXT);
1688     Observer.changedInstr(MI);
1689     return Legalized;
1690   case TargetOpcode::G_FLDEXP:
1691   case TargetOpcode::G_STRICT_FLDEXP:
1692     return narrowScalarFLDEXP(MI, TypeIdx, NarrowTy);
1693   }
1694 }
1695 
coerceToScalar(Register Val)1696 Register LegalizerHelper::coerceToScalar(Register Val) {
1697   LLT Ty = MRI.getType(Val);
1698   if (Ty.isScalar())
1699     return Val;
1700 
1701   const DataLayout &DL = MIRBuilder.getDataLayout();
1702   LLT NewTy = LLT::scalar(Ty.getSizeInBits());
1703   if (Ty.isPointer()) {
1704     if (DL.isNonIntegralAddressSpace(Ty.getAddressSpace()))
1705       return Register();
1706     return MIRBuilder.buildPtrToInt(NewTy, Val).getReg(0);
1707   }
1708 
1709   Register NewVal = Val;
1710 
1711   assert(Ty.isVector());
1712   LLT EltTy = Ty.getElementType();
1713   if (EltTy.isPointer())
1714     NewVal = MIRBuilder.buildPtrToInt(NewTy, NewVal).getReg(0);
1715   return MIRBuilder.buildBitcast(NewTy, NewVal).getReg(0);
1716 }
1717 
widenScalarSrc(MachineInstr & MI,LLT WideTy,unsigned OpIdx,unsigned ExtOpcode)1718 void LegalizerHelper::widenScalarSrc(MachineInstr &MI, LLT WideTy,
1719                                      unsigned OpIdx, unsigned ExtOpcode) {
1720   MachineOperand &MO = MI.getOperand(OpIdx);
1721   auto ExtB = MIRBuilder.buildInstr(ExtOpcode, {WideTy}, {MO});
1722   MO.setReg(ExtB.getReg(0));
1723 }
1724 
narrowScalarSrc(MachineInstr & MI,LLT NarrowTy,unsigned OpIdx)1725 void LegalizerHelper::narrowScalarSrc(MachineInstr &MI, LLT NarrowTy,
1726                                       unsigned OpIdx) {
1727   MachineOperand &MO = MI.getOperand(OpIdx);
1728   auto ExtB = MIRBuilder.buildTrunc(NarrowTy, MO);
1729   MO.setReg(ExtB.getReg(0));
1730 }
1731 
widenScalarDst(MachineInstr & MI,LLT WideTy,unsigned OpIdx,unsigned TruncOpcode)1732 void LegalizerHelper::widenScalarDst(MachineInstr &MI, LLT WideTy,
1733                                      unsigned OpIdx, unsigned TruncOpcode) {
1734   MachineOperand &MO = MI.getOperand(OpIdx);
1735   Register DstExt = MRI.createGenericVirtualRegister(WideTy);
1736   MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
1737   MIRBuilder.buildInstr(TruncOpcode, {MO}, {DstExt});
1738   MO.setReg(DstExt);
1739 }
1740 
narrowScalarDst(MachineInstr & MI,LLT NarrowTy,unsigned OpIdx,unsigned ExtOpcode)1741 void LegalizerHelper::narrowScalarDst(MachineInstr &MI, LLT NarrowTy,
1742                                       unsigned OpIdx, unsigned ExtOpcode) {
1743   MachineOperand &MO = MI.getOperand(OpIdx);
1744   Register DstTrunc = MRI.createGenericVirtualRegister(NarrowTy);
1745   MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
1746   MIRBuilder.buildInstr(ExtOpcode, {MO}, {DstTrunc});
1747   MO.setReg(DstTrunc);
1748 }
1749 
moreElementsVectorDst(MachineInstr & MI,LLT WideTy,unsigned OpIdx)1750 void LegalizerHelper::moreElementsVectorDst(MachineInstr &MI, LLT WideTy,
1751                                             unsigned OpIdx) {
1752   MachineOperand &MO = MI.getOperand(OpIdx);
1753   MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
1754   Register Dst = MO.getReg();
1755   Register DstExt = MRI.createGenericVirtualRegister(WideTy);
1756   MO.setReg(DstExt);
1757   MIRBuilder.buildDeleteTrailingVectorElements(Dst, DstExt);
1758 }
1759 
moreElementsVectorSrc(MachineInstr & MI,LLT MoreTy,unsigned OpIdx)1760 void LegalizerHelper::moreElementsVectorSrc(MachineInstr &MI, LLT MoreTy,
1761                                             unsigned OpIdx) {
1762   MachineOperand &MO = MI.getOperand(OpIdx);
1763   SmallVector<Register, 8> Regs;
1764   MO.setReg(MIRBuilder.buildPadVectorWithUndefElements(MoreTy, MO).getReg(0));
1765 }
1766 
bitcastSrc(MachineInstr & MI,LLT CastTy,unsigned OpIdx)1767 void LegalizerHelper::bitcastSrc(MachineInstr &MI, LLT CastTy, unsigned OpIdx) {
1768   MachineOperand &Op = MI.getOperand(OpIdx);
1769   Op.setReg(MIRBuilder.buildBitcast(CastTy, Op).getReg(0));
1770 }
1771 
bitcastDst(MachineInstr & MI,LLT CastTy,unsigned OpIdx)1772 void LegalizerHelper::bitcastDst(MachineInstr &MI, LLT CastTy, unsigned OpIdx) {
1773   MachineOperand &MO = MI.getOperand(OpIdx);
1774   Register CastDst = MRI.createGenericVirtualRegister(CastTy);
1775   MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
1776   MIRBuilder.buildBitcast(MO, CastDst);
1777   MO.setReg(CastDst);
1778 }
1779 
1780 LegalizerHelper::LegalizeResult
widenScalarMergeValues(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)1781 LegalizerHelper::widenScalarMergeValues(MachineInstr &MI, unsigned TypeIdx,
1782                                         LLT WideTy) {
1783   if (TypeIdx != 1)
1784     return UnableToLegalize;
1785 
1786   auto [DstReg, DstTy, Src1Reg, Src1Ty] = MI.getFirst2RegLLTs();
1787   if (DstTy.isVector())
1788     return UnableToLegalize;
1789 
1790   LLT SrcTy = MRI.getType(Src1Reg);
1791   const int DstSize = DstTy.getSizeInBits();
1792   const int SrcSize = SrcTy.getSizeInBits();
1793   const int WideSize = WideTy.getSizeInBits();
1794   const int NumMerge = (DstSize + WideSize - 1) / WideSize;
1795 
1796   unsigned NumOps = MI.getNumOperands();
1797   unsigned NumSrc = MI.getNumOperands() - 1;
1798   unsigned PartSize = DstTy.getSizeInBits() / NumSrc;
1799 
1800   if (WideSize >= DstSize) {
1801     // Directly pack the bits in the target type.
1802     Register ResultReg = MIRBuilder.buildZExt(WideTy, Src1Reg).getReg(0);
1803 
1804     for (unsigned I = 2; I != NumOps; ++I) {
1805       const unsigned Offset = (I - 1) * PartSize;
1806 
1807       Register SrcReg = MI.getOperand(I).getReg();
1808       assert(MRI.getType(SrcReg) == LLT::scalar(PartSize));
1809 
1810       auto ZextInput = MIRBuilder.buildZExt(WideTy, SrcReg);
1811 
1812       Register NextResult = I + 1 == NumOps && WideTy == DstTy ? DstReg :
1813         MRI.createGenericVirtualRegister(WideTy);
1814 
1815       auto ShiftAmt = MIRBuilder.buildConstant(WideTy, Offset);
1816       auto Shl = MIRBuilder.buildShl(WideTy, ZextInput, ShiftAmt);
1817       MIRBuilder.buildOr(NextResult, ResultReg, Shl);
1818       ResultReg = NextResult;
1819     }
1820 
1821     if (WideSize > DstSize)
1822       MIRBuilder.buildTrunc(DstReg, ResultReg);
1823     else if (DstTy.isPointer())
1824       MIRBuilder.buildIntToPtr(DstReg, ResultReg);
1825 
1826     MI.eraseFromParent();
1827     return Legalized;
1828   }
1829 
1830   // Unmerge the original values to the GCD type, and recombine to the next
1831   // multiple greater than the original type.
1832   //
1833   // %3:_(s12) = G_MERGE_VALUES %0:_(s4), %1:_(s4), %2:_(s4) -> s6
1834   // %4:_(s2), %5:_(s2) = G_UNMERGE_VALUES %0
1835   // %6:_(s2), %7:_(s2) = G_UNMERGE_VALUES %1
1836   // %8:_(s2), %9:_(s2) = G_UNMERGE_VALUES %2
1837   // %10:_(s6) = G_MERGE_VALUES %4, %5, %6
1838   // %11:_(s6) = G_MERGE_VALUES %7, %8, %9
1839   // %12:_(s12) = G_MERGE_VALUES %10, %11
1840   //
1841   // Padding with undef if necessary:
1842   //
1843   // %2:_(s8) = G_MERGE_VALUES %0:_(s4), %1:_(s4) -> s6
1844   // %3:_(s2), %4:_(s2) = G_UNMERGE_VALUES %0
1845   // %5:_(s2), %6:_(s2) = G_UNMERGE_VALUES %1
1846   // %7:_(s2) = G_IMPLICIT_DEF
1847   // %8:_(s6) = G_MERGE_VALUES %3, %4, %5
1848   // %9:_(s6) = G_MERGE_VALUES %6, %7, %7
1849   // %10:_(s12) = G_MERGE_VALUES %8, %9
1850 
1851   const int GCD = std::gcd(SrcSize, WideSize);
1852   LLT GCDTy = LLT::scalar(GCD);
1853 
1854   SmallVector<Register, 8> Parts;
1855   SmallVector<Register, 8> NewMergeRegs;
1856   SmallVector<Register, 8> Unmerges;
1857   LLT WideDstTy = LLT::scalar(NumMerge * WideSize);
1858 
1859   // Decompose the original operands if they don't evenly divide.
1860   for (const MachineOperand &MO : llvm::drop_begin(MI.operands())) {
1861     Register SrcReg = MO.getReg();
1862     if (GCD == SrcSize) {
1863       Unmerges.push_back(SrcReg);
1864     } else {
1865       auto Unmerge = MIRBuilder.buildUnmerge(GCDTy, SrcReg);
1866       for (int J = 0, JE = Unmerge->getNumOperands() - 1; J != JE; ++J)
1867         Unmerges.push_back(Unmerge.getReg(J));
1868     }
1869   }
1870 
1871   // Pad with undef to the next size that is a multiple of the requested size.
1872   if (static_cast<int>(Unmerges.size()) != NumMerge * WideSize) {
1873     Register UndefReg = MIRBuilder.buildUndef(GCDTy).getReg(0);
1874     for (int I = Unmerges.size(); I != NumMerge * WideSize; ++I)
1875       Unmerges.push_back(UndefReg);
1876   }
1877 
1878   const int PartsPerGCD = WideSize / GCD;
1879 
1880   // Build merges of each piece.
1881   ArrayRef<Register> Slicer(Unmerges);
1882   for (int I = 0; I != NumMerge; ++I, Slicer = Slicer.drop_front(PartsPerGCD)) {
1883     auto Merge =
1884         MIRBuilder.buildMergeLikeInstr(WideTy, Slicer.take_front(PartsPerGCD));
1885     NewMergeRegs.push_back(Merge.getReg(0));
1886   }
1887 
1888   // A truncate may be necessary if the requested type doesn't evenly divide the
1889   // original result type.
1890   if (DstTy.getSizeInBits() == WideDstTy.getSizeInBits()) {
1891     MIRBuilder.buildMergeLikeInstr(DstReg, NewMergeRegs);
1892   } else {
1893     auto FinalMerge = MIRBuilder.buildMergeLikeInstr(WideDstTy, NewMergeRegs);
1894     MIRBuilder.buildTrunc(DstReg, FinalMerge.getReg(0));
1895   }
1896 
1897   MI.eraseFromParent();
1898   return Legalized;
1899 }
1900 
1901 LegalizerHelper::LegalizeResult
widenScalarUnmergeValues(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)1902 LegalizerHelper::widenScalarUnmergeValues(MachineInstr &MI, unsigned TypeIdx,
1903                                           LLT WideTy) {
1904   if (TypeIdx != 0)
1905     return UnableToLegalize;
1906 
1907   int NumDst = MI.getNumOperands() - 1;
1908   Register SrcReg = MI.getOperand(NumDst).getReg();
1909   LLT SrcTy = MRI.getType(SrcReg);
1910   if (SrcTy.isVector())
1911     return UnableToLegalize;
1912 
1913   Register Dst0Reg = MI.getOperand(0).getReg();
1914   LLT DstTy = MRI.getType(Dst0Reg);
1915   if (!DstTy.isScalar())
1916     return UnableToLegalize;
1917 
1918   if (WideTy.getSizeInBits() >= SrcTy.getSizeInBits()) {
1919     if (SrcTy.isPointer()) {
1920       const DataLayout &DL = MIRBuilder.getDataLayout();
1921       if (DL.isNonIntegralAddressSpace(SrcTy.getAddressSpace())) {
1922         LLVM_DEBUG(
1923             dbgs() << "Not casting non-integral address space integer\n");
1924         return UnableToLegalize;
1925       }
1926 
1927       SrcTy = LLT::scalar(SrcTy.getSizeInBits());
1928       SrcReg = MIRBuilder.buildPtrToInt(SrcTy, SrcReg).getReg(0);
1929     }
1930 
1931     // Widen SrcTy to WideTy. This does not affect the result, but since the
1932     // user requested this size, it is probably better handled than SrcTy and
1933     // should reduce the total number of legalization artifacts.
1934     if (WideTy.getSizeInBits() > SrcTy.getSizeInBits()) {
1935       SrcTy = WideTy;
1936       SrcReg = MIRBuilder.buildAnyExt(WideTy, SrcReg).getReg(0);
1937     }
1938 
1939     // Theres no unmerge type to target. Directly extract the bits from the
1940     // source type
1941     unsigned DstSize = DstTy.getSizeInBits();
1942 
1943     MIRBuilder.buildTrunc(Dst0Reg, SrcReg);
1944     for (int I = 1; I != NumDst; ++I) {
1945       auto ShiftAmt = MIRBuilder.buildConstant(SrcTy, DstSize * I);
1946       auto Shr = MIRBuilder.buildLShr(SrcTy, SrcReg, ShiftAmt);
1947       MIRBuilder.buildTrunc(MI.getOperand(I), Shr);
1948     }
1949 
1950     MI.eraseFromParent();
1951     return Legalized;
1952   }
1953 
1954   // Extend the source to a wider type.
1955   LLT LCMTy = getLCMType(SrcTy, WideTy);
1956 
1957   Register WideSrc = SrcReg;
1958   if (LCMTy.getSizeInBits() != SrcTy.getSizeInBits()) {
1959     // TODO: If this is an integral address space, cast to integer and anyext.
1960     if (SrcTy.isPointer()) {
1961       LLVM_DEBUG(dbgs() << "Widening pointer source types not implemented\n");
1962       return UnableToLegalize;
1963     }
1964 
1965     WideSrc = MIRBuilder.buildAnyExt(LCMTy, WideSrc).getReg(0);
1966   }
1967 
1968   auto Unmerge = MIRBuilder.buildUnmerge(WideTy, WideSrc);
1969 
1970   // Create a sequence of unmerges and merges to the original results. Since we
1971   // may have widened the source, we will need to pad the results with dead defs
1972   // to cover the source register.
1973   // e.g. widen s48 to s64:
1974   // %1:_(s48), %2:_(s48) = G_UNMERGE_VALUES %0:_(s96)
1975   //
1976   // =>
1977   //  %4:_(s192) = G_ANYEXT %0:_(s96)
1978   //  %5:_(s64), %6, %7 = G_UNMERGE_VALUES %4 ; Requested unmerge
1979   //  ; unpack to GCD type, with extra dead defs
1980   //  %8:_(s16), %9, %10, %11 = G_UNMERGE_VALUES %5:_(s64)
1981   //  %12:_(s16), %13, dead %14, dead %15 = G_UNMERGE_VALUES %6:_(s64)
1982   //  dead %16:_(s16), dead %17, dead %18, dead %18 = G_UNMERGE_VALUES %7:_(s64)
1983   //  %1:_(s48) = G_MERGE_VALUES %8:_(s16), %9, %10   ; Remerge to destination
1984   //  %2:_(s48) = G_MERGE_VALUES %11:_(s16), %12, %13 ; Remerge to destination
1985   const LLT GCDTy = getGCDType(WideTy, DstTy);
1986   const int NumUnmerge = Unmerge->getNumOperands() - 1;
1987   const int PartsPerRemerge = DstTy.getSizeInBits() / GCDTy.getSizeInBits();
1988 
1989   // Directly unmerge to the destination without going through a GCD type
1990   // if possible
1991   if (PartsPerRemerge == 1) {
1992     const int PartsPerUnmerge = WideTy.getSizeInBits() / DstTy.getSizeInBits();
1993 
1994     for (int I = 0; I != NumUnmerge; ++I) {
1995       auto MIB = MIRBuilder.buildInstr(TargetOpcode::G_UNMERGE_VALUES);
1996 
1997       for (int J = 0; J != PartsPerUnmerge; ++J) {
1998         int Idx = I * PartsPerUnmerge + J;
1999         if (Idx < NumDst)
2000           MIB.addDef(MI.getOperand(Idx).getReg());
2001         else {
2002           // Create dead def for excess components.
2003           MIB.addDef(MRI.createGenericVirtualRegister(DstTy));
2004         }
2005       }
2006 
2007       MIB.addUse(Unmerge.getReg(I));
2008     }
2009   } else {
2010     SmallVector<Register, 16> Parts;
2011     for (int J = 0; J != NumUnmerge; ++J)
2012       extractGCDType(Parts, GCDTy, Unmerge.getReg(J));
2013 
2014     SmallVector<Register, 8> RemergeParts;
2015     for (int I = 0; I != NumDst; ++I) {
2016       for (int J = 0; J < PartsPerRemerge; ++J) {
2017         const int Idx = I * PartsPerRemerge + J;
2018         RemergeParts.emplace_back(Parts[Idx]);
2019       }
2020 
2021       MIRBuilder.buildMergeLikeInstr(MI.getOperand(I).getReg(), RemergeParts);
2022       RemergeParts.clear();
2023     }
2024   }
2025 
2026   MI.eraseFromParent();
2027   return Legalized;
2028 }
2029 
2030 LegalizerHelper::LegalizeResult
widenScalarExtract(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)2031 LegalizerHelper::widenScalarExtract(MachineInstr &MI, unsigned TypeIdx,
2032                                     LLT WideTy) {
2033   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
2034   unsigned Offset = MI.getOperand(2).getImm();
2035 
2036   if (TypeIdx == 0) {
2037     if (SrcTy.isVector() || DstTy.isVector())
2038       return UnableToLegalize;
2039 
2040     SrcOp Src(SrcReg);
2041     if (SrcTy.isPointer()) {
2042       // Extracts from pointers can be handled only if they are really just
2043       // simple integers.
2044       const DataLayout &DL = MIRBuilder.getDataLayout();
2045       if (DL.isNonIntegralAddressSpace(SrcTy.getAddressSpace()))
2046         return UnableToLegalize;
2047 
2048       LLT SrcAsIntTy = LLT::scalar(SrcTy.getSizeInBits());
2049       Src = MIRBuilder.buildPtrToInt(SrcAsIntTy, Src);
2050       SrcTy = SrcAsIntTy;
2051     }
2052 
2053     if (DstTy.isPointer())
2054       return UnableToLegalize;
2055 
2056     if (Offset == 0) {
2057       // Avoid a shift in the degenerate case.
2058       MIRBuilder.buildTrunc(DstReg,
2059                             MIRBuilder.buildAnyExtOrTrunc(WideTy, Src));
2060       MI.eraseFromParent();
2061       return Legalized;
2062     }
2063 
2064     // Do a shift in the source type.
2065     LLT ShiftTy = SrcTy;
2066     if (WideTy.getSizeInBits() > SrcTy.getSizeInBits()) {
2067       Src = MIRBuilder.buildAnyExt(WideTy, Src);
2068       ShiftTy = WideTy;
2069     }
2070 
2071     auto LShr = MIRBuilder.buildLShr(
2072       ShiftTy, Src, MIRBuilder.buildConstant(ShiftTy, Offset));
2073     MIRBuilder.buildTrunc(DstReg, LShr);
2074     MI.eraseFromParent();
2075     return Legalized;
2076   }
2077 
2078   if (SrcTy.isScalar()) {
2079     Observer.changingInstr(MI);
2080     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2081     Observer.changedInstr(MI);
2082     return Legalized;
2083   }
2084 
2085   if (!SrcTy.isVector())
2086     return UnableToLegalize;
2087 
2088   if (DstTy != SrcTy.getElementType())
2089     return UnableToLegalize;
2090 
2091   if (Offset % SrcTy.getScalarSizeInBits() != 0)
2092     return UnableToLegalize;
2093 
2094   Observer.changingInstr(MI);
2095   widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2096 
2097   MI.getOperand(2).setImm((WideTy.getSizeInBits() / SrcTy.getSizeInBits()) *
2098                           Offset);
2099   widenScalarDst(MI, WideTy.getScalarType(), 0);
2100   Observer.changedInstr(MI);
2101   return Legalized;
2102 }
2103 
2104 LegalizerHelper::LegalizeResult
widenScalarInsert(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)2105 LegalizerHelper::widenScalarInsert(MachineInstr &MI, unsigned TypeIdx,
2106                                    LLT WideTy) {
2107   if (TypeIdx != 0 || WideTy.isVector())
2108     return UnableToLegalize;
2109   Observer.changingInstr(MI);
2110   widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2111   widenScalarDst(MI, WideTy);
2112   Observer.changedInstr(MI);
2113   return Legalized;
2114 }
2115 
2116 LegalizerHelper::LegalizeResult
widenScalarAddSubOverflow(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)2117 LegalizerHelper::widenScalarAddSubOverflow(MachineInstr &MI, unsigned TypeIdx,
2118                                            LLT WideTy) {
2119   unsigned Opcode;
2120   unsigned ExtOpcode;
2121   std::optional<Register> CarryIn;
2122   switch (MI.getOpcode()) {
2123   default:
2124     llvm_unreachable("Unexpected opcode!");
2125   case TargetOpcode::G_SADDO:
2126     Opcode = TargetOpcode::G_ADD;
2127     ExtOpcode = TargetOpcode::G_SEXT;
2128     break;
2129   case TargetOpcode::G_SSUBO:
2130     Opcode = TargetOpcode::G_SUB;
2131     ExtOpcode = TargetOpcode::G_SEXT;
2132     break;
2133   case TargetOpcode::G_UADDO:
2134     Opcode = TargetOpcode::G_ADD;
2135     ExtOpcode = TargetOpcode::G_ZEXT;
2136     break;
2137   case TargetOpcode::G_USUBO:
2138     Opcode = TargetOpcode::G_SUB;
2139     ExtOpcode = TargetOpcode::G_ZEXT;
2140     break;
2141   case TargetOpcode::G_SADDE:
2142     Opcode = TargetOpcode::G_UADDE;
2143     ExtOpcode = TargetOpcode::G_SEXT;
2144     CarryIn = MI.getOperand(4).getReg();
2145     break;
2146   case TargetOpcode::G_SSUBE:
2147     Opcode = TargetOpcode::G_USUBE;
2148     ExtOpcode = TargetOpcode::G_SEXT;
2149     CarryIn = MI.getOperand(4).getReg();
2150     break;
2151   case TargetOpcode::G_UADDE:
2152     Opcode = TargetOpcode::G_UADDE;
2153     ExtOpcode = TargetOpcode::G_ZEXT;
2154     CarryIn = MI.getOperand(4).getReg();
2155     break;
2156   case TargetOpcode::G_USUBE:
2157     Opcode = TargetOpcode::G_USUBE;
2158     ExtOpcode = TargetOpcode::G_ZEXT;
2159     CarryIn = MI.getOperand(4).getReg();
2160     break;
2161   }
2162 
2163   if (TypeIdx == 1) {
2164     unsigned BoolExtOp = MIRBuilder.getBoolExtOp(WideTy.isVector(), false);
2165 
2166     Observer.changingInstr(MI);
2167     if (CarryIn)
2168       widenScalarSrc(MI, WideTy, 4, BoolExtOp);
2169     widenScalarDst(MI, WideTy, 1);
2170 
2171     Observer.changedInstr(MI);
2172     return Legalized;
2173   }
2174 
2175   auto LHSExt = MIRBuilder.buildInstr(ExtOpcode, {WideTy}, {MI.getOperand(2)});
2176   auto RHSExt = MIRBuilder.buildInstr(ExtOpcode, {WideTy}, {MI.getOperand(3)});
2177   // Do the arithmetic in the larger type.
2178   Register NewOp;
2179   if (CarryIn) {
2180     LLT CarryOutTy = MRI.getType(MI.getOperand(1).getReg());
2181     NewOp = MIRBuilder
2182                 .buildInstr(Opcode, {WideTy, CarryOutTy},
2183                             {LHSExt, RHSExt, *CarryIn})
2184                 .getReg(0);
2185   } else {
2186     NewOp = MIRBuilder.buildInstr(Opcode, {WideTy}, {LHSExt, RHSExt}).getReg(0);
2187   }
2188   LLT OrigTy = MRI.getType(MI.getOperand(0).getReg());
2189   auto TruncOp = MIRBuilder.buildTrunc(OrigTy, NewOp);
2190   auto ExtOp = MIRBuilder.buildInstr(ExtOpcode, {WideTy}, {TruncOp});
2191   // There is no overflow if the ExtOp is the same as NewOp.
2192   MIRBuilder.buildICmp(CmpInst::ICMP_NE, MI.getOperand(1), NewOp, ExtOp);
2193   // Now trunc the NewOp to the original result.
2194   MIRBuilder.buildTrunc(MI.getOperand(0), NewOp);
2195   MI.eraseFromParent();
2196   return Legalized;
2197 }
2198 
2199 LegalizerHelper::LegalizeResult
widenScalarAddSubShlSat(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)2200 LegalizerHelper::widenScalarAddSubShlSat(MachineInstr &MI, unsigned TypeIdx,
2201                                          LLT WideTy) {
2202   bool IsSigned = MI.getOpcode() == TargetOpcode::G_SADDSAT ||
2203                   MI.getOpcode() == TargetOpcode::G_SSUBSAT ||
2204                   MI.getOpcode() == TargetOpcode::G_SSHLSAT;
2205   bool IsShift = MI.getOpcode() == TargetOpcode::G_SSHLSAT ||
2206                  MI.getOpcode() == TargetOpcode::G_USHLSAT;
2207   // We can convert this to:
2208   //   1. Any extend iN to iM
2209   //   2. SHL by M-N
2210   //   3. [US][ADD|SUB|SHL]SAT
2211   //   4. L/ASHR by M-N
2212   //
2213   // It may be more efficient to lower this to a min and a max operation in
2214   // the higher precision arithmetic if the promoted operation isn't legal,
2215   // but this decision is up to the target's lowering request.
2216   Register DstReg = MI.getOperand(0).getReg();
2217 
2218   unsigned NewBits = WideTy.getScalarSizeInBits();
2219   unsigned SHLAmount = NewBits - MRI.getType(DstReg).getScalarSizeInBits();
2220 
2221   // Shifts must zero-extend the RHS to preserve the unsigned quantity, and
2222   // must not left shift the RHS to preserve the shift amount.
2223   auto LHS = MIRBuilder.buildAnyExt(WideTy, MI.getOperand(1));
2224   auto RHS = IsShift ? MIRBuilder.buildZExt(WideTy, MI.getOperand(2))
2225                      : MIRBuilder.buildAnyExt(WideTy, MI.getOperand(2));
2226   auto ShiftK = MIRBuilder.buildConstant(WideTy, SHLAmount);
2227   auto ShiftL = MIRBuilder.buildShl(WideTy, LHS, ShiftK);
2228   auto ShiftR = IsShift ? RHS : MIRBuilder.buildShl(WideTy, RHS, ShiftK);
2229 
2230   auto WideInst = MIRBuilder.buildInstr(MI.getOpcode(), {WideTy},
2231                                         {ShiftL, ShiftR}, MI.getFlags());
2232 
2233   // Use a shift that will preserve the number of sign bits when the trunc is
2234   // folded away.
2235   auto Result = IsSigned ? MIRBuilder.buildAShr(WideTy, WideInst, ShiftK)
2236                          : MIRBuilder.buildLShr(WideTy, WideInst, ShiftK);
2237 
2238   MIRBuilder.buildTrunc(DstReg, Result);
2239   MI.eraseFromParent();
2240   return Legalized;
2241 }
2242 
2243 LegalizerHelper::LegalizeResult
widenScalarMulo(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)2244 LegalizerHelper::widenScalarMulo(MachineInstr &MI, unsigned TypeIdx,
2245                                  LLT WideTy) {
2246   if (TypeIdx == 1) {
2247     Observer.changingInstr(MI);
2248     widenScalarDst(MI, WideTy, 1);
2249     Observer.changedInstr(MI);
2250     return Legalized;
2251   }
2252 
2253   bool IsSigned = MI.getOpcode() == TargetOpcode::G_SMULO;
2254   auto [Result, OriginalOverflow, LHS, RHS] = MI.getFirst4Regs();
2255   LLT SrcTy = MRI.getType(LHS);
2256   LLT OverflowTy = MRI.getType(OriginalOverflow);
2257   unsigned SrcBitWidth = SrcTy.getScalarSizeInBits();
2258 
2259   // To determine if the result overflowed in the larger type, we extend the
2260   // input to the larger type, do the multiply (checking if it overflows),
2261   // then also check the high bits of the result to see if overflow happened
2262   // there.
2263   unsigned ExtOp = IsSigned ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
2264   auto LeftOperand = MIRBuilder.buildInstr(ExtOp, {WideTy}, {LHS});
2265   auto RightOperand = MIRBuilder.buildInstr(ExtOp, {WideTy}, {RHS});
2266 
2267   // Multiplication cannot overflow if the WideTy is >= 2 * original width,
2268   // so we don't need to check the overflow result of larger type Mulo.
2269   bool WideMulCanOverflow = WideTy.getScalarSizeInBits() < 2 * SrcBitWidth;
2270 
2271   unsigned MulOpc =
2272       WideMulCanOverflow ? MI.getOpcode() : (unsigned)TargetOpcode::G_MUL;
2273 
2274   MachineInstrBuilder Mulo;
2275   if (WideMulCanOverflow)
2276     Mulo = MIRBuilder.buildInstr(MulOpc, {WideTy, OverflowTy},
2277                                  {LeftOperand, RightOperand});
2278   else
2279     Mulo = MIRBuilder.buildInstr(MulOpc, {WideTy}, {LeftOperand, RightOperand});
2280 
2281   auto Mul = Mulo->getOperand(0);
2282   MIRBuilder.buildTrunc(Result, Mul);
2283 
2284   MachineInstrBuilder ExtResult;
2285   // Overflow occurred if it occurred in the larger type, or if the high part
2286   // of the result does not zero/sign-extend the low part.  Check this second
2287   // possibility first.
2288   if (IsSigned) {
2289     // For signed, overflow occurred when the high part does not sign-extend
2290     // the low part.
2291     ExtResult = MIRBuilder.buildSExtInReg(WideTy, Mul, SrcBitWidth);
2292   } else {
2293     // Unsigned overflow occurred when the high part does not zero-extend the
2294     // low part.
2295     ExtResult = MIRBuilder.buildZExtInReg(WideTy, Mul, SrcBitWidth);
2296   }
2297 
2298   if (WideMulCanOverflow) {
2299     auto Overflow =
2300         MIRBuilder.buildICmp(CmpInst::ICMP_NE, OverflowTy, Mul, ExtResult);
2301     // Finally check if the multiplication in the larger type itself overflowed.
2302     MIRBuilder.buildOr(OriginalOverflow, Mulo->getOperand(1), Overflow);
2303   } else {
2304     MIRBuilder.buildICmp(CmpInst::ICMP_NE, OriginalOverflow, Mul, ExtResult);
2305   }
2306   MI.eraseFromParent();
2307   return Legalized;
2308 }
2309 
2310 LegalizerHelper::LegalizeResult
widenScalar(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)2311 LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
2312   switch (MI.getOpcode()) {
2313   default:
2314     return UnableToLegalize;
2315   case TargetOpcode::G_ATOMICRMW_XCHG:
2316   case TargetOpcode::G_ATOMICRMW_ADD:
2317   case TargetOpcode::G_ATOMICRMW_SUB:
2318   case TargetOpcode::G_ATOMICRMW_AND:
2319   case TargetOpcode::G_ATOMICRMW_OR:
2320   case TargetOpcode::G_ATOMICRMW_XOR:
2321   case TargetOpcode::G_ATOMICRMW_MIN:
2322   case TargetOpcode::G_ATOMICRMW_MAX:
2323   case TargetOpcode::G_ATOMICRMW_UMIN:
2324   case TargetOpcode::G_ATOMICRMW_UMAX:
2325     assert(TypeIdx == 0 && "atomicrmw with second scalar type");
2326     Observer.changingInstr(MI);
2327     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
2328     widenScalarDst(MI, WideTy, 0);
2329     Observer.changedInstr(MI);
2330     return Legalized;
2331   case TargetOpcode::G_ATOMIC_CMPXCHG:
2332     assert(TypeIdx == 0 && "G_ATOMIC_CMPXCHG with second scalar type");
2333     Observer.changingInstr(MI);
2334     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
2335     widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_ANYEXT);
2336     widenScalarDst(MI, WideTy, 0);
2337     Observer.changedInstr(MI);
2338     return Legalized;
2339   case TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS:
2340     if (TypeIdx == 0) {
2341       Observer.changingInstr(MI);
2342       widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_ANYEXT);
2343       widenScalarSrc(MI, WideTy, 4, TargetOpcode::G_ANYEXT);
2344       widenScalarDst(MI, WideTy, 0);
2345       Observer.changedInstr(MI);
2346       return Legalized;
2347     }
2348     assert(TypeIdx == 1 &&
2349            "G_ATOMIC_CMPXCHG_WITH_SUCCESS with third scalar type");
2350     Observer.changingInstr(MI);
2351     widenScalarDst(MI, WideTy, 1);
2352     Observer.changedInstr(MI);
2353     return Legalized;
2354   case TargetOpcode::G_EXTRACT:
2355     return widenScalarExtract(MI, TypeIdx, WideTy);
2356   case TargetOpcode::G_INSERT:
2357     return widenScalarInsert(MI, TypeIdx, WideTy);
2358   case TargetOpcode::G_MERGE_VALUES:
2359     return widenScalarMergeValues(MI, TypeIdx, WideTy);
2360   case TargetOpcode::G_UNMERGE_VALUES:
2361     return widenScalarUnmergeValues(MI, TypeIdx, WideTy);
2362   case TargetOpcode::G_SADDO:
2363   case TargetOpcode::G_SSUBO:
2364   case TargetOpcode::G_UADDO:
2365   case TargetOpcode::G_USUBO:
2366   case TargetOpcode::G_SADDE:
2367   case TargetOpcode::G_SSUBE:
2368   case TargetOpcode::G_UADDE:
2369   case TargetOpcode::G_USUBE:
2370     return widenScalarAddSubOverflow(MI, TypeIdx, WideTy);
2371   case TargetOpcode::G_UMULO:
2372   case TargetOpcode::G_SMULO:
2373     return widenScalarMulo(MI, TypeIdx, WideTy);
2374   case TargetOpcode::G_SADDSAT:
2375   case TargetOpcode::G_SSUBSAT:
2376   case TargetOpcode::G_SSHLSAT:
2377   case TargetOpcode::G_UADDSAT:
2378   case TargetOpcode::G_USUBSAT:
2379   case TargetOpcode::G_USHLSAT:
2380     return widenScalarAddSubShlSat(MI, TypeIdx, WideTy);
2381   case TargetOpcode::G_CTTZ:
2382   case TargetOpcode::G_CTTZ_ZERO_UNDEF:
2383   case TargetOpcode::G_CTLZ:
2384   case TargetOpcode::G_CTLZ_ZERO_UNDEF:
2385   case TargetOpcode::G_CTPOP: {
2386     if (TypeIdx == 0) {
2387       Observer.changingInstr(MI);
2388       widenScalarDst(MI, WideTy, 0);
2389       Observer.changedInstr(MI);
2390       return Legalized;
2391     }
2392 
2393     Register SrcReg = MI.getOperand(1).getReg();
2394 
2395     // First extend the input.
2396     unsigned ExtOpc = MI.getOpcode() == TargetOpcode::G_CTTZ ||
2397                               MI.getOpcode() == TargetOpcode::G_CTTZ_ZERO_UNDEF
2398                           ? TargetOpcode::G_ANYEXT
2399                           : TargetOpcode::G_ZEXT;
2400     auto MIBSrc = MIRBuilder.buildInstr(ExtOpc, {WideTy}, {SrcReg});
2401     LLT CurTy = MRI.getType(SrcReg);
2402     unsigned NewOpc = MI.getOpcode();
2403     if (NewOpc == TargetOpcode::G_CTTZ) {
2404       // The count is the same in the larger type except if the original
2405       // value was zero.  This can be handled by setting the bit just off
2406       // the top of the original type.
2407       auto TopBit =
2408           APInt::getOneBitSet(WideTy.getSizeInBits(), CurTy.getSizeInBits());
2409       MIBSrc = MIRBuilder.buildOr(
2410         WideTy, MIBSrc, MIRBuilder.buildConstant(WideTy, TopBit));
2411       // Now we know the operand is non-zero, use the more relaxed opcode.
2412       NewOpc = TargetOpcode::G_CTTZ_ZERO_UNDEF;
2413     }
2414 
2415     // Perform the operation at the larger size.
2416     auto MIBNewOp = MIRBuilder.buildInstr(NewOpc, {WideTy}, {MIBSrc});
2417     // This is already the correct result for CTPOP and CTTZs
2418     if (MI.getOpcode() == TargetOpcode::G_CTLZ ||
2419         MI.getOpcode() == TargetOpcode::G_CTLZ_ZERO_UNDEF) {
2420       // The correct result is NewOp - (Difference in widety and current ty).
2421       unsigned SizeDiff = WideTy.getSizeInBits() - CurTy.getSizeInBits();
2422       MIBNewOp = MIRBuilder.buildSub(
2423           WideTy, MIBNewOp, MIRBuilder.buildConstant(WideTy, SizeDiff));
2424     }
2425 
2426     MIRBuilder.buildZExtOrTrunc(MI.getOperand(0), MIBNewOp);
2427     MI.eraseFromParent();
2428     return Legalized;
2429   }
2430   case TargetOpcode::G_BSWAP: {
2431     Observer.changingInstr(MI);
2432     Register DstReg = MI.getOperand(0).getReg();
2433 
2434     Register ShrReg = MRI.createGenericVirtualRegister(WideTy);
2435     Register DstExt = MRI.createGenericVirtualRegister(WideTy);
2436     Register ShiftAmtReg = MRI.createGenericVirtualRegister(WideTy);
2437     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2438 
2439     MI.getOperand(0).setReg(DstExt);
2440 
2441     MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
2442 
2443     LLT Ty = MRI.getType(DstReg);
2444     unsigned DiffBits = WideTy.getScalarSizeInBits() - Ty.getScalarSizeInBits();
2445     MIRBuilder.buildConstant(ShiftAmtReg, DiffBits);
2446     MIRBuilder.buildLShr(ShrReg, DstExt, ShiftAmtReg);
2447 
2448     MIRBuilder.buildTrunc(DstReg, ShrReg);
2449     Observer.changedInstr(MI);
2450     return Legalized;
2451   }
2452   case TargetOpcode::G_BITREVERSE: {
2453     Observer.changingInstr(MI);
2454 
2455     Register DstReg = MI.getOperand(0).getReg();
2456     LLT Ty = MRI.getType(DstReg);
2457     unsigned DiffBits = WideTy.getScalarSizeInBits() - Ty.getScalarSizeInBits();
2458 
2459     Register DstExt = MRI.createGenericVirtualRegister(WideTy);
2460     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2461     MI.getOperand(0).setReg(DstExt);
2462     MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
2463 
2464     auto ShiftAmt = MIRBuilder.buildConstant(WideTy, DiffBits);
2465     auto Shift = MIRBuilder.buildLShr(WideTy, DstExt, ShiftAmt);
2466     MIRBuilder.buildTrunc(DstReg, Shift);
2467     Observer.changedInstr(MI);
2468     return Legalized;
2469   }
2470   case TargetOpcode::G_FREEZE:
2471     Observer.changingInstr(MI);
2472     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2473     widenScalarDst(MI, WideTy);
2474     Observer.changedInstr(MI);
2475     return Legalized;
2476 
2477   case TargetOpcode::G_ABS:
2478     Observer.changingInstr(MI);
2479     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_SEXT);
2480     widenScalarDst(MI, WideTy);
2481     Observer.changedInstr(MI);
2482     return Legalized;
2483 
2484   case TargetOpcode::G_ADD:
2485   case TargetOpcode::G_AND:
2486   case TargetOpcode::G_MUL:
2487   case TargetOpcode::G_OR:
2488   case TargetOpcode::G_XOR:
2489   case TargetOpcode::G_SUB:
2490     // Perform operation at larger width (any extension is fines here, high bits
2491     // don't affect the result) and then truncate the result back to the
2492     // original type.
2493     Observer.changingInstr(MI);
2494     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2495     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
2496     widenScalarDst(MI, WideTy);
2497     Observer.changedInstr(MI);
2498     return Legalized;
2499 
2500   case TargetOpcode::G_SBFX:
2501   case TargetOpcode::G_UBFX:
2502     Observer.changingInstr(MI);
2503 
2504     if (TypeIdx == 0) {
2505       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2506       widenScalarDst(MI, WideTy);
2507     } else {
2508       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2509       widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_ZEXT);
2510     }
2511 
2512     Observer.changedInstr(MI);
2513     return Legalized;
2514 
2515   case TargetOpcode::G_SHL:
2516     Observer.changingInstr(MI);
2517 
2518     if (TypeIdx == 0) {
2519       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2520       widenScalarDst(MI, WideTy);
2521     } else {
2522       assert(TypeIdx == 1);
2523       // The "number of bits to shift" operand must preserve its value as an
2524       // unsigned integer:
2525       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2526     }
2527 
2528     Observer.changedInstr(MI);
2529     return Legalized;
2530 
2531   case TargetOpcode::G_ROTR:
2532   case TargetOpcode::G_ROTL:
2533     if (TypeIdx != 1)
2534       return UnableToLegalize;
2535 
2536     Observer.changingInstr(MI);
2537     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2538     Observer.changedInstr(MI);
2539     return Legalized;
2540 
2541   case TargetOpcode::G_SDIV:
2542   case TargetOpcode::G_SREM:
2543   case TargetOpcode::G_SMIN:
2544   case TargetOpcode::G_SMAX:
2545     Observer.changingInstr(MI);
2546     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_SEXT);
2547     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_SEXT);
2548     widenScalarDst(MI, WideTy);
2549     Observer.changedInstr(MI);
2550     return Legalized;
2551 
2552   case TargetOpcode::G_SDIVREM:
2553     Observer.changingInstr(MI);
2554     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_SEXT);
2555     widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_SEXT);
2556     widenScalarDst(MI, WideTy);
2557     widenScalarDst(MI, WideTy, 1);
2558     Observer.changedInstr(MI);
2559     return Legalized;
2560 
2561   case TargetOpcode::G_ASHR:
2562   case TargetOpcode::G_LSHR:
2563     Observer.changingInstr(MI);
2564 
2565     if (TypeIdx == 0) {
2566       unsigned CvtOp = MI.getOpcode() == TargetOpcode::G_ASHR ?
2567         TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
2568 
2569       widenScalarSrc(MI, WideTy, 1, CvtOp);
2570       widenScalarDst(MI, WideTy);
2571     } else {
2572       assert(TypeIdx == 1);
2573       // The "number of bits to shift" operand must preserve its value as an
2574       // unsigned integer:
2575       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2576     }
2577 
2578     Observer.changedInstr(MI);
2579     return Legalized;
2580   case TargetOpcode::G_UDIV:
2581   case TargetOpcode::G_UREM:
2582   case TargetOpcode::G_UMIN:
2583   case TargetOpcode::G_UMAX:
2584     Observer.changingInstr(MI);
2585     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ZEXT);
2586     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2587     widenScalarDst(MI, WideTy);
2588     Observer.changedInstr(MI);
2589     return Legalized;
2590 
2591   case TargetOpcode::G_UDIVREM:
2592     Observer.changingInstr(MI);
2593     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2594     widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_ZEXT);
2595     widenScalarDst(MI, WideTy);
2596     widenScalarDst(MI, WideTy, 1);
2597     Observer.changedInstr(MI);
2598     return Legalized;
2599 
2600   case TargetOpcode::G_SELECT:
2601     Observer.changingInstr(MI);
2602     if (TypeIdx == 0) {
2603       // Perform operation at larger width (any extension is fine here, high
2604       // bits don't affect the result) and then truncate the result back to the
2605       // original type.
2606       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
2607       widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_ANYEXT);
2608       widenScalarDst(MI, WideTy);
2609     } else {
2610       bool IsVec = MRI.getType(MI.getOperand(1).getReg()).isVector();
2611       // Explicit extension is required here since high bits affect the result.
2612       widenScalarSrc(MI, WideTy, 1, MIRBuilder.getBoolExtOp(IsVec, false));
2613     }
2614     Observer.changedInstr(MI);
2615     return Legalized;
2616 
2617   case TargetOpcode::G_FPTOSI:
2618   case TargetOpcode::G_FPTOUI:
2619   case TargetOpcode::G_IS_FPCLASS:
2620     Observer.changingInstr(MI);
2621 
2622     if (TypeIdx == 0)
2623       widenScalarDst(MI, WideTy);
2624     else
2625       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_FPEXT);
2626 
2627     Observer.changedInstr(MI);
2628     return Legalized;
2629   case TargetOpcode::G_SITOFP:
2630     Observer.changingInstr(MI);
2631 
2632     if (TypeIdx == 0)
2633       widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
2634     else
2635       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_SEXT);
2636 
2637     Observer.changedInstr(MI);
2638     return Legalized;
2639   case TargetOpcode::G_UITOFP:
2640     Observer.changingInstr(MI);
2641 
2642     if (TypeIdx == 0)
2643       widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
2644     else
2645       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ZEXT);
2646 
2647     Observer.changedInstr(MI);
2648     return Legalized;
2649   case TargetOpcode::G_LOAD:
2650   case TargetOpcode::G_SEXTLOAD:
2651   case TargetOpcode::G_ZEXTLOAD:
2652     Observer.changingInstr(MI);
2653     widenScalarDst(MI, WideTy);
2654     Observer.changedInstr(MI);
2655     return Legalized;
2656 
2657   case TargetOpcode::G_STORE: {
2658     if (TypeIdx != 0)
2659       return UnableToLegalize;
2660 
2661     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
2662     if (!Ty.isScalar())
2663       return UnableToLegalize;
2664 
2665     Observer.changingInstr(MI);
2666 
2667     unsigned ExtType = Ty.getScalarSizeInBits() == 1 ?
2668       TargetOpcode::G_ZEXT : TargetOpcode::G_ANYEXT;
2669     widenScalarSrc(MI, WideTy, 0, ExtType);
2670 
2671     Observer.changedInstr(MI);
2672     return Legalized;
2673   }
2674   case TargetOpcode::G_CONSTANT: {
2675     MachineOperand &SrcMO = MI.getOperand(1);
2676     LLVMContext &Ctx = MIRBuilder.getMF().getFunction().getContext();
2677     unsigned ExtOpc = LI.getExtOpcodeForWideningConstant(
2678         MRI.getType(MI.getOperand(0).getReg()));
2679     assert((ExtOpc == TargetOpcode::G_ZEXT || ExtOpc == TargetOpcode::G_SEXT ||
2680             ExtOpc == TargetOpcode::G_ANYEXT) &&
2681            "Illegal Extend");
2682     const APInt &SrcVal = SrcMO.getCImm()->getValue();
2683     const APInt &Val = (ExtOpc == TargetOpcode::G_SEXT)
2684                            ? SrcVal.sext(WideTy.getSizeInBits())
2685                            : SrcVal.zext(WideTy.getSizeInBits());
2686     Observer.changingInstr(MI);
2687     SrcMO.setCImm(ConstantInt::get(Ctx, Val));
2688 
2689     widenScalarDst(MI, WideTy);
2690     Observer.changedInstr(MI);
2691     return Legalized;
2692   }
2693   case TargetOpcode::G_FCONSTANT: {
2694     // To avoid changing the bits of the constant due to extension to a larger
2695     // type and then using G_FPTRUNC, we simply convert to a G_CONSTANT.
2696     MachineOperand &SrcMO = MI.getOperand(1);
2697     APInt Val = SrcMO.getFPImm()->getValueAPF().bitcastToAPInt();
2698     MIRBuilder.setInstrAndDebugLoc(MI);
2699     auto IntCst = MIRBuilder.buildConstant(MI.getOperand(0).getReg(), Val);
2700     widenScalarDst(*IntCst, WideTy, 0, TargetOpcode::G_TRUNC);
2701     MI.eraseFromParent();
2702     return Legalized;
2703   }
2704   case TargetOpcode::G_IMPLICIT_DEF: {
2705     Observer.changingInstr(MI);
2706     widenScalarDst(MI, WideTy);
2707     Observer.changedInstr(MI);
2708     return Legalized;
2709   }
2710   case TargetOpcode::G_BRCOND:
2711     Observer.changingInstr(MI);
2712     widenScalarSrc(MI, WideTy, 0, MIRBuilder.getBoolExtOp(false, false));
2713     Observer.changedInstr(MI);
2714     return Legalized;
2715 
2716   case TargetOpcode::G_FCMP:
2717     Observer.changingInstr(MI);
2718     if (TypeIdx == 0)
2719       widenScalarDst(MI, WideTy);
2720     else {
2721       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_FPEXT);
2722       widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_FPEXT);
2723     }
2724     Observer.changedInstr(MI);
2725     return Legalized;
2726 
2727   case TargetOpcode::G_ICMP:
2728     Observer.changingInstr(MI);
2729     if (TypeIdx == 0)
2730       widenScalarDst(MI, WideTy);
2731     else {
2732       unsigned ExtOpcode = CmpInst::isSigned(static_cast<CmpInst::Predicate>(
2733                                MI.getOperand(1).getPredicate()))
2734                                ? TargetOpcode::G_SEXT
2735                                : TargetOpcode::G_ZEXT;
2736       widenScalarSrc(MI, WideTy, 2, ExtOpcode);
2737       widenScalarSrc(MI, WideTy, 3, ExtOpcode);
2738     }
2739     Observer.changedInstr(MI);
2740     return Legalized;
2741 
2742   case TargetOpcode::G_PTR_ADD:
2743     assert(TypeIdx == 1 && "unable to legalize pointer of G_PTR_ADD");
2744     Observer.changingInstr(MI);
2745     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_SEXT);
2746     Observer.changedInstr(MI);
2747     return Legalized;
2748 
2749   case TargetOpcode::G_PHI: {
2750     assert(TypeIdx == 0 && "Expecting only Idx 0");
2751 
2752     Observer.changingInstr(MI);
2753     for (unsigned I = 1; I < MI.getNumOperands(); I += 2) {
2754       MachineBasicBlock &OpMBB = *MI.getOperand(I + 1).getMBB();
2755       MIRBuilder.setInsertPt(OpMBB, OpMBB.getFirstTerminatorForward());
2756       widenScalarSrc(MI, WideTy, I, TargetOpcode::G_ANYEXT);
2757     }
2758 
2759     MachineBasicBlock &MBB = *MI.getParent();
2760     MIRBuilder.setInsertPt(MBB, --MBB.getFirstNonPHI());
2761     widenScalarDst(MI, WideTy);
2762     Observer.changedInstr(MI);
2763     return Legalized;
2764   }
2765   case TargetOpcode::G_EXTRACT_VECTOR_ELT: {
2766     if (TypeIdx == 0) {
2767       Register VecReg = MI.getOperand(1).getReg();
2768       LLT VecTy = MRI.getType(VecReg);
2769       Observer.changingInstr(MI);
2770 
2771       widenScalarSrc(
2772           MI, LLT::vector(VecTy.getElementCount(), WideTy.getSizeInBits()), 1,
2773           TargetOpcode::G_ANYEXT);
2774 
2775       widenScalarDst(MI, WideTy, 0);
2776       Observer.changedInstr(MI);
2777       return Legalized;
2778     }
2779 
2780     if (TypeIdx != 2)
2781       return UnableToLegalize;
2782     Observer.changingInstr(MI);
2783     // TODO: Probably should be zext
2784     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_SEXT);
2785     Observer.changedInstr(MI);
2786     return Legalized;
2787   }
2788   case TargetOpcode::G_INSERT_VECTOR_ELT: {
2789     if (TypeIdx == 0) {
2790       Observer.changingInstr(MI);
2791       const LLT WideEltTy = WideTy.getElementType();
2792 
2793       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2794       widenScalarSrc(MI, WideEltTy, 2, TargetOpcode::G_ANYEXT);
2795       widenScalarDst(MI, WideTy, 0);
2796       Observer.changedInstr(MI);
2797       return Legalized;
2798     }
2799 
2800     if (TypeIdx == 1) {
2801       Observer.changingInstr(MI);
2802 
2803       Register VecReg = MI.getOperand(1).getReg();
2804       LLT VecTy = MRI.getType(VecReg);
2805       LLT WideVecTy = LLT::vector(VecTy.getElementCount(), WideTy);
2806 
2807       widenScalarSrc(MI, WideVecTy, 1, TargetOpcode::G_ANYEXT);
2808       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
2809       widenScalarDst(MI, WideVecTy, 0);
2810       Observer.changedInstr(MI);
2811       return Legalized;
2812     }
2813 
2814     if (TypeIdx == 2) {
2815       Observer.changingInstr(MI);
2816       // TODO: Probably should be zext
2817       widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_SEXT);
2818       Observer.changedInstr(MI);
2819       return Legalized;
2820     }
2821 
2822     return UnableToLegalize;
2823   }
2824   case TargetOpcode::G_FADD:
2825   case TargetOpcode::G_FMUL:
2826   case TargetOpcode::G_FSUB:
2827   case TargetOpcode::G_FMA:
2828   case TargetOpcode::G_FMAD:
2829   case TargetOpcode::G_FNEG:
2830   case TargetOpcode::G_FABS:
2831   case TargetOpcode::G_FCANONICALIZE:
2832   case TargetOpcode::G_FMINNUM:
2833   case TargetOpcode::G_FMAXNUM:
2834   case TargetOpcode::G_FMINNUM_IEEE:
2835   case TargetOpcode::G_FMAXNUM_IEEE:
2836   case TargetOpcode::G_FMINIMUM:
2837   case TargetOpcode::G_FMAXIMUM:
2838   case TargetOpcode::G_FDIV:
2839   case TargetOpcode::G_FREM:
2840   case TargetOpcode::G_FCEIL:
2841   case TargetOpcode::G_FFLOOR:
2842   case TargetOpcode::G_FCOS:
2843   case TargetOpcode::G_FSIN:
2844   case TargetOpcode::G_FLOG10:
2845   case TargetOpcode::G_FLOG:
2846   case TargetOpcode::G_FLOG2:
2847   case TargetOpcode::G_FRINT:
2848   case TargetOpcode::G_FNEARBYINT:
2849   case TargetOpcode::G_FSQRT:
2850   case TargetOpcode::G_FEXP:
2851   case TargetOpcode::G_FEXP2:
2852   case TargetOpcode::G_FEXP10:
2853   case TargetOpcode::G_FPOW:
2854   case TargetOpcode::G_INTRINSIC_TRUNC:
2855   case TargetOpcode::G_INTRINSIC_ROUND:
2856   case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
2857     assert(TypeIdx == 0);
2858     Observer.changingInstr(MI);
2859 
2860     for (unsigned I = 1, E = MI.getNumOperands(); I != E; ++I)
2861       widenScalarSrc(MI, WideTy, I, TargetOpcode::G_FPEXT);
2862 
2863     widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
2864     Observer.changedInstr(MI);
2865     return Legalized;
2866   case TargetOpcode::G_FPOWI:
2867   case TargetOpcode::G_FLDEXP:
2868   case TargetOpcode::G_STRICT_FLDEXP: {
2869     if (TypeIdx == 0) {
2870       if (MI.getOpcode() == TargetOpcode::G_STRICT_FLDEXP)
2871         return UnableToLegalize;
2872 
2873       Observer.changingInstr(MI);
2874       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_FPEXT);
2875       widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
2876       Observer.changedInstr(MI);
2877       return Legalized;
2878     }
2879 
2880     if (TypeIdx == 1) {
2881       // For some reason SelectionDAG tries to promote to a libcall without
2882       // actually changing the integer type for promotion.
2883       Observer.changingInstr(MI);
2884       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_SEXT);
2885       Observer.changedInstr(MI);
2886       return Legalized;
2887     }
2888 
2889     return UnableToLegalize;
2890   }
2891   case TargetOpcode::G_FFREXP: {
2892     Observer.changingInstr(MI);
2893 
2894     if (TypeIdx == 0) {
2895       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_FPEXT);
2896       widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
2897     } else {
2898       widenScalarDst(MI, WideTy, 1);
2899     }
2900 
2901     Observer.changedInstr(MI);
2902     return Legalized;
2903   }
2904   case TargetOpcode::G_INTTOPTR:
2905     if (TypeIdx != 1)
2906       return UnableToLegalize;
2907 
2908     Observer.changingInstr(MI);
2909     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ZEXT);
2910     Observer.changedInstr(MI);
2911     return Legalized;
2912   case TargetOpcode::G_PTRTOINT:
2913     if (TypeIdx != 0)
2914       return UnableToLegalize;
2915 
2916     Observer.changingInstr(MI);
2917     widenScalarDst(MI, WideTy, 0);
2918     Observer.changedInstr(MI);
2919     return Legalized;
2920   case TargetOpcode::G_BUILD_VECTOR: {
2921     Observer.changingInstr(MI);
2922 
2923     const LLT WideEltTy = TypeIdx == 1 ? WideTy : WideTy.getElementType();
2924     for (int I = 1, E = MI.getNumOperands(); I != E; ++I)
2925       widenScalarSrc(MI, WideEltTy, I, TargetOpcode::G_ANYEXT);
2926 
2927     // Avoid changing the result vector type if the source element type was
2928     // requested.
2929     if (TypeIdx == 1) {
2930       MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::G_BUILD_VECTOR_TRUNC));
2931     } else {
2932       widenScalarDst(MI, WideTy, 0);
2933     }
2934 
2935     Observer.changedInstr(MI);
2936     return Legalized;
2937   }
2938   case TargetOpcode::G_SEXT_INREG:
2939     if (TypeIdx != 0)
2940       return UnableToLegalize;
2941 
2942     Observer.changingInstr(MI);
2943     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2944     widenScalarDst(MI, WideTy, 0, TargetOpcode::G_TRUNC);
2945     Observer.changedInstr(MI);
2946     return Legalized;
2947   case TargetOpcode::G_PTRMASK: {
2948     if (TypeIdx != 1)
2949       return UnableToLegalize;
2950     Observer.changingInstr(MI);
2951     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2952     Observer.changedInstr(MI);
2953     return Legalized;
2954   }
2955   case TargetOpcode::G_VECREDUCE_FADD:
2956   case TargetOpcode::G_VECREDUCE_FMUL:
2957   case TargetOpcode::G_VECREDUCE_FMIN:
2958   case TargetOpcode::G_VECREDUCE_FMAX:
2959   case TargetOpcode::G_VECREDUCE_FMINIMUM:
2960   case TargetOpcode::G_VECREDUCE_FMAXIMUM:
2961     if (TypeIdx != 0)
2962       return UnableToLegalize;
2963     Observer.changingInstr(MI);
2964     Register VecReg = MI.getOperand(1).getReg();
2965     LLT VecTy = MRI.getType(VecReg);
2966     LLT WideVecTy = VecTy.isVector()
2967                         ? LLT::vector(VecTy.getElementCount(), WideTy)
2968                         : WideTy;
2969     widenScalarSrc(MI, WideVecTy, 1, TargetOpcode::G_FPEXT);
2970     widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
2971     Observer.changedInstr(MI);
2972     return Legalized;
2973   }
2974 }
2975 
getUnmergePieces(SmallVectorImpl<Register> & Pieces,MachineIRBuilder & B,Register Src,LLT Ty)2976 static void getUnmergePieces(SmallVectorImpl<Register> &Pieces,
2977                              MachineIRBuilder &B, Register Src, LLT Ty) {
2978   auto Unmerge = B.buildUnmerge(Ty, Src);
2979   for (int I = 0, E = Unmerge->getNumOperands() - 1; I != E; ++I)
2980     Pieces.push_back(Unmerge.getReg(I));
2981 }
2982 
2983 LegalizerHelper::LegalizeResult
lowerFConstant(MachineInstr & MI)2984 LegalizerHelper::lowerFConstant(MachineInstr &MI) {
2985   Register Dst = MI.getOperand(0).getReg();
2986 
2987   MachineFunction &MF = MIRBuilder.getMF();
2988   const DataLayout &DL = MIRBuilder.getDataLayout();
2989 
2990   unsigned AddrSpace = DL.getDefaultGlobalsAddressSpace();
2991   LLT AddrPtrTy = LLT::pointer(AddrSpace, DL.getPointerSizeInBits(AddrSpace));
2992   Align Alignment = Align(DL.getABITypeAlign(
2993       getFloatTypeForLLT(MF.getFunction().getContext(), MRI.getType(Dst))));
2994 
2995   auto Addr = MIRBuilder.buildConstantPool(
2996       AddrPtrTy, MF.getConstantPool()->getConstantPoolIndex(
2997                      MI.getOperand(1).getFPImm(), Alignment));
2998 
2999   MachineMemOperand *MMO = MF.getMachineMemOperand(
3000       MachinePointerInfo::getConstantPool(MF), MachineMemOperand::MOLoad,
3001       MRI.getType(Dst), Alignment);
3002 
3003   MIRBuilder.buildLoadInstr(TargetOpcode::G_LOAD, Dst, Addr, *MMO);
3004   MI.eraseFromParent();
3005 
3006   return Legalized;
3007 }
3008 
3009 LegalizerHelper::LegalizeResult
lowerBitcast(MachineInstr & MI)3010 LegalizerHelper::lowerBitcast(MachineInstr &MI) {
3011   auto [Dst, DstTy, Src, SrcTy] = MI.getFirst2RegLLTs();
3012   if (SrcTy.isVector()) {
3013     LLT SrcEltTy = SrcTy.getElementType();
3014     SmallVector<Register, 8> SrcRegs;
3015 
3016     if (DstTy.isVector()) {
3017       int NumDstElt = DstTy.getNumElements();
3018       int NumSrcElt = SrcTy.getNumElements();
3019 
3020       LLT DstEltTy = DstTy.getElementType();
3021       LLT DstCastTy = DstEltTy; // Intermediate bitcast result type
3022       LLT SrcPartTy = SrcEltTy; // Original unmerge result type.
3023 
3024       // If there's an element size mismatch, insert intermediate casts to match
3025       // the result element type.
3026       if (NumSrcElt < NumDstElt) { // Source element type is larger.
3027         // %1:_(<4 x s8>) = G_BITCAST %0:_(<2 x s16>)
3028         //
3029         // =>
3030         //
3031         // %2:_(s16), %3:_(s16) = G_UNMERGE_VALUES %0
3032         // %3:_(<2 x s8>) = G_BITCAST %2
3033         // %4:_(<2 x s8>) = G_BITCAST %3
3034         // %1:_(<4 x s16>) = G_CONCAT_VECTORS %3, %4
3035         DstCastTy = LLT::fixed_vector(NumDstElt / NumSrcElt, DstEltTy);
3036         SrcPartTy = SrcEltTy;
3037       } else if (NumSrcElt > NumDstElt) { // Source element type is smaller.
3038         //
3039         // %1:_(<2 x s16>) = G_BITCAST %0:_(<4 x s8>)
3040         //
3041         // =>
3042         //
3043         // %2:_(<2 x s8>), %3:_(<2 x s8>) = G_UNMERGE_VALUES %0
3044         // %3:_(s16) = G_BITCAST %2
3045         // %4:_(s16) = G_BITCAST %3
3046         // %1:_(<2 x s16>) = G_BUILD_VECTOR %3, %4
3047         SrcPartTy = LLT::fixed_vector(NumSrcElt / NumDstElt, SrcEltTy);
3048         DstCastTy = DstEltTy;
3049       }
3050 
3051       getUnmergePieces(SrcRegs, MIRBuilder, Src, SrcPartTy);
3052       for (Register &SrcReg : SrcRegs)
3053         SrcReg = MIRBuilder.buildBitcast(DstCastTy, SrcReg).getReg(0);
3054     } else
3055       getUnmergePieces(SrcRegs, MIRBuilder, Src, SrcEltTy);
3056 
3057     MIRBuilder.buildMergeLikeInstr(Dst, SrcRegs);
3058     MI.eraseFromParent();
3059     return Legalized;
3060   }
3061 
3062   if (DstTy.isVector()) {
3063     SmallVector<Register, 8> SrcRegs;
3064     getUnmergePieces(SrcRegs, MIRBuilder, Src, DstTy.getElementType());
3065     MIRBuilder.buildMergeLikeInstr(Dst, SrcRegs);
3066     MI.eraseFromParent();
3067     return Legalized;
3068   }
3069 
3070   return UnableToLegalize;
3071 }
3072 
3073 /// Figure out the bit offset into a register when coercing a vector index for
3074 /// the wide element type. This is only for the case when promoting vector to
3075 /// one with larger elements.
3076 //
3077 ///
3078 /// %offset_idx = G_AND %idx, ~(-1 << Log2(DstEltSize / SrcEltSize))
3079 /// %offset_bits = G_SHL %offset_idx, Log2(SrcEltSize)
getBitcastWiderVectorElementOffset(MachineIRBuilder & B,Register Idx,unsigned NewEltSize,unsigned OldEltSize)3080 static Register getBitcastWiderVectorElementOffset(MachineIRBuilder &B,
3081                                                    Register Idx,
3082                                                    unsigned NewEltSize,
3083                                                    unsigned OldEltSize) {
3084   const unsigned Log2EltRatio = Log2_32(NewEltSize / OldEltSize);
3085   LLT IdxTy = B.getMRI()->getType(Idx);
3086 
3087   // Now figure out the amount we need to shift to get the target bits.
3088   auto OffsetMask = B.buildConstant(
3089       IdxTy, ~(APInt::getAllOnes(IdxTy.getSizeInBits()) << Log2EltRatio));
3090   auto OffsetIdx = B.buildAnd(IdxTy, Idx, OffsetMask);
3091   return B.buildShl(IdxTy, OffsetIdx,
3092                     B.buildConstant(IdxTy, Log2_32(OldEltSize))).getReg(0);
3093 }
3094 
3095 /// Perform a G_EXTRACT_VECTOR_ELT in a different sized vector element. If this
3096 /// is casting to a vector with a smaller element size, perform multiple element
3097 /// extracts and merge the results. If this is coercing to a vector with larger
3098 /// elements, index the bitcasted vector and extract the target element with bit
3099 /// operations. This is intended to force the indexing in the native register
3100 /// size for architectures that can dynamically index the register file.
3101 LegalizerHelper::LegalizeResult
bitcastExtractVectorElt(MachineInstr & MI,unsigned TypeIdx,LLT CastTy)3102 LegalizerHelper::bitcastExtractVectorElt(MachineInstr &MI, unsigned TypeIdx,
3103                                          LLT CastTy) {
3104   if (TypeIdx != 1)
3105     return UnableToLegalize;
3106 
3107   auto [Dst, DstTy, SrcVec, SrcVecTy, Idx, IdxTy] = MI.getFirst3RegLLTs();
3108 
3109   LLT SrcEltTy = SrcVecTy.getElementType();
3110   unsigned NewNumElts = CastTy.isVector() ? CastTy.getNumElements() : 1;
3111   unsigned OldNumElts = SrcVecTy.getNumElements();
3112 
3113   LLT NewEltTy = CastTy.isVector() ? CastTy.getElementType() : CastTy;
3114   Register CastVec = MIRBuilder.buildBitcast(CastTy, SrcVec).getReg(0);
3115 
3116   const unsigned NewEltSize = NewEltTy.getSizeInBits();
3117   const unsigned OldEltSize = SrcEltTy.getSizeInBits();
3118   if (NewNumElts > OldNumElts) {
3119     // Decreasing the vector element size
3120     //
3121     // e.g. i64 = extract_vector_elt x:v2i64, y:i32
3122     //  =>
3123     //  v4i32:castx = bitcast x:v2i64
3124     //
3125     // i64 = bitcast
3126     //   (v2i32 build_vector (i32 (extract_vector_elt castx, (2 * y))),
3127     //                       (i32 (extract_vector_elt castx, (2 * y + 1)))
3128     //
3129     if (NewNumElts % OldNumElts != 0)
3130       return UnableToLegalize;
3131 
3132     // Type of the intermediate result vector.
3133     const unsigned NewEltsPerOldElt = NewNumElts / OldNumElts;
3134     LLT MidTy =
3135         LLT::scalarOrVector(ElementCount::getFixed(NewEltsPerOldElt), NewEltTy);
3136 
3137     auto NewEltsPerOldEltK = MIRBuilder.buildConstant(IdxTy, NewEltsPerOldElt);
3138 
3139     SmallVector<Register, 8> NewOps(NewEltsPerOldElt);
3140     auto NewBaseIdx = MIRBuilder.buildMul(IdxTy, Idx, NewEltsPerOldEltK);
3141 
3142     for (unsigned I = 0; I < NewEltsPerOldElt; ++I) {
3143       auto IdxOffset = MIRBuilder.buildConstant(IdxTy, I);
3144       auto TmpIdx = MIRBuilder.buildAdd(IdxTy, NewBaseIdx, IdxOffset);
3145       auto Elt = MIRBuilder.buildExtractVectorElement(NewEltTy, CastVec, TmpIdx);
3146       NewOps[I] = Elt.getReg(0);
3147     }
3148 
3149     auto NewVec = MIRBuilder.buildBuildVector(MidTy, NewOps);
3150     MIRBuilder.buildBitcast(Dst, NewVec);
3151     MI.eraseFromParent();
3152     return Legalized;
3153   }
3154 
3155   if (NewNumElts < OldNumElts) {
3156     if (NewEltSize % OldEltSize != 0)
3157       return UnableToLegalize;
3158 
3159     // This only depends on powers of 2 because we use bit tricks to figure out
3160     // the bit offset we need to shift to get the target element. A general
3161     // expansion could emit division/multiply.
3162     if (!isPowerOf2_32(NewEltSize / OldEltSize))
3163       return UnableToLegalize;
3164 
3165     // Increasing the vector element size.
3166     // %elt:_(small_elt) = G_EXTRACT_VECTOR_ELT %vec:_(<N x small_elt>), %idx
3167     //
3168     //   =>
3169     //
3170     // %cast = G_BITCAST %vec
3171     // %scaled_idx = G_LSHR %idx, Log2(DstEltSize / SrcEltSize)
3172     // %wide_elt  = G_EXTRACT_VECTOR_ELT %cast, %scaled_idx
3173     // %offset_idx = G_AND %idx, ~(-1 << Log2(DstEltSize / SrcEltSize))
3174     // %offset_bits = G_SHL %offset_idx, Log2(SrcEltSize)
3175     // %elt_bits = G_LSHR %wide_elt, %offset_bits
3176     // %elt = G_TRUNC %elt_bits
3177 
3178     const unsigned Log2EltRatio = Log2_32(NewEltSize / OldEltSize);
3179     auto Log2Ratio = MIRBuilder.buildConstant(IdxTy, Log2EltRatio);
3180 
3181     // Divide to get the index in the wider element type.
3182     auto ScaledIdx = MIRBuilder.buildLShr(IdxTy, Idx, Log2Ratio);
3183 
3184     Register WideElt = CastVec;
3185     if (CastTy.isVector()) {
3186       WideElt = MIRBuilder.buildExtractVectorElement(NewEltTy, CastVec,
3187                                                      ScaledIdx).getReg(0);
3188     }
3189 
3190     // Compute the bit offset into the register of the target element.
3191     Register OffsetBits = getBitcastWiderVectorElementOffset(
3192       MIRBuilder, Idx, NewEltSize, OldEltSize);
3193 
3194     // Shift the wide element to get the target element.
3195     auto ExtractedBits = MIRBuilder.buildLShr(NewEltTy, WideElt, OffsetBits);
3196     MIRBuilder.buildTrunc(Dst, ExtractedBits);
3197     MI.eraseFromParent();
3198     return Legalized;
3199   }
3200 
3201   return UnableToLegalize;
3202 }
3203 
3204 /// Emit code to insert \p InsertReg into \p TargetRet at \p OffsetBits in \p
3205 /// TargetReg, while preserving other bits in \p TargetReg.
3206 ///
3207 /// (InsertReg << Offset) | (TargetReg & ~(-1 >> InsertReg.size()) << Offset)
buildBitFieldInsert(MachineIRBuilder & B,Register TargetReg,Register InsertReg,Register OffsetBits)3208 static Register buildBitFieldInsert(MachineIRBuilder &B,
3209                                     Register TargetReg, Register InsertReg,
3210                                     Register OffsetBits) {
3211   LLT TargetTy = B.getMRI()->getType(TargetReg);
3212   LLT InsertTy = B.getMRI()->getType(InsertReg);
3213   auto ZextVal = B.buildZExt(TargetTy, InsertReg);
3214   auto ShiftedInsertVal = B.buildShl(TargetTy, ZextVal, OffsetBits);
3215 
3216   // Produce a bitmask of the value to insert
3217   auto EltMask = B.buildConstant(
3218     TargetTy, APInt::getLowBitsSet(TargetTy.getSizeInBits(),
3219                                    InsertTy.getSizeInBits()));
3220   // Shift it into position
3221   auto ShiftedMask = B.buildShl(TargetTy, EltMask, OffsetBits);
3222   auto InvShiftedMask = B.buildNot(TargetTy, ShiftedMask);
3223 
3224   // Clear out the bits in the wide element
3225   auto MaskedOldElt = B.buildAnd(TargetTy, TargetReg, InvShiftedMask);
3226 
3227   // The value to insert has all zeros already, so stick it into the masked
3228   // wide element.
3229   return B.buildOr(TargetTy, MaskedOldElt, ShiftedInsertVal).getReg(0);
3230 }
3231 
3232 /// Perform a G_INSERT_VECTOR_ELT in a different sized vector element. If this
3233 /// is increasing the element size, perform the indexing in the target element
3234 /// type, and use bit operations to insert at the element position. This is
3235 /// intended for architectures that can dynamically index the register file and
3236 /// want to force indexing in the native register size.
3237 LegalizerHelper::LegalizeResult
bitcastInsertVectorElt(MachineInstr & MI,unsigned TypeIdx,LLT CastTy)3238 LegalizerHelper::bitcastInsertVectorElt(MachineInstr &MI, unsigned TypeIdx,
3239                                         LLT CastTy) {
3240   if (TypeIdx != 0)
3241     return UnableToLegalize;
3242 
3243   auto [Dst, DstTy, SrcVec, SrcVecTy, Val, ValTy, Idx, IdxTy] =
3244       MI.getFirst4RegLLTs();
3245   LLT VecTy = DstTy;
3246 
3247   LLT VecEltTy = VecTy.getElementType();
3248   LLT NewEltTy = CastTy.isVector() ? CastTy.getElementType() : CastTy;
3249   const unsigned NewEltSize = NewEltTy.getSizeInBits();
3250   const unsigned OldEltSize = VecEltTy.getSizeInBits();
3251 
3252   unsigned NewNumElts = CastTy.isVector() ? CastTy.getNumElements() : 1;
3253   unsigned OldNumElts = VecTy.getNumElements();
3254 
3255   Register CastVec = MIRBuilder.buildBitcast(CastTy, SrcVec).getReg(0);
3256   if (NewNumElts < OldNumElts) {
3257     if (NewEltSize % OldEltSize != 0)
3258       return UnableToLegalize;
3259 
3260     // This only depends on powers of 2 because we use bit tricks to figure out
3261     // the bit offset we need to shift to get the target element. A general
3262     // expansion could emit division/multiply.
3263     if (!isPowerOf2_32(NewEltSize / OldEltSize))
3264       return UnableToLegalize;
3265 
3266     const unsigned Log2EltRatio = Log2_32(NewEltSize / OldEltSize);
3267     auto Log2Ratio = MIRBuilder.buildConstant(IdxTy, Log2EltRatio);
3268 
3269     // Divide to get the index in the wider element type.
3270     auto ScaledIdx = MIRBuilder.buildLShr(IdxTy, Idx, Log2Ratio);
3271 
3272     Register ExtractedElt = CastVec;
3273     if (CastTy.isVector()) {
3274       ExtractedElt = MIRBuilder.buildExtractVectorElement(NewEltTy, CastVec,
3275                                                           ScaledIdx).getReg(0);
3276     }
3277 
3278     // Compute the bit offset into the register of the target element.
3279     Register OffsetBits = getBitcastWiderVectorElementOffset(
3280       MIRBuilder, Idx, NewEltSize, OldEltSize);
3281 
3282     Register InsertedElt = buildBitFieldInsert(MIRBuilder, ExtractedElt,
3283                                                Val, OffsetBits);
3284     if (CastTy.isVector()) {
3285       InsertedElt = MIRBuilder.buildInsertVectorElement(
3286         CastTy, CastVec, InsertedElt, ScaledIdx).getReg(0);
3287     }
3288 
3289     MIRBuilder.buildBitcast(Dst, InsertedElt);
3290     MI.eraseFromParent();
3291     return Legalized;
3292   }
3293 
3294   return UnableToLegalize;
3295 }
3296 
lowerLoad(GAnyLoad & LoadMI)3297 LegalizerHelper::LegalizeResult LegalizerHelper::lowerLoad(GAnyLoad &LoadMI) {
3298   // Lower to a memory-width G_LOAD and a G_SEXT/G_ZEXT/G_ANYEXT
3299   Register DstReg = LoadMI.getDstReg();
3300   Register PtrReg = LoadMI.getPointerReg();
3301   LLT DstTy = MRI.getType(DstReg);
3302   MachineMemOperand &MMO = LoadMI.getMMO();
3303   LLT MemTy = MMO.getMemoryType();
3304   MachineFunction &MF = MIRBuilder.getMF();
3305 
3306   unsigned MemSizeInBits = MemTy.getSizeInBits();
3307   unsigned MemStoreSizeInBits = 8 * MemTy.getSizeInBytes();
3308 
3309   if (MemSizeInBits != MemStoreSizeInBits) {
3310     if (MemTy.isVector())
3311       return UnableToLegalize;
3312 
3313     // Promote to a byte-sized load if not loading an integral number of
3314     // bytes.  For example, promote EXTLOAD:i20 -> EXTLOAD:i24.
3315     LLT WideMemTy = LLT::scalar(MemStoreSizeInBits);
3316     MachineMemOperand *NewMMO =
3317         MF.getMachineMemOperand(&MMO, MMO.getPointerInfo(), WideMemTy);
3318 
3319     Register LoadReg = DstReg;
3320     LLT LoadTy = DstTy;
3321 
3322     // If this wasn't already an extending load, we need to widen the result
3323     // register to avoid creating a load with a narrower result than the source.
3324     if (MemStoreSizeInBits > DstTy.getSizeInBits()) {
3325       LoadTy = WideMemTy;
3326       LoadReg = MRI.createGenericVirtualRegister(WideMemTy);
3327     }
3328 
3329     if (isa<GSExtLoad>(LoadMI)) {
3330       auto NewLoad = MIRBuilder.buildLoad(LoadTy, PtrReg, *NewMMO);
3331       MIRBuilder.buildSExtInReg(LoadReg, NewLoad, MemSizeInBits);
3332     } else if (isa<GZExtLoad>(LoadMI) || WideMemTy == LoadTy) {
3333       auto NewLoad = MIRBuilder.buildLoad(LoadTy, PtrReg, *NewMMO);
3334       // The extra bits are guaranteed to be zero, since we stored them that
3335       // way.  A zext load from Wide thus automatically gives zext from MemVT.
3336       MIRBuilder.buildAssertZExt(LoadReg, NewLoad, MemSizeInBits);
3337     } else {
3338       MIRBuilder.buildLoad(LoadReg, PtrReg, *NewMMO);
3339     }
3340 
3341     if (DstTy != LoadTy)
3342       MIRBuilder.buildTrunc(DstReg, LoadReg);
3343 
3344     LoadMI.eraseFromParent();
3345     return Legalized;
3346   }
3347 
3348   // Big endian lowering not implemented.
3349   if (MIRBuilder.getDataLayout().isBigEndian())
3350     return UnableToLegalize;
3351 
3352   // This load needs splitting into power of 2 sized loads.
3353   //
3354   // Our strategy here is to generate anyextending loads for the smaller
3355   // types up to next power-2 result type, and then combine the two larger
3356   // result values together, before truncating back down to the non-pow-2
3357   // type.
3358   // E.g. v1 = i24 load =>
3359   // v2 = i32 zextload (2 byte)
3360   // v3 = i32 load (1 byte)
3361   // v4 = i32 shl v3, 16
3362   // v5 = i32 or v4, v2
3363   // v1 = i24 trunc v5
3364   // By doing this we generate the correct truncate which should get
3365   // combined away as an artifact with a matching extend.
3366 
3367   uint64_t LargeSplitSize, SmallSplitSize;
3368 
3369   if (!isPowerOf2_32(MemSizeInBits)) {
3370     // This load needs splitting into power of 2 sized loads.
3371     LargeSplitSize = llvm::bit_floor(MemSizeInBits);
3372     SmallSplitSize = MemSizeInBits - LargeSplitSize;
3373   } else {
3374     // This is already a power of 2, but we still need to split this in half.
3375     //
3376     // Assume we're being asked to decompose an unaligned load.
3377     // TODO: If this requires multiple splits, handle them all at once.
3378     auto &Ctx = MF.getFunction().getContext();
3379     if (TLI.allowsMemoryAccess(Ctx, MIRBuilder.getDataLayout(), MemTy, MMO))
3380       return UnableToLegalize;
3381 
3382     SmallSplitSize = LargeSplitSize = MemSizeInBits / 2;
3383   }
3384 
3385   if (MemTy.isVector()) {
3386     // TODO: Handle vector extloads
3387     if (MemTy != DstTy)
3388       return UnableToLegalize;
3389 
3390     // TODO: We can do better than scalarizing the vector and at least split it
3391     // in half.
3392     return reduceLoadStoreWidth(LoadMI, 0, DstTy.getElementType());
3393   }
3394 
3395   MachineMemOperand *LargeMMO =
3396       MF.getMachineMemOperand(&MMO, 0, LargeSplitSize / 8);
3397   MachineMemOperand *SmallMMO =
3398       MF.getMachineMemOperand(&MMO, LargeSplitSize / 8, SmallSplitSize / 8);
3399 
3400   LLT PtrTy = MRI.getType(PtrReg);
3401   unsigned AnyExtSize = PowerOf2Ceil(DstTy.getSizeInBits());
3402   LLT AnyExtTy = LLT::scalar(AnyExtSize);
3403   auto LargeLoad = MIRBuilder.buildLoadInstr(TargetOpcode::G_ZEXTLOAD, AnyExtTy,
3404                                              PtrReg, *LargeMMO);
3405 
3406   auto OffsetCst = MIRBuilder.buildConstant(LLT::scalar(PtrTy.getSizeInBits()),
3407                                             LargeSplitSize / 8);
3408   Register PtrAddReg = MRI.createGenericVirtualRegister(PtrTy);
3409   auto SmallPtr = MIRBuilder.buildPtrAdd(PtrAddReg, PtrReg, OffsetCst);
3410   auto SmallLoad = MIRBuilder.buildLoadInstr(LoadMI.getOpcode(), AnyExtTy,
3411                                              SmallPtr, *SmallMMO);
3412 
3413   auto ShiftAmt = MIRBuilder.buildConstant(AnyExtTy, LargeSplitSize);
3414   auto Shift = MIRBuilder.buildShl(AnyExtTy, SmallLoad, ShiftAmt);
3415 
3416   if (AnyExtTy == DstTy)
3417     MIRBuilder.buildOr(DstReg, Shift, LargeLoad);
3418   else if (AnyExtTy.getSizeInBits() != DstTy.getSizeInBits()) {
3419     auto Or = MIRBuilder.buildOr(AnyExtTy, Shift, LargeLoad);
3420     MIRBuilder.buildTrunc(DstReg, {Or});
3421   } else {
3422     assert(DstTy.isPointer() && "expected pointer");
3423     auto Or = MIRBuilder.buildOr(AnyExtTy, Shift, LargeLoad);
3424 
3425     // FIXME: We currently consider this to be illegal for non-integral address
3426     // spaces, but we need still need a way to reinterpret the bits.
3427     MIRBuilder.buildIntToPtr(DstReg, Or);
3428   }
3429 
3430   LoadMI.eraseFromParent();
3431   return Legalized;
3432 }
3433 
lowerStore(GStore & StoreMI)3434 LegalizerHelper::LegalizeResult LegalizerHelper::lowerStore(GStore &StoreMI) {
3435   // Lower a non-power of 2 store into multiple pow-2 stores.
3436   // E.g. split an i24 store into an i16 store + i8 store.
3437   // We do this by first extending the stored value to the next largest power
3438   // of 2 type, and then using truncating stores to store the components.
3439   // By doing this, likewise with G_LOAD, generate an extend that can be
3440   // artifact-combined away instead of leaving behind extracts.
3441   Register SrcReg = StoreMI.getValueReg();
3442   Register PtrReg = StoreMI.getPointerReg();
3443   LLT SrcTy = MRI.getType(SrcReg);
3444   MachineFunction &MF = MIRBuilder.getMF();
3445   MachineMemOperand &MMO = **StoreMI.memoperands_begin();
3446   LLT MemTy = MMO.getMemoryType();
3447 
3448   unsigned StoreWidth = MemTy.getSizeInBits();
3449   unsigned StoreSizeInBits = 8 * MemTy.getSizeInBytes();
3450 
3451   if (StoreWidth != StoreSizeInBits) {
3452     if (SrcTy.isVector())
3453       return UnableToLegalize;
3454 
3455     // Promote to a byte-sized store with upper bits zero if not
3456     // storing an integral number of bytes.  For example, promote
3457     // TRUNCSTORE:i1 X -> TRUNCSTORE:i8 (and X, 1)
3458     LLT WideTy = LLT::scalar(StoreSizeInBits);
3459 
3460     if (StoreSizeInBits > SrcTy.getSizeInBits()) {
3461       // Avoid creating a store with a narrower source than result.
3462       SrcReg = MIRBuilder.buildAnyExt(WideTy, SrcReg).getReg(0);
3463       SrcTy = WideTy;
3464     }
3465 
3466     auto ZextInReg = MIRBuilder.buildZExtInReg(SrcTy, SrcReg, StoreWidth);
3467 
3468     MachineMemOperand *NewMMO =
3469         MF.getMachineMemOperand(&MMO, MMO.getPointerInfo(), WideTy);
3470     MIRBuilder.buildStore(ZextInReg, PtrReg, *NewMMO);
3471     StoreMI.eraseFromParent();
3472     return Legalized;
3473   }
3474 
3475   if (MemTy.isVector()) {
3476     // TODO: Handle vector trunc stores
3477     if (MemTy != SrcTy)
3478       return UnableToLegalize;
3479 
3480     // TODO: We can do better than scalarizing the vector and at least split it
3481     // in half.
3482     return reduceLoadStoreWidth(StoreMI, 0, SrcTy.getElementType());
3483   }
3484 
3485   unsigned MemSizeInBits = MemTy.getSizeInBits();
3486   uint64_t LargeSplitSize, SmallSplitSize;
3487 
3488   if (!isPowerOf2_32(MemSizeInBits)) {
3489     LargeSplitSize = llvm::bit_floor<uint64_t>(MemTy.getSizeInBits());
3490     SmallSplitSize = MemTy.getSizeInBits() - LargeSplitSize;
3491   } else {
3492     auto &Ctx = MF.getFunction().getContext();
3493     if (TLI.allowsMemoryAccess(Ctx, MIRBuilder.getDataLayout(), MemTy, MMO))
3494       return UnableToLegalize; // Don't know what we're being asked to do.
3495 
3496     SmallSplitSize = LargeSplitSize = MemSizeInBits / 2;
3497   }
3498 
3499   // Extend to the next pow-2. If this store was itself the result of lowering,
3500   // e.g. an s56 store being broken into s32 + s24, we might have a stored type
3501   // that's wider than the stored size.
3502   unsigned AnyExtSize = PowerOf2Ceil(MemTy.getSizeInBits());
3503   const LLT NewSrcTy = LLT::scalar(AnyExtSize);
3504 
3505   if (SrcTy.isPointer()) {
3506     const LLT IntPtrTy = LLT::scalar(SrcTy.getSizeInBits());
3507     SrcReg = MIRBuilder.buildPtrToInt(IntPtrTy, SrcReg).getReg(0);
3508   }
3509 
3510   auto ExtVal = MIRBuilder.buildAnyExtOrTrunc(NewSrcTy, SrcReg);
3511 
3512   // Obtain the smaller value by shifting away the larger value.
3513   auto ShiftAmt = MIRBuilder.buildConstant(NewSrcTy, LargeSplitSize);
3514   auto SmallVal = MIRBuilder.buildLShr(NewSrcTy, ExtVal, ShiftAmt);
3515 
3516   // Generate the PtrAdd and truncating stores.
3517   LLT PtrTy = MRI.getType(PtrReg);
3518   auto OffsetCst = MIRBuilder.buildConstant(
3519     LLT::scalar(PtrTy.getSizeInBits()), LargeSplitSize / 8);
3520   auto SmallPtr =
3521     MIRBuilder.buildPtrAdd(PtrTy, PtrReg, OffsetCst);
3522 
3523   MachineMemOperand *LargeMMO =
3524     MF.getMachineMemOperand(&MMO, 0, LargeSplitSize / 8);
3525   MachineMemOperand *SmallMMO =
3526     MF.getMachineMemOperand(&MMO, LargeSplitSize / 8, SmallSplitSize / 8);
3527   MIRBuilder.buildStore(ExtVal, PtrReg, *LargeMMO);
3528   MIRBuilder.buildStore(SmallVal, SmallPtr, *SmallMMO);
3529   StoreMI.eraseFromParent();
3530   return Legalized;
3531 }
3532 
3533 LegalizerHelper::LegalizeResult
bitcast(MachineInstr & MI,unsigned TypeIdx,LLT CastTy)3534 LegalizerHelper::bitcast(MachineInstr &MI, unsigned TypeIdx, LLT CastTy) {
3535   switch (MI.getOpcode()) {
3536   case TargetOpcode::G_LOAD: {
3537     if (TypeIdx != 0)
3538       return UnableToLegalize;
3539     MachineMemOperand &MMO = **MI.memoperands_begin();
3540 
3541     // Not sure how to interpret a bitcast of an extending load.
3542     if (MMO.getMemoryType().getSizeInBits() != CastTy.getSizeInBits())
3543       return UnableToLegalize;
3544 
3545     Observer.changingInstr(MI);
3546     bitcastDst(MI, CastTy, 0);
3547     MMO.setType(CastTy);
3548     Observer.changedInstr(MI);
3549     return Legalized;
3550   }
3551   case TargetOpcode::G_STORE: {
3552     if (TypeIdx != 0)
3553       return UnableToLegalize;
3554 
3555     MachineMemOperand &MMO = **MI.memoperands_begin();
3556 
3557     // Not sure how to interpret a bitcast of a truncating store.
3558     if (MMO.getMemoryType().getSizeInBits() != CastTy.getSizeInBits())
3559       return UnableToLegalize;
3560 
3561     Observer.changingInstr(MI);
3562     bitcastSrc(MI, CastTy, 0);
3563     MMO.setType(CastTy);
3564     Observer.changedInstr(MI);
3565     return Legalized;
3566   }
3567   case TargetOpcode::G_SELECT: {
3568     if (TypeIdx != 0)
3569       return UnableToLegalize;
3570 
3571     if (MRI.getType(MI.getOperand(1).getReg()).isVector()) {
3572       LLVM_DEBUG(
3573           dbgs() << "bitcast action not implemented for vector select\n");
3574       return UnableToLegalize;
3575     }
3576 
3577     Observer.changingInstr(MI);
3578     bitcastSrc(MI, CastTy, 2);
3579     bitcastSrc(MI, CastTy, 3);
3580     bitcastDst(MI, CastTy, 0);
3581     Observer.changedInstr(MI);
3582     return Legalized;
3583   }
3584   case TargetOpcode::G_AND:
3585   case TargetOpcode::G_OR:
3586   case TargetOpcode::G_XOR: {
3587     Observer.changingInstr(MI);
3588     bitcastSrc(MI, CastTy, 1);
3589     bitcastSrc(MI, CastTy, 2);
3590     bitcastDst(MI, CastTy, 0);
3591     Observer.changedInstr(MI);
3592     return Legalized;
3593   }
3594   case TargetOpcode::G_EXTRACT_VECTOR_ELT:
3595     return bitcastExtractVectorElt(MI, TypeIdx, CastTy);
3596   case TargetOpcode::G_INSERT_VECTOR_ELT:
3597     return bitcastInsertVectorElt(MI, TypeIdx, CastTy);
3598   default:
3599     return UnableToLegalize;
3600   }
3601 }
3602 
3603 // Legalize an instruction by changing the opcode in place.
changeOpcode(MachineInstr & MI,unsigned NewOpcode)3604 void LegalizerHelper::changeOpcode(MachineInstr &MI, unsigned NewOpcode) {
3605     Observer.changingInstr(MI);
3606     MI.setDesc(MIRBuilder.getTII().get(NewOpcode));
3607     Observer.changedInstr(MI);
3608 }
3609 
3610 LegalizerHelper::LegalizeResult
lower(MachineInstr & MI,unsigned TypeIdx,LLT LowerHintTy)3611 LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
3612   using namespace TargetOpcode;
3613 
3614   switch(MI.getOpcode()) {
3615   default:
3616     return UnableToLegalize;
3617   case TargetOpcode::G_FCONSTANT:
3618     return lowerFConstant(MI);
3619   case TargetOpcode::G_BITCAST:
3620     return lowerBitcast(MI);
3621   case TargetOpcode::G_SREM:
3622   case TargetOpcode::G_UREM: {
3623     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
3624     auto Quot =
3625         MIRBuilder.buildInstr(MI.getOpcode() == G_SREM ? G_SDIV : G_UDIV, {Ty},
3626                               {MI.getOperand(1), MI.getOperand(2)});
3627 
3628     auto Prod = MIRBuilder.buildMul(Ty, Quot, MI.getOperand(2));
3629     MIRBuilder.buildSub(MI.getOperand(0), MI.getOperand(1), Prod);
3630     MI.eraseFromParent();
3631     return Legalized;
3632   }
3633   case TargetOpcode::G_SADDO:
3634   case TargetOpcode::G_SSUBO:
3635     return lowerSADDO_SSUBO(MI);
3636   case TargetOpcode::G_UMULH:
3637   case TargetOpcode::G_SMULH:
3638     return lowerSMULH_UMULH(MI);
3639   case TargetOpcode::G_SMULO:
3640   case TargetOpcode::G_UMULO: {
3641     // Generate G_UMULH/G_SMULH to check for overflow and a normal G_MUL for the
3642     // result.
3643     auto [Res, Overflow, LHS, RHS] = MI.getFirst4Regs();
3644     LLT Ty = MRI.getType(Res);
3645 
3646     unsigned Opcode = MI.getOpcode() == TargetOpcode::G_SMULO
3647                           ? TargetOpcode::G_SMULH
3648                           : TargetOpcode::G_UMULH;
3649 
3650     Observer.changingInstr(MI);
3651     const auto &TII = MIRBuilder.getTII();
3652     MI.setDesc(TII.get(TargetOpcode::G_MUL));
3653     MI.removeOperand(1);
3654     Observer.changedInstr(MI);
3655 
3656     auto HiPart = MIRBuilder.buildInstr(Opcode, {Ty}, {LHS, RHS});
3657     auto Zero = MIRBuilder.buildConstant(Ty, 0);
3658 
3659     // Move insert point forward so we can use the Res register if needed.
3660     MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
3661 
3662     // For *signed* multiply, overflow is detected by checking:
3663     // (hi != (lo >> bitwidth-1))
3664     if (Opcode == TargetOpcode::G_SMULH) {
3665       auto ShiftAmt = MIRBuilder.buildConstant(Ty, Ty.getSizeInBits() - 1);
3666       auto Shifted = MIRBuilder.buildAShr(Ty, Res, ShiftAmt);
3667       MIRBuilder.buildICmp(CmpInst::ICMP_NE, Overflow, HiPart, Shifted);
3668     } else {
3669       MIRBuilder.buildICmp(CmpInst::ICMP_NE, Overflow, HiPart, Zero);
3670     }
3671     return Legalized;
3672   }
3673   case TargetOpcode::G_FNEG: {
3674     auto [Res, SubByReg] = MI.getFirst2Regs();
3675     LLT Ty = MRI.getType(Res);
3676 
3677     // TODO: Handle vector types once we are able to
3678     // represent them.
3679     if (Ty.isVector())
3680       return UnableToLegalize;
3681     auto SignMask =
3682         MIRBuilder.buildConstant(Ty, APInt::getSignMask(Ty.getSizeInBits()));
3683     MIRBuilder.buildXor(Res, SubByReg, SignMask);
3684     MI.eraseFromParent();
3685     return Legalized;
3686   }
3687   case TargetOpcode::G_FSUB:
3688   case TargetOpcode::G_STRICT_FSUB: {
3689     auto [Res, LHS, RHS] = MI.getFirst3Regs();
3690     LLT Ty = MRI.getType(Res);
3691 
3692     // Lower (G_FSUB LHS, RHS) to (G_FADD LHS, (G_FNEG RHS)).
3693     auto Neg = MIRBuilder.buildFNeg(Ty, RHS);
3694 
3695     if (MI.getOpcode() == TargetOpcode::G_STRICT_FSUB)
3696       MIRBuilder.buildStrictFAdd(Res, LHS, Neg, MI.getFlags());
3697     else
3698       MIRBuilder.buildFAdd(Res, LHS, Neg, MI.getFlags());
3699 
3700     MI.eraseFromParent();
3701     return Legalized;
3702   }
3703   case TargetOpcode::G_FMAD:
3704     return lowerFMad(MI);
3705   case TargetOpcode::G_FFLOOR:
3706     return lowerFFloor(MI);
3707   case TargetOpcode::G_INTRINSIC_ROUND:
3708     return lowerIntrinsicRound(MI);
3709   case TargetOpcode::G_FRINT: {
3710     // Since round even is the assumed rounding mode for unconstrained FP
3711     // operations, rint and roundeven are the same operation.
3712     changeOpcode(MI, TargetOpcode::G_INTRINSIC_ROUNDEVEN);
3713     return Legalized;
3714   }
3715   case TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS: {
3716     auto [OldValRes, SuccessRes, Addr, CmpVal, NewVal] = MI.getFirst5Regs();
3717     MIRBuilder.buildAtomicCmpXchg(OldValRes, Addr, CmpVal, NewVal,
3718                                   **MI.memoperands_begin());
3719     MIRBuilder.buildICmp(CmpInst::ICMP_EQ, SuccessRes, OldValRes, CmpVal);
3720     MI.eraseFromParent();
3721     return Legalized;
3722   }
3723   case TargetOpcode::G_LOAD:
3724   case TargetOpcode::G_SEXTLOAD:
3725   case TargetOpcode::G_ZEXTLOAD:
3726     return lowerLoad(cast<GAnyLoad>(MI));
3727   case TargetOpcode::G_STORE:
3728     return lowerStore(cast<GStore>(MI));
3729   case TargetOpcode::G_CTLZ_ZERO_UNDEF:
3730   case TargetOpcode::G_CTTZ_ZERO_UNDEF:
3731   case TargetOpcode::G_CTLZ:
3732   case TargetOpcode::G_CTTZ:
3733   case TargetOpcode::G_CTPOP:
3734     return lowerBitCount(MI);
3735   case G_UADDO: {
3736     auto [Res, CarryOut, LHS, RHS] = MI.getFirst4Regs();
3737 
3738     MIRBuilder.buildAdd(Res, LHS, RHS);
3739     MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CarryOut, Res, RHS);
3740 
3741     MI.eraseFromParent();
3742     return Legalized;
3743   }
3744   case G_UADDE: {
3745     auto [Res, CarryOut, LHS, RHS, CarryIn] = MI.getFirst5Regs();
3746     const LLT CondTy = MRI.getType(CarryOut);
3747     const LLT Ty = MRI.getType(Res);
3748 
3749     // Initial add of the two operands.
3750     auto TmpRes = MIRBuilder.buildAdd(Ty, LHS, RHS);
3751 
3752     // Initial check for carry.
3753     auto Carry = MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CondTy, TmpRes, LHS);
3754 
3755     // Add the sum and the carry.
3756     auto ZExtCarryIn = MIRBuilder.buildZExt(Ty, CarryIn);
3757     MIRBuilder.buildAdd(Res, TmpRes, ZExtCarryIn);
3758 
3759     // Second check for carry. We can only carry if the initial sum is all 1s
3760     // and the carry is set, resulting in a new sum of 0.
3761     auto Zero = MIRBuilder.buildConstant(Ty, 0);
3762     auto ResEqZero = MIRBuilder.buildICmp(CmpInst::ICMP_EQ, CondTy, Res, Zero);
3763     auto Carry2 = MIRBuilder.buildAnd(CondTy, ResEqZero, CarryIn);
3764     MIRBuilder.buildOr(CarryOut, Carry, Carry2);
3765 
3766     MI.eraseFromParent();
3767     return Legalized;
3768   }
3769   case G_USUBO: {
3770     auto [Res, BorrowOut, LHS, RHS] = MI.getFirst4Regs();
3771 
3772     MIRBuilder.buildSub(Res, LHS, RHS);
3773     MIRBuilder.buildICmp(CmpInst::ICMP_ULT, BorrowOut, LHS, RHS);
3774 
3775     MI.eraseFromParent();
3776     return Legalized;
3777   }
3778   case G_USUBE: {
3779     auto [Res, BorrowOut, LHS, RHS, BorrowIn] = MI.getFirst5Regs();
3780     const LLT CondTy = MRI.getType(BorrowOut);
3781     const LLT Ty = MRI.getType(Res);
3782 
3783     // Initial subtract of the two operands.
3784     auto TmpRes = MIRBuilder.buildSub(Ty, LHS, RHS);
3785 
3786     // Initial check for borrow.
3787     auto Borrow = MIRBuilder.buildICmp(CmpInst::ICMP_UGT, CondTy, TmpRes, LHS);
3788 
3789     // Subtract the borrow from the first subtract.
3790     auto ZExtBorrowIn = MIRBuilder.buildZExt(Ty, BorrowIn);
3791     MIRBuilder.buildSub(Res, TmpRes, ZExtBorrowIn);
3792 
3793     // Second check for borrow. We can only borrow if the initial difference is
3794     // 0 and the borrow is set, resulting in a new difference of all 1s.
3795     auto Zero = MIRBuilder.buildConstant(Ty, 0);
3796     auto TmpResEqZero =
3797         MIRBuilder.buildICmp(CmpInst::ICMP_EQ, CondTy, TmpRes, Zero);
3798     auto Borrow2 = MIRBuilder.buildAnd(CondTy, TmpResEqZero, BorrowIn);
3799     MIRBuilder.buildOr(BorrowOut, Borrow, Borrow2);
3800 
3801     MI.eraseFromParent();
3802     return Legalized;
3803   }
3804   case G_UITOFP:
3805     return lowerUITOFP(MI);
3806   case G_SITOFP:
3807     return lowerSITOFP(MI);
3808   case G_FPTOUI:
3809     return lowerFPTOUI(MI);
3810   case G_FPTOSI:
3811     return lowerFPTOSI(MI);
3812   case G_FPTRUNC:
3813     return lowerFPTRUNC(MI);
3814   case G_FPOWI:
3815     return lowerFPOWI(MI);
3816   case G_SMIN:
3817   case G_SMAX:
3818   case G_UMIN:
3819   case G_UMAX:
3820     return lowerMinMax(MI);
3821   case G_FCOPYSIGN:
3822     return lowerFCopySign(MI);
3823   case G_FMINNUM:
3824   case G_FMAXNUM:
3825     return lowerFMinNumMaxNum(MI);
3826   case G_MERGE_VALUES:
3827     return lowerMergeValues(MI);
3828   case G_UNMERGE_VALUES:
3829     return lowerUnmergeValues(MI);
3830   case TargetOpcode::G_SEXT_INREG: {
3831     assert(MI.getOperand(2).isImm() && "Expected immediate");
3832     int64_t SizeInBits = MI.getOperand(2).getImm();
3833 
3834     auto [DstReg, SrcReg] = MI.getFirst2Regs();
3835     LLT DstTy = MRI.getType(DstReg);
3836     Register TmpRes = MRI.createGenericVirtualRegister(DstTy);
3837 
3838     auto MIBSz = MIRBuilder.buildConstant(DstTy, DstTy.getScalarSizeInBits() - SizeInBits);
3839     MIRBuilder.buildShl(TmpRes, SrcReg, MIBSz->getOperand(0));
3840     MIRBuilder.buildAShr(DstReg, TmpRes, MIBSz->getOperand(0));
3841     MI.eraseFromParent();
3842     return Legalized;
3843   }
3844   case G_EXTRACT_VECTOR_ELT:
3845   case G_INSERT_VECTOR_ELT:
3846     return lowerExtractInsertVectorElt(MI);
3847   case G_SHUFFLE_VECTOR:
3848     return lowerShuffleVector(MI);
3849   case G_DYN_STACKALLOC:
3850     return lowerDynStackAlloc(MI);
3851   case G_STACKSAVE:
3852     return lowerStackSave(MI);
3853   case G_STACKRESTORE:
3854     return lowerStackRestore(MI);
3855   case G_EXTRACT:
3856     return lowerExtract(MI);
3857   case G_INSERT:
3858     return lowerInsert(MI);
3859   case G_BSWAP:
3860     return lowerBswap(MI);
3861   case G_BITREVERSE:
3862     return lowerBitreverse(MI);
3863   case G_READ_REGISTER:
3864   case G_WRITE_REGISTER:
3865     return lowerReadWriteRegister(MI);
3866   case G_UADDSAT:
3867   case G_USUBSAT: {
3868     // Try to make a reasonable guess about which lowering strategy to use. The
3869     // target can override this with custom lowering and calling the
3870     // implementation functions.
3871     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
3872     if (LI.isLegalOrCustom({G_UMIN, Ty}))
3873       return lowerAddSubSatToMinMax(MI);
3874     return lowerAddSubSatToAddoSubo(MI);
3875   }
3876   case G_SADDSAT:
3877   case G_SSUBSAT: {
3878     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
3879 
3880     // FIXME: It would probably make more sense to see if G_SADDO is preferred,
3881     // since it's a shorter expansion. However, we would need to figure out the
3882     // preferred boolean type for the carry out for the query.
3883     if (LI.isLegalOrCustom({G_SMIN, Ty}) && LI.isLegalOrCustom({G_SMAX, Ty}))
3884       return lowerAddSubSatToMinMax(MI);
3885     return lowerAddSubSatToAddoSubo(MI);
3886   }
3887   case G_SSHLSAT:
3888   case G_USHLSAT:
3889     return lowerShlSat(MI);
3890   case G_ABS:
3891     return lowerAbsToAddXor(MI);
3892   case G_SELECT:
3893     return lowerSelect(MI);
3894   case G_IS_FPCLASS:
3895     return lowerISFPCLASS(MI);
3896   case G_SDIVREM:
3897   case G_UDIVREM:
3898     return lowerDIVREM(MI);
3899   case G_FSHL:
3900   case G_FSHR:
3901     return lowerFunnelShift(MI);
3902   case G_ROTL:
3903   case G_ROTR:
3904     return lowerRotate(MI);
3905   case G_MEMSET:
3906   case G_MEMCPY:
3907   case G_MEMMOVE:
3908     return lowerMemCpyFamily(MI);
3909   case G_MEMCPY_INLINE:
3910     return lowerMemcpyInline(MI);
3911   case G_ZEXT:
3912   case G_SEXT:
3913   case G_ANYEXT:
3914     return lowerEXT(MI);
3915   case G_TRUNC:
3916     return lowerTRUNC(MI);
3917   GISEL_VECREDUCE_CASES_NONSEQ
3918     return lowerVectorReduction(MI);
3919   case G_VAARG:
3920     return lowerVAArg(MI);
3921   }
3922 }
3923 
getStackTemporaryAlignment(LLT Ty,Align MinAlign) const3924 Align LegalizerHelper::getStackTemporaryAlignment(LLT Ty,
3925                                                   Align MinAlign) const {
3926   // FIXME: We're missing a way to go back from LLT to llvm::Type to query the
3927   // datalayout for the preferred alignment. Also there should be a target hook
3928   // for this to allow targets to reduce the alignment and ignore the
3929   // datalayout. e.g. AMDGPU should always use a 4-byte alignment, regardless of
3930   // the type.
3931   return std::max(Align(PowerOf2Ceil(Ty.getSizeInBytes())), MinAlign);
3932 }
3933 
3934 MachineInstrBuilder
createStackTemporary(TypeSize Bytes,Align Alignment,MachinePointerInfo & PtrInfo)3935 LegalizerHelper::createStackTemporary(TypeSize Bytes, Align Alignment,
3936                                       MachinePointerInfo &PtrInfo) {
3937   MachineFunction &MF = MIRBuilder.getMF();
3938   const DataLayout &DL = MIRBuilder.getDataLayout();
3939   int FrameIdx = MF.getFrameInfo().CreateStackObject(Bytes, Alignment, false);
3940 
3941   unsigned AddrSpace = DL.getAllocaAddrSpace();
3942   LLT FramePtrTy = LLT::pointer(AddrSpace, DL.getPointerSizeInBits(AddrSpace));
3943 
3944   PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIdx);
3945   return MIRBuilder.buildFrameIndex(FramePtrTy, FrameIdx);
3946 }
3947 
clampDynamicVectorIndex(MachineIRBuilder & B,Register IdxReg,LLT VecTy)3948 static Register clampDynamicVectorIndex(MachineIRBuilder &B, Register IdxReg,
3949                                         LLT VecTy) {
3950   int64_t IdxVal;
3951   if (mi_match(IdxReg, *B.getMRI(), m_ICst(IdxVal)))
3952     return IdxReg;
3953 
3954   LLT IdxTy = B.getMRI()->getType(IdxReg);
3955   unsigned NElts = VecTy.getNumElements();
3956   if (isPowerOf2_32(NElts)) {
3957     APInt Imm = APInt::getLowBitsSet(IdxTy.getSizeInBits(), Log2_32(NElts));
3958     return B.buildAnd(IdxTy, IdxReg, B.buildConstant(IdxTy, Imm)).getReg(0);
3959   }
3960 
3961   return B.buildUMin(IdxTy, IdxReg, B.buildConstant(IdxTy, NElts - 1))
3962       .getReg(0);
3963 }
3964 
getVectorElementPointer(Register VecPtr,LLT VecTy,Register Index)3965 Register LegalizerHelper::getVectorElementPointer(Register VecPtr, LLT VecTy,
3966                                                   Register Index) {
3967   LLT EltTy = VecTy.getElementType();
3968 
3969   // Calculate the element offset and add it to the pointer.
3970   unsigned EltSize = EltTy.getSizeInBits() / 8; // FIXME: should be ABI size.
3971   assert(EltSize * 8 == EltTy.getSizeInBits() &&
3972          "Converting bits to bytes lost precision");
3973 
3974   Index = clampDynamicVectorIndex(MIRBuilder, Index, VecTy);
3975 
3976   LLT IdxTy = MRI.getType(Index);
3977   auto Mul = MIRBuilder.buildMul(IdxTy, Index,
3978                                  MIRBuilder.buildConstant(IdxTy, EltSize));
3979 
3980   LLT PtrTy = MRI.getType(VecPtr);
3981   return MIRBuilder.buildPtrAdd(PtrTy, VecPtr, Mul).getReg(0);
3982 }
3983 
3984 #ifndef NDEBUG
3985 /// Check that all vector operands have same number of elements. Other operands
3986 /// should be listed in NonVecOp.
hasSameNumEltsOnAllVectorOperands(GenericMachineInstr & MI,MachineRegisterInfo & MRI,std::initializer_list<unsigned> NonVecOpIndices)3987 static bool hasSameNumEltsOnAllVectorOperands(
3988     GenericMachineInstr &MI, MachineRegisterInfo &MRI,
3989     std::initializer_list<unsigned> NonVecOpIndices) {
3990   if (MI.getNumMemOperands() != 0)
3991     return false;
3992 
3993   LLT VecTy = MRI.getType(MI.getReg(0));
3994   if (!VecTy.isVector())
3995     return false;
3996   unsigned NumElts = VecTy.getNumElements();
3997 
3998   for (unsigned OpIdx = 1; OpIdx < MI.getNumOperands(); ++OpIdx) {
3999     MachineOperand &Op = MI.getOperand(OpIdx);
4000     if (!Op.isReg()) {
4001       if (!is_contained(NonVecOpIndices, OpIdx))
4002         return false;
4003       continue;
4004     }
4005 
4006     LLT Ty = MRI.getType(Op.getReg());
4007     if (!Ty.isVector()) {
4008       if (!is_contained(NonVecOpIndices, OpIdx))
4009         return false;
4010       continue;
4011     }
4012 
4013     if (Ty.getNumElements() != NumElts)
4014       return false;
4015   }
4016 
4017   return true;
4018 }
4019 #endif
4020 
4021 /// Fill \p DstOps with DstOps that have same number of elements combined as
4022 /// the Ty. These DstOps have either scalar type when \p NumElts = 1 or are
4023 /// vectors with \p NumElts elements. When Ty.getNumElements() is not multiple
4024 /// of \p NumElts last DstOp (leftover) has fewer then \p NumElts elements.
makeDstOps(SmallVectorImpl<DstOp> & DstOps,LLT Ty,unsigned NumElts)4025 static void makeDstOps(SmallVectorImpl<DstOp> &DstOps, LLT Ty,
4026                        unsigned NumElts) {
4027   LLT LeftoverTy;
4028   assert(Ty.isVector() && "Expected vector type");
4029   LLT EltTy = Ty.getElementType();
4030   LLT NarrowTy = (NumElts == 1) ? EltTy : LLT::fixed_vector(NumElts, EltTy);
4031   int NumParts, NumLeftover;
4032   std::tie(NumParts, NumLeftover) =
4033       getNarrowTypeBreakDown(Ty, NarrowTy, LeftoverTy);
4034 
4035   assert(NumParts > 0 && "Error in getNarrowTypeBreakDown");
4036   for (int i = 0; i < NumParts; ++i) {
4037     DstOps.push_back(NarrowTy);
4038   }
4039 
4040   if (LeftoverTy.isValid()) {
4041     assert(NumLeftover == 1 && "expected exactly one leftover");
4042     DstOps.push_back(LeftoverTy);
4043   }
4044 }
4045 
4046 /// Operand \p Op is used on \p N sub-instructions. Fill \p Ops with \p N SrcOps
4047 /// made from \p Op depending on operand type.
broadcastSrcOp(SmallVectorImpl<SrcOp> & Ops,unsigned N,MachineOperand & Op)4048 static void broadcastSrcOp(SmallVectorImpl<SrcOp> &Ops, unsigned N,
4049                            MachineOperand &Op) {
4050   for (unsigned i = 0; i < N; ++i) {
4051     if (Op.isReg())
4052       Ops.push_back(Op.getReg());
4053     else if (Op.isImm())
4054       Ops.push_back(Op.getImm());
4055     else if (Op.isPredicate())
4056       Ops.push_back(static_cast<CmpInst::Predicate>(Op.getPredicate()));
4057     else
4058       llvm_unreachable("Unsupported type");
4059   }
4060 }
4061 
4062 // Handle splitting vector operations which need to have the same number of
4063 // elements in each type index, but each type index may have a different element
4064 // type.
4065 //
4066 // e.g.  <4 x s64> = G_SHL <4 x s64>, <4 x s32> ->
4067 //       <2 x s64> = G_SHL <2 x s64>, <2 x s32>
4068 //       <2 x s64> = G_SHL <2 x s64>, <2 x s32>
4069 //
4070 // Also handles some irregular breakdown cases, e.g.
4071 // e.g.  <3 x s64> = G_SHL <3 x s64>, <3 x s32> ->
4072 //       <2 x s64> = G_SHL <2 x s64>, <2 x s32>
4073 //             s64 = G_SHL s64, s32
4074 LegalizerHelper::LegalizeResult
fewerElementsVectorMultiEltType(GenericMachineInstr & MI,unsigned NumElts,std::initializer_list<unsigned> NonVecOpIndices)4075 LegalizerHelper::fewerElementsVectorMultiEltType(
4076     GenericMachineInstr &MI, unsigned NumElts,
4077     std::initializer_list<unsigned> NonVecOpIndices) {
4078   assert(hasSameNumEltsOnAllVectorOperands(MI, MRI, NonVecOpIndices) &&
4079          "Non-compatible opcode or not specified non-vector operands");
4080   unsigned OrigNumElts = MRI.getType(MI.getReg(0)).getNumElements();
4081 
4082   unsigned NumInputs = MI.getNumOperands() - MI.getNumDefs();
4083   unsigned NumDefs = MI.getNumDefs();
4084 
4085   // Create DstOps (sub-vectors with NumElts elts + Leftover) for each output.
4086   // Build instructions with DstOps to use instruction found by CSE directly.
4087   // CSE copies found instruction into given vreg when building with vreg dest.
4088   SmallVector<SmallVector<DstOp, 8>, 2> OutputOpsPieces(NumDefs);
4089   // Output registers will be taken from created instructions.
4090   SmallVector<SmallVector<Register, 8>, 2> OutputRegs(NumDefs);
4091   for (unsigned i = 0; i < NumDefs; ++i) {
4092     makeDstOps(OutputOpsPieces[i], MRI.getType(MI.getReg(i)), NumElts);
4093   }
4094 
4095   // Split vector input operands into sub-vectors with NumElts elts + Leftover.
4096   // Operands listed in NonVecOpIndices will be used as is without splitting;
4097   // examples: compare predicate in icmp and fcmp (op 1), vector select with i1
4098   // scalar condition (op 1), immediate in sext_inreg (op 2).
4099   SmallVector<SmallVector<SrcOp, 8>, 3> InputOpsPieces(NumInputs);
4100   for (unsigned UseIdx = NumDefs, UseNo = 0; UseIdx < MI.getNumOperands();
4101        ++UseIdx, ++UseNo) {
4102     if (is_contained(NonVecOpIndices, UseIdx)) {
4103       broadcastSrcOp(InputOpsPieces[UseNo], OutputOpsPieces[0].size(),
4104                      MI.getOperand(UseIdx));
4105     } else {
4106       SmallVector<Register, 8> SplitPieces;
4107       extractVectorParts(MI.getReg(UseIdx), NumElts, SplitPieces, MIRBuilder,
4108                          MRI);
4109       for (auto Reg : SplitPieces)
4110         InputOpsPieces[UseNo].push_back(Reg);
4111     }
4112   }
4113 
4114   unsigned NumLeftovers = OrigNumElts % NumElts ? 1 : 0;
4115 
4116   // Take i-th piece of each input operand split and build sub-vector/scalar
4117   // instruction. Set i-th DstOp(s) from OutputOpsPieces as destination(s).
4118   for (unsigned i = 0; i < OrigNumElts / NumElts + NumLeftovers; ++i) {
4119     SmallVector<DstOp, 2> Defs;
4120     for (unsigned DstNo = 0; DstNo < NumDefs; ++DstNo)
4121       Defs.push_back(OutputOpsPieces[DstNo][i]);
4122 
4123     SmallVector<SrcOp, 3> Uses;
4124     for (unsigned InputNo = 0; InputNo < NumInputs; ++InputNo)
4125       Uses.push_back(InputOpsPieces[InputNo][i]);
4126 
4127     auto I = MIRBuilder.buildInstr(MI.getOpcode(), Defs, Uses, MI.getFlags());
4128     for (unsigned DstNo = 0; DstNo < NumDefs; ++DstNo)
4129       OutputRegs[DstNo].push_back(I.getReg(DstNo));
4130   }
4131 
4132   // Merge small outputs into MI's output for each def operand.
4133   if (NumLeftovers) {
4134     for (unsigned i = 0; i < NumDefs; ++i)
4135       mergeMixedSubvectors(MI.getReg(i), OutputRegs[i]);
4136   } else {
4137     for (unsigned i = 0; i < NumDefs; ++i)
4138       MIRBuilder.buildMergeLikeInstr(MI.getReg(i), OutputRegs[i]);
4139   }
4140 
4141   MI.eraseFromParent();
4142   return Legalized;
4143 }
4144 
4145 LegalizerHelper::LegalizeResult
fewerElementsVectorPhi(GenericMachineInstr & MI,unsigned NumElts)4146 LegalizerHelper::fewerElementsVectorPhi(GenericMachineInstr &MI,
4147                                         unsigned NumElts) {
4148   unsigned OrigNumElts = MRI.getType(MI.getReg(0)).getNumElements();
4149 
4150   unsigned NumInputs = MI.getNumOperands() - MI.getNumDefs();
4151   unsigned NumDefs = MI.getNumDefs();
4152 
4153   SmallVector<DstOp, 8> OutputOpsPieces;
4154   SmallVector<Register, 8> OutputRegs;
4155   makeDstOps(OutputOpsPieces, MRI.getType(MI.getReg(0)), NumElts);
4156 
4157   // Instructions that perform register split will be inserted in basic block
4158   // where register is defined (basic block is in the next operand).
4159   SmallVector<SmallVector<Register, 8>, 3> InputOpsPieces(NumInputs / 2);
4160   for (unsigned UseIdx = NumDefs, UseNo = 0; UseIdx < MI.getNumOperands();
4161        UseIdx += 2, ++UseNo) {
4162     MachineBasicBlock &OpMBB = *MI.getOperand(UseIdx + 1).getMBB();
4163     MIRBuilder.setInsertPt(OpMBB, OpMBB.getFirstTerminatorForward());
4164     extractVectorParts(MI.getReg(UseIdx), NumElts, InputOpsPieces[UseNo],
4165                        MIRBuilder, MRI);
4166   }
4167 
4168   // Build PHIs with fewer elements.
4169   unsigned NumLeftovers = OrigNumElts % NumElts ? 1 : 0;
4170   MIRBuilder.setInsertPt(*MI.getParent(), MI);
4171   for (unsigned i = 0; i < OrigNumElts / NumElts + NumLeftovers; ++i) {
4172     auto Phi = MIRBuilder.buildInstr(TargetOpcode::G_PHI);
4173     Phi.addDef(
4174         MRI.createGenericVirtualRegister(OutputOpsPieces[i].getLLTTy(MRI)));
4175     OutputRegs.push_back(Phi.getReg(0));
4176 
4177     for (unsigned j = 0; j < NumInputs / 2; ++j) {
4178       Phi.addUse(InputOpsPieces[j][i]);
4179       Phi.add(MI.getOperand(1 + j * 2 + 1));
4180     }
4181   }
4182 
4183   // Set the insert point after the existing PHIs
4184   MachineBasicBlock &MBB = *MI.getParent();
4185   MIRBuilder.setInsertPt(MBB, MBB.getFirstNonPHI());
4186 
4187   // Merge small outputs into MI's def.
4188   if (NumLeftovers) {
4189     mergeMixedSubvectors(MI.getReg(0), OutputRegs);
4190   } else {
4191     MIRBuilder.buildMergeLikeInstr(MI.getReg(0), OutputRegs);
4192   }
4193 
4194   MI.eraseFromParent();
4195   return Legalized;
4196 }
4197 
4198 LegalizerHelper::LegalizeResult
fewerElementsVectorUnmergeValues(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)4199 LegalizerHelper::fewerElementsVectorUnmergeValues(MachineInstr &MI,
4200                                                   unsigned TypeIdx,
4201                                                   LLT NarrowTy) {
4202   const int NumDst = MI.getNumOperands() - 1;
4203   const Register SrcReg = MI.getOperand(NumDst).getReg();
4204   LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
4205   LLT SrcTy = MRI.getType(SrcReg);
4206 
4207   if (TypeIdx != 1 || NarrowTy == DstTy)
4208     return UnableToLegalize;
4209 
4210   // Requires compatible types. Otherwise SrcReg should have been defined by
4211   // merge-like instruction that would get artifact combined. Most likely
4212   // instruction that defines SrcReg has to perform more/fewer elements
4213   // legalization compatible with NarrowTy.
4214   assert(SrcTy.isVector() && NarrowTy.isVector() && "Expected vector types");
4215   assert((SrcTy.getScalarType() == NarrowTy.getScalarType()) && "bad type");
4216 
4217   if ((SrcTy.getSizeInBits() % NarrowTy.getSizeInBits() != 0) ||
4218       (NarrowTy.getSizeInBits() % DstTy.getSizeInBits() != 0))
4219     return UnableToLegalize;
4220 
4221   // This is most likely DstTy (smaller then register size) packed in SrcTy
4222   // (larger then register size) and since unmerge was not combined it will be
4223   // lowered to bit sequence extracts from register. Unpack SrcTy to NarrowTy
4224   // (register size) pieces first. Then unpack each of NarrowTy pieces to DstTy.
4225 
4226   // %1:_(DstTy), %2, %3, %4 = G_UNMERGE_VALUES %0:_(SrcTy)
4227   //
4228   // %5:_(NarrowTy), %6 = G_UNMERGE_VALUES %0:_(SrcTy) - reg sequence
4229   // %1:_(DstTy), %2 = G_UNMERGE_VALUES %5:_(NarrowTy) - sequence of bits in reg
4230   // %3:_(DstTy), %4 = G_UNMERGE_VALUES %6:_(NarrowTy)
4231   auto Unmerge = MIRBuilder.buildUnmerge(NarrowTy, SrcReg);
4232   const int NumUnmerge = Unmerge->getNumOperands() - 1;
4233   const int PartsPerUnmerge = NumDst / NumUnmerge;
4234 
4235   for (int I = 0; I != NumUnmerge; ++I) {
4236     auto MIB = MIRBuilder.buildInstr(TargetOpcode::G_UNMERGE_VALUES);
4237 
4238     for (int J = 0; J != PartsPerUnmerge; ++J)
4239       MIB.addDef(MI.getOperand(I * PartsPerUnmerge + J).getReg());
4240     MIB.addUse(Unmerge.getReg(I));
4241   }
4242 
4243   MI.eraseFromParent();
4244   return Legalized;
4245 }
4246 
4247 LegalizerHelper::LegalizeResult
fewerElementsVectorMerge(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)4248 LegalizerHelper::fewerElementsVectorMerge(MachineInstr &MI, unsigned TypeIdx,
4249                                           LLT NarrowTy) {
4250   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
4251   // Requires compatible types. Otherwise user of DstReg did not perform unmerge
4252   // that should have been artifact combined. Most likely instruction that uses
4253   // DstReg has to do more/fewer elements legalization compatible with NarrowTy.
4254   assert(DstTy.isVector() && NarrowTy.isVector() && "Expected vector types");
4255   assert((DstTy.getScalarType() == NarrowTy.getScalarType()) && "bad type");
4256   if (NarrowTy == SrcTy)
4257     return UnableToLegalize;
4258 
4259   // This attempts to lower part of LCMTy merge/unmerge sequence. Intended use
4260   // is for old mir tests. Since the changes to more/fewer elements it should no
4261   // longer be possible to generate MIR like this when starting from llvm-ir
4262   // because LCMTy approach was replaced with merge/unmerge to vector elements.
4263   if (TypeIdx == 1) {
4264     assert(SrcTy.isVector() && "Expected vector types");
4265     assert((SrcTy.getScalarType() == NarrowTy.getScalarType()) && "bad type");
4266     if ((DstTy.getSizeInBits() % NarrowTy.getSizeInBits() != 0) ||
4267         (NarrowTy.getNumElements() >= SrcTy.getNumElements()))
4268       return UnableToLegalize;
4269     // %2:_(DstTy) = G_CONCAT_VECTORS %0:_(SrcTy), %1:_(SrcTy)
4270     //
4271     // %3:_(EltTy), %4, %5 = G_UNMERGE_VALUES %0:_(SrcTy)
4272     // %6:_(EltTy), %7, %8 = G_UNMERGE_VALUES %1:_(SrcTy)
4273     // %9:_(NarrowTy) = G_BUILD_VECTOR %3:_(EltTy), %4
4274     // %10:_(NarrowTy) = G_BUILD_VECTOR %5:_(EltTy), %6
4275     // %11:_(NarrowTy) = G_BUILD_VECTOR %7:_(EltTy), %8
4276     // %2:_(DstTy) = G_CONCAT_VECTORS %9:_(NarrowTy), %10, %11
4277 
4278     SmallVector<Register, 8> Elts;
4279     LLT EltTy = MRI.getType(MI.getOperand(1).getReg()).getScalarType();
4280     for (unsigned i = 1; i < MI.getNumOperands(); ++i) {
4281       auto Unmerge = MIRBuilder.buildUnmerge(EltTy, MI.getOperand(i).getReg());
4282       for (unsigned j = 0; j < Unmerge->getNumDefs(); ++j)
4283         Elts.push_back(Unmerge.getReg(j));
4284     }
4285 
4286     SmallVector<Register, 8> NarrowTyElts;
4287     unsigned NumNarrowTyElts = NarrowTy.getNumElements();
4288     unsigned NumNarrowTyPieces = DstTy.getNumElements() / NumNarrowTyElts;
4289     for (unsigned i = 0, Offset = 0; i < NumNarrowTyPieces;
4290          ++i, Offset += NumNarrowTyElts) {
4291       ArrayRef<Register> Pieces(&Elts[Offset], NumNarrowTyElts);
4292       NarrowTyElts.push_back(
4293           MIRBuilder.buildMergeLikeInstr(NarrowTy, Pieces).getReg(0));
4294     }
4295 
4296     MIRBuilder.buildMergeLikeInstr(DstReg, NarrowTyElts);
4297     MI.eraseFromParent();
4298     return Legalized;
4299   }
4300 
4301   assert(TypeIdx == 0 && "Bad type index");
4302   if ((NarrowTy.getSizeInBits() % SrcTy.getSizeInBits() != 0) ||
4303       (DstTy.getSizeInBits() % NarrowTy.getSizeInBits() != 0))
4304     return UnableToLegalize;
4305 
4306   // This is most likely SrcTy (smaller then register size) packed in DstTy
4307   // (larger then register size) and since merge was not combined it will be
4308   // lowered to bit sequence packing into register. Merge SrcTy to NarrowTy
4309   // (register size) pieces first. Then merge each of NarrowTy pieces to DstTy.
4310 
4311   // %0:_(DstTy) = G_MERGE_VALUES %1:_(SrcTy), %2, %3, %4
4312   //
4313   // %5:_(NarrowTy) = G_MERGE_VALUES %1:_(SrcTy), %2 - sequence of bits in reg
4314   // %6:_(NarrowTy) = G_MERGE_VALUES %3:_(SrcTy), %4
4315   // %0:_(DstTy)  = G_MERGE_VALUES %5:_(NarrowTy), %6 - reg sequence
4316   SmallVector<Register, 8> NarrowTyElts;
4317   unsigned NumParts = DstTy.getNumElements() / NarrowTy.getNumElements();
4318   unsigned NumSrcElts = SrcTy.isVector() ? SrcTy.getNumElements() : 1;
4319   unsigned NumElts = NarrowTy.getNumElements() / NumSrcElts;
4320   for (unsigned i = 0; i < NumParts; ++i) {
4321     SmallVector<Register, 8> Sources;
4322     for (unsigned j = 0; j < NumElts; ++j)
4323       Sources.push_back(MI.getOperand(1 + i * NumElts + j).getReg());
4324     NarrowTyElts.push_back(
4325         MIRBuilder.buildMergeLikeInstr(NarrowTy, Sources).getReg(0));
4326   }
4327 
4328   MIRBuilder.buildMergeLikeInstr(DstReg, NarrowTyElts);
4329   MI.eraseFromParent();
4330   return Legalized;
4331 }
4332 
4333 LegalizerHelper::LegalizeResult
fewerElementsVectorExtractInsertVectorElt(MachineInstr & MI,unsigned TypeIdx,LLT NarrowVecTy)4334 LegalizerHelper::fewerElementsVectorExtractInsertVectorElt(MachineInstr &MI,
4335                                                            unsigned TypeIdx,
4336                                                            LLT NarrowVecTy) {
4337   auto [DstReg, SrcVec] = MI.getFirst2Regs();
4338   Register InsertVal;
4339   bool IsInsert = MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT;
4340 
4341   assert((IsInsert ? TypeIdx == 0 : TypeIdx == 1) && "not a vector type index");
4342   if (IsInsert)
4343     InsertVal = MI.getOperand(2).getReg();
4344 
4345   Register Idx = MI.getOperand(MI.getNumOperands() - 1).getReg();
4346 
4347   // TODO: Handle total scalarization case.
4348   if (!NarrowVecTy.isVector())
4349     return UnableToLegalize;
4350 
4351   LLT VecTy = MRI.getType(SrcVec);
4352 
4353   // If the index is a constant, we can really break this down as you would
4354   // expect, and index into the target size pieces.
4355   int64_t IdxVal;
4356   auto MaybeCst = getIConstantVRegValWithLookThrough(Idx, MRI);
4357   if (MaybeCst) {
4358     IdxVal = MaybeCst->Value.getSExtValue();
4359     // Avoid out of bounds indexing the pieces.
4360     if (IdxVal >= VecTy.getNumElements()) {
4361       MIRBuilder.buildUndef(DstReg);
4362       MI.eraseFromParent();
4363       return Legalized;
4364     }
4365 
4366     SmallVector<Register, 8> VecParts;
4367     LLT GCDTy = extractGCDType(VecParts, VecTy, NarrowVecTy, SrcVec);
4368 
4369     // Build a sequence of NarrowTy pieces in VecParts for this operand.
4370     LLT LCMTy = buildLCMMergePieces(VecTy, NarrowVecTy, GCDTy, VecParts,
4371                                     TargetOpcode::G_ANYEXT);
4372 
4373     unsigned NewNumElts = NarrowVecTy.getNumElements();
4374 
4375     LLT IdxTy = MRI.getType(Idx);
4376     int64_t PartIdx = IdxVal / NewNumElts;
4377     auto NewIdx =
4378         MIRBuilder.buildConstant(IdxTy, IdxVal - NewNumElts * PartIdx);
4379 
4380     if (IsInsert) {
4381       LLT PartTy = MRI.getType(VecParts[PartIdx]);
4382 
4383       // Use the adjusted index to insert into one of the subvectors.
4384       auto InsertPart = MIRBuilder.buildInsertVectorElement(
4385           PartTy, VecParts[PartIdx], InsertVal, NewIdx);
4386       VecParts[PartIdx] = InsertPart.getReg(0);
4387 
4388       // Recombine the inserted subvector with the others to reform the result
4389       // vector.
4390       buildWidenedRemergeToDst(DstReg, LCMTy, VecParts);
4391     } else {
4392       MIRBuilder.buildExtractVectorElement(DstReg, VecParts[PartIdx], NewIdx);
4393     }
4394 
4395     MI.eraseFromParent();
4396     return Legalized;
4397   }
4398 
4399   // With a variable index, we can't perform the operation in a smaller type, so
4400   // we're forced to expand this.
4401   //
4402   // TODO: We could emit a chain of compare/select to figure out which piece to
4403   // index.
4404   return lowerExtractInsertVectorElt(MI);
4405 }
4406 
4407 LegalizerHelper::LegalizeResult
reduceLoadStoreWidth(GLoadStore & LdStMI,unsigned TypeIdx,LLT NarrowTy)4408 LegalizerHelper::reduceLoadStoreWidth(GLoadStore &LdStMI, unsigned TypeIdx,
4409                                       LLT NarrowTy) {
4410   // FIXME: Don't know how to handle secondary types yet.
4411   if (TypeIdx != 0)
4412     return UnableToLegalize;
4413 
4414   // This implementation doesn't work for atomics. Give up instead of doing
4415   // something invalid.
4416   if (LdStMI.isAtomic())
4417     return UnableToLegalize;
4418 
4419   bool IsLoad = isa<GLoad>(LdStMI);
4420   Register ValReg = LdStMI.getReg(0);
4421   Register AddrReg = LdStMI.getPointerReg();
4422   LLT ValTy = MRI.getType(ValReg);
4423 
4424   // FIXME: Do we need a distinct NarrowMemory legalize action?
4425   if (ValTy.getSizeInBits() != 8 * LdStMI.getMemSize()) {
4426     LLVM_DEBUG(dbgs() << "Can't narrow extload/truncstore\n");
4427     return UnableToLegalize;
4428   }
4429 
4430   int NumParts = -1;
4431   int NumLeftover = -1;
4432   LLT LeftoverTy;
4433   SmallVector<Register, 8> NarrowRegs, NarrowLeftoverRegs;
4434   if (IsLoad) {
4435     std::tie(NumParts, NumLeftover) = getNarrowTypeBreakDown(ValTy, NarrowTy, LeftoverTy);
4436   } else {
4437     if (extractParts(ValReg, ValTy, NarrowTy, LeftoverTy, NarrowRegs,
4438                      NarrowLeftoverRegs, MIRBuilder, MRI)) {
4439       NumParts = NarrowRegs.size();
4440       NumLeftover = NarrowLeftoverRegs.size();
4441     }
4442   }
4443 
4444   if (NumParts == -1)
4445     return UnableToLegalize;
4446 
4447   LLT PtrTy = MRI.getType(AddrReg);
4448   const LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
4449 
4450   unsigned TotalSize = ValTy.getSizeInBits();
4451 
4452   // Split the load/store into PartTy sized pieces starting at Offset. If this
4453   // is a load, return the new registers in ValRegs. For a store, each elements
4454   // of ValRegs should be PartTy. Returns the next offset that needs to be
4455   // handled.
4456   bool isBigEndian = MIRBuilder.getDataLayout().isBigEndian();
4457   auto MMO = LdStMI.getMMO();
4458   auto splitTypePieces = [=](LLT PartTy, SmallVectorImpl<Register> &ValRegs,
4459                              unsigned NumParts, unsigned Offset) -> unsigned {
4460     MachineFunction &MF = MIRBuilder.getMF();
4461     unsigned PartSize = PartTy.getSizeInBits();
4462     for (unsigned Idx = 0, E = NumParts; Idx != E && Offset < TotalSize;
4463          ++Idx) {
4464       unsigned ByteOffset = Offset / 8;
4465       Register NewAddrReg;
4466 
4467       MIRBuilder.materializePtrAdd(NewAddrReg, AddrReg, OffsetTy, ByteOffset);
4468 
4469       MachineMemOperand *NewMMO =
4470           MF.getMachineMemOperand(&MMO, ByteOffset, PartTy);
4471 
4472       if (IsLoad) {
4473         Register Dst = MRI.createGenericVirtualRegister(PartTy);
4474         ValRegs.push_back(Dst);
4475         MIRBuilder.buildLoad(Dst, NewAddrReg, *NewMMO);
4476       } else {
4477         MIRBuilder.buildStore(ValRegs[Idx], NewAddrReg, *NewMMO);
4478       }
4479       Offset = isBigEndian ? Offset - PartSize : Offset + PartSize;
4480     }
4481 
4482     return Offset;
4483   };
4484 
4485   unsigned Offset = isBigEndian ? TotalSize - NarrowTy.getSizeInBits() : 0;
4486   unsigned HandledOffset =
4487       splitTypePieces(NarrowTy, NarrowRegs, NumParts, Offset);
4488 
4489   // Handle the rest of the register if this isn't an even type breakdown.
4490   if (LeftoverTy.isValid())
4491     splitTypePieces(LeftoverTy, NarrowLeftoverRegs, NumLeftover, HandledOffset);
4492 
4493   if (IsLoad) {
4494     insertParts(ValReg, ValTy, NarrowTy, NarrowRegs,
4495                 LeftoverTy, NarrowLeftoverRegs);
4496   }
4497 
4498   LdStMI.eraseFromParent();
4499   return Legalized;
4500 }
4501 
4502 LegalizerHelper::LegalizeResult
fewerElementsVector(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)4503 LegalizerHelper::fewerElementsVector(MachineInstr &MI, unsigned TypeIdx,
4504                                      LLT NarrowTy) {
4505   using namespace TargetOpcode;
4506   GenericMachineInstr &GMI = cast<GenericMachineInstr>(MI);
4507   unsigned NumElts = NarrowTy.isVector() ? NarrowTy.getNumElements() : 1;
4508 
4509   switch (MI.getOpcode()) {
4510   case G_IMPLICIT_DEF:
4511   case G_TRUNC:
4512   case G_AND:
4513   case G_OR:
4514   case G_XOR:
4515   case G_ADD:
4516   case G_SUB:
4517   case G_MUL:
4518   case G_PTR_ADD:
4519   case G_SMULH:
4520   case G_UMULH:
4521   case G_FADD:
4522   case G_FMUL:
4523   case G_FSUB:
4524   case G_FNEG:
4525   case G_FABS:
4526   case G_FCANONICALIZE:
4527   case G_FDIV:
4528   case G_FREM:
4529   case G_FMA:
4530   case G_FMAD:
4531   case G_FPOW:
4532   case G_FEXP:
4533   case G_FEXP2:
4534   case G_FEXP10:
4535   case G_FLOG:
4536   case G_FLOG2:
4537   case G_FLOG10:
4538   case G_FLDEXP:
4539   case G_FNEARBYINT:
4540   case G_FCEIL:
4541   case G_FFLOOR:
4542   case G_FRINT:
4543   case G_INTRINSIC_ROUND:
4544   case G_INTRINSIC_ROUNDEVEN:
4545   case G_INTRINSIC_TRUNC:
4546   case G_FCOS:
4547   case G_FSIN:
4548   case G_FSQRT:
4549   case G_BSWAP:
4550   case G_BITREVERSE:
4551   case G_SDIV:
4552   case G_UDIV:
4553   case G_SREM:
4554   case G_UREM:
4555   case G_SDIVREM:
4556   case G_UDIVREM:
4557   case G_SMIN:
4558   case G_SMAX:
4559   case G_UMIN:
4560   case G_UMAX:
4561   case G_ABS:
4562   case G_FMINNUM:
4563   case G_FMAXNUM:
4564   case G_FMINNUM_IEEE:
4565   case G_FMAXNUM_IEEE:
4566   case G_FMINIMUM:
4567   case G_FMAXIMUM:
4568   case G_FSHL:
4569   case G_FSHR:
4570   case G_ROTL:
4571   case G_ROTR:
4572   case G_FREEZE:
4573   case G_SADDSAT:
4574   case G_SSUBSAT:
4575   case G_UADDSAT:
4576   case G_USUBSAT:
4577   case G_UMULO:
4578   case G_SMULO:
4579   case G_SHL:
4580   case G_LSHR:
4581   case G_ASHR:
4582   case G_SSHLSAT:
4583   case G_USHLSAT:
4584   case G_CTLZ:
4585   case G_CTLZ_ZERO_UNDEF:
4586   case G_CTTZ:
4587   case G_CTTZ_ZERO_UNDEF:
4588   case G_CTPOP:
4589   case G_FCOPYSIGN:
4590   case G_ZEXT:
4591   case G_SEXT:
4592   case G_ANYEXT:
4593   case G_FPEXT:
4594   case G_FPTRUNC:
4595   case G_SITOFP:
4596   case G_UITOFP:
4597   case G_FPTOSI:
4598   case G_FPTOUI:
4599   case G_INTTOPTR:
4600   case G_PTRTOINT:
4601   case G_ADDRSPACE_CAST:
4602   case G_UADDO:
4603   case G_USUBO:
4604   case G_UADDE:
4605   case G_USUBE:
4606   case G_SADDO:
4607   case G_SSUBO:
4608   case G_SADDE:
4609   case G_SSUBE:
4610   case G_STRICT_FADD:
4611   case G_STRICT_FSUB:
4612   case G_STRICT_FMUL:
4613   case G_STRICT_FMA:
4614   case G_STRICT_FLDEXP:
4615   case G_FFREXP:
4616     return fewerElementsVectorMultiEltType(GMI, NumElts);
4617   case G_ICMP:
4618   case G_FCMP:
4619     return fewerElementsVectorMultiEltType(GMI, NumElts, {1 /*cpm predicate*/});
4620   case G_IS_FPCLASS:
4621     return fewerElementsVectorMultiEltType(GMI, NumElts, {2, 3 /*mask,fpsem*/});
4622   case G_SELECT:
4623     if (MRI.getType(MI.getOperand(1).getReg()).isVector())
4624       return fewerElementsVectorMultiEltType(GMI, NumElts);
4625     return fewerElementsVectorMultiEltType(GMI, NumElts, {1 /*scalar cond*/});
4626   case G_PHI:
4627     return fewerElementsVectorPhi(GMI, NumElts);
4628   case G_UNMERGE_VALUES:
4629     return fewerElementsVectorUnmergeValues(MI, TypeIdx, NarrowTy);
4630   case G_BUILD_VECTOR:
4631     assert(TypeIdx == 0 && "not a vector type index");
4632     return fewerElementsVectorMerge(MI, TypeIdx, NarrowTy);
4633   case G_CONCAT_VECTORS:
4634     if (TypeIdx != 1) // TODO: This probably does work as expected already.
4635       return UnableToLegalize;
4636     return fewerElementsVectorMerge(MI, TypeIdx, NarrowTy);
4637   case G_EXTRACT_VECTOR_ELT:
4638   case G_INSERT_VECTOR_ELT:
4639     return fewerElementsVectorExtractInsertVectorElt(MI, TypeIdx, NarrowTy);
4640   case G_LOAD:
4641   case G_STORE:
4642     return reduceLoadStoreWidth(cast<GLoadStore>(MI), TypeIdx, NarrowTy);
4643   case G_SEXT_INREG:
4644     return fewerElementsVectorMultiEltType(GMI, NumElts, {2 /*imm*/});
4645   GISEL_VECREDUCE_CASES_NONSEQ
4646     return fewerElementsVectorReductions(MI, TypeIdx, NarrowTy);
4647   case TargetOpcode::G_VECREDUCE_SEQ_FADD:
4648   case TargetOpcode::G_VECREDUCE_SEQ_FMUL:
4649     return fewerElementsVectorSeqReductions(MI, TypeIdx, NarrowTy);
4650   case G_SHUFFLE_VECTOR:
4651     return fewerElementsVectorShuffle(MI, TypeIdx, NarrowTy);
4652   case G_FPOWI:
4653     return fewerElementsVectorMultiEltType(GMI, NumElts, {2 /*pow*/});
4654   default:
4655     return UnableToLegalize;
4656   }
4657 }
4658 
fewerElementsVectorShuffle(MachineInstr & MI,unsigned int TypeIdx,LLT NarrowTy)4659 LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorShuffle(
4660     MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
4661   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
4662   if (TypeIdx != 0)
4663     return UnableToLegalize;
4664 
4665   auto [DstReg, DstTy, Src1Reg, Src1Ty, Src2Reg, Src2Ty] =
4666       MI.getFirst3RegLLTs();
4667   ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
4668   // The shuffle should be canonicalized by now.
4669   if (DstTy != Src1Ty)
4670     return UnableToLegalize;
4671   if (DstTy != Src2Ty)
4672     return UnableToLegalize;
4673 
4674   if (!isPowerOf2_32(DstTy.getNumElements()))
4675     return UnableToLegalize;
4676 
4677   // We only support splitting a shuffle into 2, so adjust NarrowTy accordingly.
4678   // Further legalization attempts will be needed to do split further.
4679   NarrowTy =
4680       DstTy.changeElementCount(DstTy.getElementCount().divideCoefficientBy(2));
4681   unsigned NewElts = NarrowTy.getNumElements();
4682 
4683   SmallVector<Register> SplitSrc1Regs, SplitSrc2Regs;
4684   extractParts(Src1Reg, NarrowTy, 2, SplitSrc1Regs, MIRBuilder, MRI);
4685   extractParts(Src2Reg, NarrowTy, 2, SplitSrc2Regs, MIRBuilder, MRI);
4686   Register Inputs[4] = {SplitSrc1Regs[0], SplitSrc1Regs[1], SplitSrc2Regs[0],
4687                         SplitSrc2Regs[1]};
4688 
4689   Register Hi, Lo;
4690 
4691   // If Lo or Hi uses elements from at most two of the four input vectors, then
4692   // express it as a vector shuffle of those two inputs.  Otherwise extract the
4693   // input elements by hand and construct the Lo/Hi output using a BUILD_VECTOR.
4694   SmallVector<int, 16> Ops;
4695   for (unsigned High = 0; High < 2; ++High) {
4696     Register &Output = High ? Hi : Lo;
4697 
4698     // Build a shuffle mask for the output, discovering on the fly which
4699     // input vectors to use as shuffle operands (recorded in InputUsed).
4700     // If building a suitable shuffle vector proves too hard, then bail
4701     // out with useBuildVector set.
4702     unsigned InputUsed[2] = {-1U, -1U}; // Not yet discovered.
4703     unsigned FirstMaskIdx = High * NewElts;
4704     bool UseBuildVector = false;
4705     for (unsigned MaskOffset = 0; MaskOffset < NewElts; ++MaskOffset) {
4706       // The mask element.  This indexes into the input.
4707       int Idx = Mask[FirstMaskIdx + MaskOffset];
4708 
4709       // The input vector this mask element indexes into.
4710       unsigned Input = (unsigned)Idx / NewElts;
4711 
4712       if (Input >= std::size(Inputs)) {
4713         // The mask element does not index into any input vector.
4714         Ops.push_back(-1);
4715         continue;
4716       }
4717 
4718       // Turn the index into an offset from the start of the input vector.
4719       Idx -= Input * NewElts;
4720 
4721       // Find or create a shuffle vector operand to hold this input.
4722       unsigned OpNo;
4723       for (OpNo = 0; OpNo < std::size(InputUsed); ++OpNo) {
4724         if (InputUsed[OpNo] == Input) {
4725           // This input vector is already an operand.
4726           break;
4727         } else if (InputUsed[OpNo] == -1U) {
4728           // Create a new operand for this input vector.
4729           InputUsed[OpNo] = Input;
4730           break;
4731         }
4732       }
4733 
4734       if (OpNo >= std::size(InputUsed)) {
4735         // More than two input vectors used!  Give up on trying to create a
4736         // shuffle vector.  Insert all elements into a BUILD_VECTOR instead.
4737         UseBuildVector = true;
4738         break;
4739       }
4740 
4741       // Add the mask index for the new shuffle vector.
4742       Ops.push_back(Idx + OpNo * NewElts);
4743     }
4744 
4745     if (UseBuildVector) {
4746       LLT EltTy = NarrowTy.getElementType();
4747       SmallVector<Register, 16> SVOps;
4748 
4749       // Extract the input elements by hand.
4750       for (unsigned MaskOffset = 0; MaskOffset < NewElts; ++MaskOffset) {
4751         // The mask element.  This indexes into the input.
4752         int Idx = Mask[FirstMaskIdx + MaskOffset];
4753 
4754         // The input vector this mask element indexes into.
4755         unsigned Input = (unsigned)Idx / NewElts;
4756 
4757         if (Input >= std::size(Inputs)) {
4758           // The mask element is "undef" or indexes off the end of the input.
4759           SVOps.push_back(MIRBuilder.buildUndef(EltTy).getReg(0));
4760           continue;
4761         }
4762 
4763         // Turn the index into an offset from the start of the input vector.
4764         Idx -= Input * NewElts;
4765 
4766         // Extract the vector element by hand.
4767         SVOps.push_back(MIRBuilder
4768                             .buildExtractVectorElement(
4769                                 EltTy, Inputs[Input],
4770                                 MIRBuilder.buildConstant(LLT::scalar(32), Idx))
4771                             .getReg(0));
4772       }
4773 
4774       // Construct the Lo/Hi output using a G_BUILD_VECTOR.
4775       Output = MIRBuilder.buildBuildVector(NarrowTy, SVOps).getReg(0);
4776     } else if (InputUsed[0] == -1U) {
4777       // No input vectors were used! The result is undefined.
4778       Output = MIRBuilder.buildUndef(NarrowTy).getReg(0);
4779     } else {
4780       Register Op0 = Inputs[InputUsed[0]];
4781       // If only one input was used, use an undefined vector for the other.
4782       Register Op1 = InputUsed[1] == -1U
4783                          ? MIRBuilder.buildUndef(NarrowTy).getReg(0)
4784                          : Inputs[InputUsed[1]];
4785       // At least one input vector was used. Create a new shuffle vector.
4786       Output = MIRBuilder.buildShuffleVector(NarrowTy, Op0, Op1, Ops).getReg(0);
4787     }
4788 
4789     Ops.clear();
4790   }
4791 
4792   MIRBuilder.buildConcatVectors(DstReg, {Lo, Hi});
4793   MI.eraseFromParent();
4794   return Legalized;
4795 }
4796 
fewerElementsVectorReductions(MachineInstr & MI,unsigned int TypeIdx,LLT NarrowTy)4797 LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
4798     MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
4799   auto &RdxMI = cast<GVecReduce>(MI);
4800 
4801   if (TypeIdx != 1)
4802     return UnableToLegalize;
4803 
4804   // The semantics of the normal non-sequential reductions allow us to freely
4805   // re-associate the operation.
4806   auto [DstReg, DstTy, SrcReg, SrcTy] = RdxMI.getFirst2RegLLTs();
4807 
4808   if (NarrowTy.isVector() &&
4809       (SrcTy.getNumElements() % NarrowTy.getNumElements() != 0))
4810     return UnableToLegalize;
4811 
4812   unsigned ScalarOpc = RdxMI.getScalarOpcForReduction();
4813   SmallVector<Register> SplitSrcs;
4814   // If NarrowTy is a scalar then we're being asked to scalarize.
4815   const unsigned NumParts =
4816       NarrowTy.isVector() ? SrcTy.getNumElements() / NarrowTy.getNumElements()
4817                           : SrcTy.getNumElements();
4818 
4819   extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs, MIRBuilder, MRI);
4820   if (NarrowTy.isScalar()) {
4821     if (DstTy != NarrowTy)
4822       return UnableToLegalize; // FIXME: handle implicit extensions.
4823 
4824     if (isPowerOf2_32(NumParts)) {
4825       // Generate a tree of scalar operations to reduce the critical path.
4826       SmallVector<Register> PartialResults;
4827       unsigned NumPartsLeft = NumParts;
4828       while (NumPartsLeft > 1) {
4829         for (unsigned Idx = 0; Idx < NumPartsLeft - 1; Idx += 2) {
4830           PartialResults.emplace_back(
4831               MIRBuilder
4832                   .buildInstr(ScalarOpc, {NarrowTy},
4833                               {SplitSrcs[Idx], SplitSrcs[Idx + 1]})
4834                   .getReg(0));
4835         }
4836         SplitSrcs = PartialResults;
4837         PartialResults.clear();
4838         NumPartsLeft = SplitSrcs.size();
4839       }
4840       assert(SplitSrcs.size() == 1);
4841       MIRBuilder.buildCopy(DstReg, SplitSrcs[0]);
4842       MI.eraseFromParent();
4843       return Legalized;
4844     }
4845     // If we can't generate a tree, then just do sequential operations.
4846     Register Acc = SplitSrcs[0];
4847     for (unsigned Idx = 1; Idx < NumParts; ++Idx)
4848       Acc = MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {Acc, SplitSrcs[Idx]})
4849                 .getReg(0);
4850     MIRBuilder.buildCopy(DstReg, Acc);
4851     MI.eraseFromParent();
4852     return Legalized;
4853   }
4854   SmallVector<Register> PartialReductions;
4855   for (unsigned Part = 0; Part < NumParts; ++Part) {
4856     PartialReductions.push_back(
4857         MIRBuilder.buildInstr(RdxMI.getOpcode(), {DstTy}, {SplitSrcs[Part]})
4858             .getReg(0));
4859   }
4860 
4861   // If the types involved are powers of 2, we can generate intermediate vector
4862   // ops, before generating a final reduction operation.
4863   if (isPowerOf2_32(SrcTy.getNumElements()) &&
4864       isPowerOf2_32(NarrowTy.getNumElements())) {
4865     return tryNarrowPow2Reduction(MI, SrcReg, SrcTy, NarrowTy, ScalarOpc);
4866   }
4867 
4868   Register Acc = PartialReductions[0];
4869   for (unsigned Part = 1; Part < NumParts; ++Part) {
4870     if (Part == NumParts - 1) {
4871       MIRBuilder.buildInstr(ScalarOpc, {DstReg},
4872                             {Acc, PartialReductions[Part]});
4873     } else {
4874       Acc = MIRBuilder
4875                 .buildInstr(ScalarOpc, {DstTy}, {Acc, PartialReductions[Part]})
4876                 .getReg(0);
4877     }
4878   }
4879   MI.eraseFromParent();
4880   return Legalized;
4881 }
4882 
4883 LegalizerHelper::LegalizeResult
fewerElementsVectorSeqReductions(MachineInstr & MI,unsigned int TypeIdx,LLT NarrowTy)4884 LegalizerHelper::fewerElementsVectorSeqReductions(MachineInstr &MI,
4885                                                   unsigned int TypeIdx,
4886                                                   LLT NarrowTy) {
4887   auto [DstReg, DstTy, ScalarReg, ScalarTy, SrcReg, SrcTy] =
4888       MI.getFirst3RegLLTs();
4889   if (!NarrowTy.isScalar() || TypeIdx != 2 || DstTy != ScalarTy ||
4890       DstTy != NarrowTy)
4891     return UnableToLegalize;
4892 
4893   assert((MI.getOpcode() == TargetOpcode::G_VECREDUCE_SEQ_FADD ||
4894           MI.getOpcode() == TargetOpcode::G_VECREDUCE_SEQ_FMUL) &&
4895          "Unexpected vecreduce opcode");
4896   unsigned ScalarOpc = MI.getOpcode() == TargetOpcode::G_VECREDUCE_SEQ_FADD
4897                            ? TargetOpcode::G_FADD
4898                            : TargetOpcode::G_FMUL;
4899 
4900   SmallVector<Register> SplitSrcs;
4901   unsigned NumParts = SrcTy.getNumElements();
4902   extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs, MIRBuilder, MRI);
4903   Register Acc = ScalarReg;
4904   for (unsigned i = 0; i < NumParts; i++)
4905     Acc = MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {Acc, SplitSrcs[i]})
4906               .getReg(0);
4907 
4908   MIRBuilder.buildCopy(DstReg, Acc);
4909   MI.eraseFromParent();
4910   return Legalized;
4911 }
4912 
4913 LegalizerHelper::LegalizeResult
tryNarrowPow2Reduction(MachineInstr & MI,Register SrcReg,LLT SrcTy,LLT NarrowTy,unsigned ScalarOpc)4914 LegalizerHelper::tryNarrowPow2Reduction(MachineInstr &MI, Register SrcReg,
4915                                         LLT SrcTy, LLT NarrowTy,
4916                                         unsigned ScalarOpc) {
4917   SmallVector<Register> SplitSrcs;
4918   // Split the sources into NarrowTy size pieces.
4919   extractParts(SrcReg, NarrowTy,
4920                SrcTy.getNumElements() / NarrowTy.getNumElements(), SplitSrcs,
4921                MIRBuilder, MRI);
4922   // We're going to do a tree reduction using vector operations until we have
4923   // one NarrowTy size value left.
4924   while (SplitSrcs.size() > 1) {
4925     SmallVector<Register> PartialRdxs;
4926     for (unsigned Idx = 0; Idx < SplitSrcs.size()-1; Idx += 2) {
4927       Register LHS = SplitSrcs[Idx];
4928       Register RHS = SplitSrcs[Idx + 1];
4929       // Create the intermediate vector op.
4930       Register Res =
4931           MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {LHS, RHS}).getReg(0);
4932       PartialRdxs.push_back(Res);
4933     }
4934     SplitSrcs = std::move(PartialRdxs);
4935   }
4936   // Finally generate the requested NarrowTy based reduction.
4937   Observer.changingInstr(MI);
4938   MI.getOperand(1).setReg(SplitSrcs[0]);
4939   Observer.changedInstr(MI);
4940   return Legalized;
4941 }
4942 
4943 LegalizerHelper::LegalizeResult
narrowScalarShiftByConstant(MachineInstr & MI,const APInt & Amt,const LLT HalfTy,const LLT AmtTy)4944 LegalizerHelper::narrowScalarShiftByConstant(MachineInstr &MI, const APInt &Amt,
4945                                              const LLT HalfTy, const LLT AmtTy) {
4946 
4947   Register InL = MRI.createGenericVirtualRegister(HalfTy);
4948   Register InH = MRI.createGenericVirtualRegister(HalfTy);
4949   MIRBuilder.buildUnmerge({InL, InH}, MI.getOperand(1));
4950 
4951   if (Amt.isZero()) {
4952     MIRBuilder.buildMergeLikeInstr(MI.getOperand(0), {InL, InH});
4953     MI.eraseFromParent();
4954     return Legalized;
4955   }
4956 
4957   LLT NVT = HalfTy;
4958   unsigned NVTBits = HalfTy.getSizeInBits();
4959   unsigned VTBits = 2 * NVTBits;
4960 
4961   SrcOp Lo(Register(0)), Hi(Register(0));
4962   if (MI.getOpcode() == TargetOpcode::G_SHL) {
4963     if (Amt.ugt(VTBits)) {
4964       Lo = Hi = MIRBuilder.buildConstant(NVT, 0);
4965     } else if (Amt.ugt(NVTBits)) {
4966       Lo = MIRBuilder.buildConstant(NVT, 0);
4967       Hi = MIRBuilder.buildShl(NVT, InL,
4968                                MIRBuilder.buildConstant(AmtTy, Amt - NVTBits));
4969     } else if (Amt == NVTBits) {
4970       Lo = MIRBuilder.buildConstant(NVT, 0);
4971       Hi = InL;
4972     } else {
4973       Lo = MIRBuilder.buildShl(NVT, InL, MIRBuilder.buildConstant(AmtTy, Amt));
4974       auto OrLHS =
4975           MIRBuilder.buildShl(NVT, InH, MIRBuilder.buildConstant(AmtTy, Amt));
4976       auto OrRHS = MIRBuilder.buildLShr(
4977           NVT, InL, MIRBuilder.buildConstant(AmtTy, -Amt + NVTBits));
4978       Hi = MIRBuilder.buildOr(NVT, OrLHS, OrRHS);
4979     }
4980   } else if (MI.getOpcode() == TargetOpcode::G_LSHR) {
4981     if (Amt.ugt(VTBits)) {
4982       Lo = Hi = MIRBuilder.buildConstant(NVT, 0);
4983     } else if (Amt.ugt(NVTBits)) {
4984       Lo = MIRBuilder.buildLShr(NVT, InH,
4985                                 MIRBuilder.buildConstant(AmtTy, Amt - NVTBits));
4986       Hi = MIRBuilder.buildConstant(NVT, 0);
4987     } else if (Amt == NVTBits) {
4988       Lo = InH;
4989       Hi = MIRBuilder.buildConstant(NVT, 0);
4990     } else {
4991       auto ShiftAmtConst = MIRBuilder.buildConstant(AmtTy, Amt);
4992 
4993       auto OrLHS = MIRBuilder.buildLShr(NVT, InL, ShiftAmtConst);
4994       auto OrRHS = MIRBuilder.buildShl(
4995           NVT, InH, MIRBuilder.buildConstant(AmtTy, -Amt + NVTBits));
4996 
4997       Lo = MIRBuilder.buildOr(NVT, OrLHS, OrRHS);
4998       Hi = MIRBuilder.buildLShr(NVT, InH, ShiftAmtConst);
4999     }
5000   } else {
5001     if (Amt.ugt(VTBits)) {
5002       Hi = Lo = MIRBuilder.buildAShr(
5003           NVT, InH, MIRBuilder.buildConstant(AmtTy, NVTBits - 1));
5004     } else if (Amt.ugt(NVTBits)) {
5005       Lo = MIRBuilder.buildAShr(NVT, InH,
5006                                 MIRBuilder.buildConstant(AmtTy, Amt - NVTBits));
5007       Hi = MIRBuilder.buildAShr(NVT, InH,
5008                                 MIRBuilder.buildConstant(AmtTy, NVTBits - 1));
5009     } else if (Amt == NVTBits) {
5010       Lo = InH;
5011       Hi = MIRBuilder.buildAShr(NVT, InH,
5012                                 MIRBuilder.buildConstant(AmtTy, NVTBits - 1));
5013     } else {
5014       auto ShiftAmtConst = MIRBuilder.buildConstant(AmtTy, Amt);
5015 
5016       auto OrLHS = MIRBuilder.buildLShr(NVT, InL, ShiftAmtConst);
5017       auto OrRHS = MIRBuilder.buildShl(
5018           NVT, InH, MIRBuilder.buildConstant(AmtTy, -Amt + NVTBits));
5019 
5020       Lo = MIRBuilder.buildOr(NVT, OrLHS, OrRHS);
5021       Hi = MIRBuilder.buildAShr(NVT, InH, ShiftAmtConst);
5022     }
5023   }
5024 
5025   MIRBuilder.buildMergeLikeInstr(MI.getOperand(0), {Lo, Hi});
5026   MI.eraseFromParent();
5027 
5028   return Legalized;
5029 }
5030 
5031 // TODO: Optimize if constant shift amount.
5032 LegalizerHelper::LegalizeResult
narrowScalarShift(MachineInstr & MI,unsigned TypeIdx,LLT RequestedTy)5033 LegalizerHelper::narrowScalarShift(MachineInstr &MI, unsigned TypeIdx,
5034                                    LLT RequestedTy) {
5035   if (TypeIdx == 1) {
5036     Observer.changingInstr(MI);
5037     narrowScalarSrc(MI, RequestedTy, 2);
5038     Observer.changedInstr(MI);
5039     return Legalized;
5040   }
5041 
5042   Register DstReg = MI.getOperand(0).getReg();
5043   LLT DstTy = MRI.getType(DstReg);
5044   if (DstTy.isVector())
5045     return UnableToLegalize;
5046 
5047   Register Amt = MI.getOperand(2).getReg();
5048   LLT ShiftAmtTy = MRI.getType(Amt);
5049   const unsigned DstEltSize = DstTy.getScalarSizeInBits();
5050   if (DstEltSize % 2 != 0)
5051     return UnableToLegalize;
5052 
5053   // Ignore the input type. We can only go to exactly half the size of the
5054   // input. If that isn't small enough, the resulting pieces will be further
5055   // legalized.
5056   const unsigned NewBitSize = DstEltSize / 2;
5057   const LLT HalfTy = LLT::scalar(NewBitSize);
5058   const LLT CondTy = LLT::scalar(1);
5059 
5060   if (auto VRegAndVal = getIConstantVRegValWithLookThrough(Amt, MRI)) {
5061     return narrowScalarShiftByConstant(MI, VRegAndVal->Value, HalfTy,
5062                                        ShiftAmtTy);
5063   }
5064 
5065   // TODO: Expand with known bits.
5066 
5067   // Handle the fully general expansion by an unknown amount.
5068   auto NewBits = MIRBuilder.buildConstant(ShiftAmtTy, NewBitSize);
5069 
5070   Register InL = MRI.createGenericVirtualRegister(HalfTy);
5071   Register InH = MRI.createGenericVirtualRegister(HalfTy);
5072   MIRBuilder.buildUnmerge({InL, InH}, MI.getOperand(1));
5073 
5074   auto AmtExcess = MIRBuilder.buildSub(ShiftAmtTy, Amt, NewBits);
5075   auto AmtLack = MIRBuilder.buildSub(ShiftAmtTy, NewBits, Amt);
5076 
5077   auto Zero = MIRBuilder.buildConstant(ShiftAmtTy, 0);
5078   auto IsShort = MIRBuilder.buildICmp(ICmpInst::ICMP_ULT, CondTy, Amt, NewBits);
5079   auto IsZero = MIRBuilder.buildICmp(ICmpInst::ICMP_EQ, CondTy, Amt, Zero);
5080 
5081   Register ResultRegs[2];
5082   switch (MI.getOpcode()) {
5083   case TargetOpcode::G_SHL: {
5084     // Short: ShAmt < NewBitSize
5085     auto LoS = MIRBuilder.buildShl(HalfTy, InL, Amt);
5086 
5087     auto LoOr = MIRBuilder.buildLShr(HalfTy, InL, AmtLack);
5088     auto HiOr = MIRBuilder.buildShl(HalfTy, InH, Amt);
5089     auto HiS = MIRBuilder.buildOr(HalfTy, LoOr, HiOr);
5090 
5091     // Long: ShAmt >= NewBitSize
5092     auto LoL = MIRBuilder.buildConstant(HalfTy, 0);         // Lo part is zero.
5093     auto HiL = MIRBuilder.buildShl(HalfTy, InL, AmtExcess); // Hi from Lo part.
5094 
5095     auto Lo = MIRBuilder.buildSelect(HalfTy, IsShort, LoS, LoL);
5096     auto Hi = MIRBuilder.buildSelect(
5097         HalfTy, IsZero, InH, MIRBuilder.buildSelect(HalfTy, IsShort, HiS, HiL));
5098 
5099     ResultRegs[0] = Lo.getReg(0);
5100     ResultRegs[1] = Hi.getReg(0);
5101     break;
5102   }
5103   case TargetOpcode::G_LSHR:
5104   case TargetOpcode::G_ASHR: {
5105     // Short: ShAmt < NewBitSize
5106     auto HiS = MIRBuilder.buildInstr(MI.getOpcode(), {HalfTy}, {InH, Amt});
5107 
5108     auto LoOr = MIRBuilder.buildLShr(HalfTy, InL, Amt);
5109     auto HiOr = MIRBuilder.buildShl(HalfTy, InH, AmtLack);
5110     auto LoS = MIRBuilder.buildOr(HalfTy, LoOr, HiOr);
5111 
5112     // Long: ShAmt >= NewBitSize
5113     MachineInstrBuilder HiL;
5114     if (MI.getOpcode() == TargetOpcode::G_LSHR) {
5115       HiL = MIRBuilder.buildConstant(HalfTy, 0);            // Hi part is zero.
5116     } else {
5117       auto ShiftAmt = MIRBuilder.buildConstant(ShiftAmtTy, NewBitSize - 1);
5118       HiL = MIRBuilder.buildAShr(HalfTy, InH, ShiftAmt);    // Sign of Hi part.
5119     }
5120     auto LoL = MIRBuilder.buildInstr(MI.getOpcode(), {HalfTy},
5121                                      {InH, AmtExcess});     // Lo from Hi part.
5122 
5123     auto Lo = MIRBuilder.buildSelect(
5124         HalfTy, IsZero, InL, MIRBuilder.buildSelect(HalfTy, IsShort, LoS, LoL));
5125 
5126     auto Hi = MIRBuilder.buildSelect(HalfTy, IsShort, HiS, HiL);
5127 
5128     ResultRegs[0] = Lo.getReg(0);
5129     ResultRegs[1] = Hi.getReg(0);
5130     break;
5131   }
5132   default:
5133     llvm_unreachable("not a shift");
5134   }
5135 
5136   MIRBuilder.buildMergeLikeInstr(DstReg, ResultRegs);
5137   MI.eraseFromParent();
5138   return Legalized;
5139 }
5140 
5141 LegalizerHelper::LegalizeResult
moreElementsVectorPhi(MachineInstr & MI,unsigned TypeIdx,LLT MoreTy)5142 LegalizerHelper::moreElementsVectorPhi(MachineInstr &MI, unsigned TypeIdx,
5143                                        LLT MoreTy) {
5144   assert(TypeIdx == 0 && "Expecting only Idx 0");
5145 
5146   Observer.changingInstr(MI);
5147   for (unsigned I = 1, E = MI.getNumOperands(); I != E; I += 2) {
5148     MachineBasicBlock &OpMBB = *MI.getOperand(I + 1).getMBB();
5149     MIRBuilder.setInsertPt(OpMBB, OpMBB.getFirstTerminator());
5150     moreElementsVectorSrc(MI, MoreTy, I);
5151   }
5152 
5153   MachineBasicBlock &MBB = *MI.getParent();
5154   MIRBuilder.setInsertPt(MBB, --MBB.getFirstNonPHI());
5155   moreElementsVectorDst(MI, MoreTy, 0);
5156   Observer.changedInstr(MI);
5157   return Legalized;
5158 }
5159 
5160 LegalizerHelper::LegalizeResult
moreElementsVector(MachineInstr & MI,unsigned TypeIdx,LLT MoreTy)5161 LegalizerHelper::moreElementsVector(MachineInstr &MI, unsigned TypeIdx,
5162                                     LLT MoreTy) {
5163   unsigned Opc = MI.getOpcode();
5164   switch (Opc) {
5165   case TargetOpcode::G_IMPLICIT_DEF:
5166   case TargetOpcode::G_LOAD: {
5167     if (TypeIdx != 0)
5168       return UnableToLegalize;
5169     Observer.changingInstr(MI);
5170     moreElementsVectorDst(MI, MoreTy, 0);
5171     Observer.changedInstr(MI);
5172     return Legalized;
5173   }
5174   case TargetOpcode::G_STORE:
5175     if (TypeIdx != 0)
5176       return UnableToLegalize;
5177     Observer.changingInstr(MI);
5178     moreElementsVectorSrc(MI, MoreTy, 0);
5179     Observer.changedInstr(MI);
5180     return Legalized;
5181   case TargetOpcode::G_AND:
5182   case TargetOpcode::G_OR:
5183   case TargetOpcode::G_XOR:
5184   case TargetOpcode::G_ADD:
5185   case TargetOpcode::G_SUB:
5186   case TargetOpcode::G_MUL:
5187   case TargetOpcode::G_FADD:
5188   case TargetOpcode::G_FSUB:
5189   case TargetOpcode::G_FMUL:
5190   case TargetOpcode::G_FDIV:
5191   case TargetOpcode::G_UADDSAT:
5192   case TargetOpcode::G_USUBSAT:
5193   case TargetOpcode::G_SADDSAT:
5194   case TargetOpcode::G_SSUBSAT:
5195   case TargetOpcode::G_SMIN:
5196   case TargetOpcode::G_SMAX:
5197   case TargetOpcode::G_UMIN:
5198   case TargetOpcode::G_UMAX:
5199   case TargetOpcode::G_FMINNUM:
5200   case TargetOpcode::G_FMAXNUM:
5201   case TargetOpcode::G_FMINNUM_IEEE:
5202   case TargetOpcode::G_FMAXNUM_IEEE:
5203   case TargetOpcode::G_FMINIMUM:
5204   case TargetOpcode::G_FMAXIMUM:
5205   case TargetOpcode::G_STRICT_FADD:
5206   case TargetOpcode::G_STRICT_FSUB:
5207   case TargetOpcode::G_STRICT_FMUL:
5208   case TargetOpcode::G_SHL:
5209   case TargetOpcode::G_ASHR:
5210   case TargetOpcode::G_LSHR: {
5211     Observer.changingInstr(MI);
5212     moreElementsVectorSrc(MI, MoreTy, 1);
5213     moreElementsVectorSrc(MI, MoreTy, 2);
5214     moreElementsVectorDst(MI, MoreTy, 0);
5215     Observer.changedInstr(MI);
5216     return Legalized;
5217   }
5218   case TargetOpcode::G_FMA:
5219   case TargetOpcode::G_STRICT_FMA:
5220   case TargetOpcode::G_FSHR:
5221   case TargetOpcode::G_FSHL: {
5222     Observer.changingInstr(MI);
5223     moreElementsVectorSrc(MI, MoreTy, 1);
5224     moreElementsVectorSrc(MI, MoreTy, 2);
5225     moreElementsVectorSrc(MI, MoreTy, 3);
5226     moreElementsVectorDst(MI, MoreTy, 0);
5227     Observer.changedInstr(MI);
5228     return Legalized;
5229   }
5230   case TargetOpcode::G_EXTRACT_VECTOR_ELT:
5231   case TargetOpcode::G_EXTRACT:
5232     if (TypeIdx != 1)
5233       return UnableToLegalize;
5234     Observer.changingInstr(MI);
5235     moreElementsVectorSrc(MI, MoreTy, 1);
5236     Observer.changedInstr(MI);
5237     return Legalized;
5238   case TargetOpcode::G_INSERT:
5239   case TargetOpcode::G_INSERT_VECTOR_ELT:
5240   case TargetOpcode::G_FREEZE:
5241   case TargetOpcode::G_FNEG:
5242   case TargetOpcode::G_FABS:
5243   case TargetOpcode::G_FSQRT:
5244   case TargetOpcode::G_FCEIL:
5245   case TargetOpcode::G_FFLOOR:
5246   case TargetOpcode::G_FNEARBYINT:
5247   case TargetOpcode::G_FRINT:
5248   case TargetOpcode::G_INTRINSIC_ROUND:
5249   case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
5250   case TargetOpcode::G_INTRINSIC_TRUNC:
5251   case TargetOpcode::G_BSWAP:
5252   case TargetOpcode::G_FCANONICALIZE:
5253   case TargetOpcode::G_SEXT_INREG:
5254     if (TypeIdx != 0)
5255       return UnableToLegalize;
5256     Observer.changingInstr(MI);
5257     moreElementsVectorSrc(MI, MoreTy, 1);
5258     moreElementsVectorDst(MI, MoreTy, 0);
5259     Observer.changedInstr(MI);
5260     return Legalized;
5261   case TargetOpcode::G_SELECT: {
5262     auto [DstReg, DstTy, CondReg, CondTy] = MI.getFirst2RegLLTs();
5263     if (TypeIdx == 1) {
5264       if (!CondTy.isScalar() ||
5265           DstTy.getElementCount() != MoreTy.getElementCount())
5266         return UnableToLegalize;
5267 
5268       // This is turning a scalar select of vectors into a vector
5269       // select. Broadcast the select condition.
5270       auto ShufSplat = MIRBuilder.buildShuffleSplat(MoreTy, CondReg);
5271       Observer.changingInstr(MI);
5272       MI.getOperand(1).setReg(ShufSplat.getReg(0));
5273       Observer.changedInstr(MI);
5274       return Legalized;
5275     }
5276 
5277     if (CondTy.isVector())
5278       return UnableToLegalize;
5279 
5280     Observer.changingInstr(MI);
5281     moreElementsVectorSrc(MI, MoreTy, 2);
5282     moreElementsVectorSrc(MI, MoreTy, 3);
5283     moreElementsVectorDst(MI, MoreTy, 0);
5284     Observer.changedInstr(MI);
5285     return Legalized;
5286   }
5287   case TargetOpcode::G_UNMERGE_VALUES:
5288     return UnableToLegalize;
5289   case TargetOpcode::G_PHI:
5290     return moreElementsVectorPhi(MI, TypeIdx, MoreTy);
5291   case TargetOpcode::G_SHUFFLE_VECTOR:
5292     return moreElementsVectorShuffle(MI, TypeIdx, MoreTy);
5293   case TargetOpcode::G_BUILD_VECTOR: {
5294     SmallVector<SrcOp, 8> Elts;
5295     for (auto Op : MI.uses()) {
5296       Elts.push_back(Op.getReg());
5297     }
5298 
5299     for (unsigned i = Elts.size(); i < MoreTy.getNumElements(); ++i) {
5300       Elts.push_back(MIRBuilder.buildUndef(MoreTy.getScalarType()));
5301     }
5302 
5303     MIRBuilder.buildDeleteTrailingVectorElements(
5304         MI.getOperand(0).getReg(), MIRBuilder.buildInstr(Opc, {MoreTy}, Elts));
5305     MI.eraseFromParent();
5306     return Legalized;
5307   }
5308   case TargetOpcode::G_TRUNC:
5309   case TargetOpcode::G_FPTRUNC:
5310   case TargetOpcode::G_FPEXT:
5311   case TargetOpcode::G_FPTOSI:
5312   case TargetOpcode::G_FPTOUI:
5313   case TargetOpcode::G_SITOFP:
5314   case TargetOpcode::G_UITOFP: {
5315     if (TypeIdx != 0)
5316       return UnableToLegalize;
5317     Observer.changingInstr(MI);
5318     LLT SrcTy = LLT::fixed_vector(
5319         MoreTy.getNumElements(),
5320         MRI.getType(MI.getOperand(1).getReg()).getElementType());
5321     moreElementsVectorSrc(MI, SrcTy, 1);
5322     moreElementsVectorDst(MI, MoreTy, 0);
5323     Observer.changedInstr(MI);
5324     return Legalized;
5325   }
5326   case TargetOpcode::G_ICMP: {
5327     // TODO: the symmetric MoreTy works for targets like, e.g. NEON.
5328     // For targets, like e.g. MVE, the result is a predicated vector (i1).
5329     // This will need some refactoring.
5330     Observer.changingInstr(MI);
5331     moreElementsVectorSrc(MI, MoreTy, 2);
5332     moreElementsVectorSrc(MI, MoreTy, 3);
5333     moreElementsVectorDst(MI, MoreTy, 0);
5334     Observer.changedInstr(MI);
5335     return Legalized;
5336   }
5337   default:
5338     return UnableToLegalize;
5339   }
5340 }
5341 
5342 LegalizerHelper::LegalizeResult
equalizeVectorShuffleLengths(MachineInstr & MI)5343 LegalizerHelper::equalizeVectorShuffleLengths(MachineInstr &MI) {
5344   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
5345   ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
5346   unsigned MaskNumElts = Mask.size();
5347   unsigned SrcNumElts = SrcTy.getNumElements();
5348   LLT DestEltTy = DstTy.getElementType();
5349 
5350   if (MaskNumElts == SrcNumElts)
5351     return Legalized;
5352 
5353   if (MaskNumElts < SrcNumElts) {
5354     // Extend mask to match new destination vector size with
5355     // undef values.
5356     SmallVector<int, 16> NewMask(Mask);
5357     for (unsigned I = MaskNumElts; I < SrcNumElts; ++I)
5358       NewMask.push_back(-1);
5359 
5360     moreElementsVectorDst(MI, SrcTy, 0);
5361     MIRBuilder.setInstrAndDebugLoc(MI);
5362     MIRBuilder.buildShuffleVector(MI.getOperand(0).getReg(),
5363                                   MI.getOperand(1).getReg(),
5364                                   MI.getOperand(2).getReg(), NewMask);
5365     MI.eraseFromParent();
5366 
5367     return Legalized;
5368   }
5369 
5370   unsigned PaddedMaskNumElts = alignTo(MaskNumElts, SrcNumElts);
5371   unsigned NumConcat = PaddedMaskNumElts / SrcNumElts;
5372   LLT PaddedTy = LLT::fixed_vector(PaddedMaskNumElts, DestEltTy);
5373 
5374   // Create new source vectors by concatenating the initial
5375   // source vectors with undefined vectors of the same size.
5376   auto Undef = MIRBuilder.buildUndef(SrcTy);
5377   SmallVector<Register, 8> MOps1(NumConcat, Undef.getReg(0));
5378   SmallVector<Register, 8> MOps2(NumConcat, Undef.getReg(0));
5379   MOps1[0] = MI.getOperand(1).getReg();
5380   MOps2[0] = MI.getOperand(2).getReg();
5381 
5382   auto Src1 = MIRBuilder.buildConcatVectors(PaddedTy, MOps1);
5383   auto Src2 = MIRBuilder.buildConcatVectors(PaddedTy, MOps2);
5384 
5385   // Readjust mask for new input vector length.
5386   SmallVector<int, 8> MappedOps(PaddedMaskNumElts, -1);
5387   for (unsigned I = 0; I != MaskNumElts; ++I) {
5388     int Idx = Mask[I];
5389     if (Idx >= static_cast<int>(SrcNumElts))
5390       Idx += PaddedMaskNumElts - SrcNumElts;
5391     MappedOps[I] = Idx;
5392   }
5393 
5394   // If we got more elements than required, extract subvector.
5395   if (MaskNumElts != PaddedMaskNumElts) {
5396     auto Shuffle =
5397         MIRBuilder.buildShuffleVector(PaddedTy, Src1, Src2, MappedOps);
5398 
5399     SmallVector<Register, 16> Elts(MaskNumElts);
5400     for (unsigned I = 0; I < MaskNumElts; ++I) {
5401       Elts[I] =
5402           MIRBuilder.buildExtractVectorElementConstant(DestEltTy, Shuffle, I)
5403               .getReg(0);
5404     }
5405     MIRBuilder.buildBuildVector(DstReg, Elts);
5406   } else {
5407     MIRBuilder.buildShuffleVector(DstReg, Src1, Src2, MappedOps);
5408   }
5409 
5410   MI.eraseFromParent();
5411   return LegalizerHelper::LegalizeResult::Legalized;
5412 }
5413 
5414 LegalizerHelper::LegalizeResult
moreElementsVectorShuffle(MachineInstr & MI,unsigned int TypeIdx,LLT MoreTy)5415 LegalizerHelper::moreElementsVectorShuffle(MachineInstr &MI,
5416                                            unsigned int TypeIdx, LLT MoreTy) {
5417   auto [DstTy, Src1Ty, Src2Ty] = MI.getFirst3LLTs();
5418   ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
5419   unsigned NumElts = DstTy.getNumElements();
5420   unsigned WidenNumElts = MoreTy.getNumElements();
5421 
5422   if (DstTy.isVector() && Src1Ty.isVector() &&
5423       DstTy.getNumElements() != Src1Ty.getNumElements()) {
5424     return equalizeVectorShuffleLengths(MI);
5425   }
5426 
5427   if (TypeIdx != 0)
5428     return UnableToLegalize;
5429 
5430   // Expect a canonicalized shuffle.
5431   if (DstTy != Src1Ty || DstTy != Src2Ty)
5432     return UnableToLegalize;
5433 
5434   moreElementsVectorSrc(MI, MoreTy, 1);
5435   moreElementsVectorSrc(MI, MoreTy, 2);
5436 
5437   // Adjust mask based on new input vector length.
5438   SmallVector<int, 16> NewMask;
5439   for (unsigned I = 0; I != NumElts; ++I) {
5440     int Idx = Mask[I];
5441     if (Idx < static_cast<int>(NumElts))
5442       NewMask.push_back(Idx);
5443     else
5444       NewMask.push_back(Idx - NumElts + WidenNumElts);
5445   }
5446   for (unsigned I = NumElts; I != WidenNumElts; ++I)
5447     NewMask.push_back(-1);
5448   moreElementsVectorDst(MI, MoreTy, 0);
5449   MIRBuilder.setInstrAndDebugLoc(MI);
5450   MIRBuilder.buildShuffleVector(MI.getOperand(0).getReg(),
5451                                 MI.getOperand(1).getReg(),
5452                                 MI.getOperand(2).getReg(), NewMask);
5453   MI.eraseFromParent();
5454   return Legalized;
5455 }
5456 
multiplyRegisters(SmallVectorImpl<Register> & DstRegs,ArrayRef<Register> Src1Regs,ArrayRef<Register> Src2Regs,LLT NarrowTy)5457 void LegalizerHelper::multiplyRegisters(SmallVectorImpl<Register> &DstRegs,
5458                                         ArrayRef<Register> Src1Regs,
5459                                         ArrayRef<Register> Src2Regs,
5460                                         LLT NarrowTy) {
5461   MachineIRBuilder &B = MIRBuilder;
5462   unsigned SrcParts = Src1Regs.size();
5463   unsigned DstParts = DstRegs.size();
5464 
5465   unsigned DstIdx = 0; // Low bits of the result.
5466   Register FactorSum =
5467       B.buildMul(NarrowTy, Src1Regs[DstIdx], Src2Regs[DstIdx]).getReg(0);
5468   DstRegs[DstIdx] = FactorSum;
5469 
5470   unsigned CarrySumPrevDstIdx;
5471   SmallVector<Register, 4> Factors;
5472 
5473   for (DstIdx = 1; DstIdx < DstParts; DstIdx++) {
5474     // Collect low parts of muls for DstIdx.
5475     for (unsigned i = DstIdx + 1 < SrcParts ? 0 : DstIdx - SrcParts + 1;
5476          i <= std::min(DstIdx, SrcParts - 1); ++i) {
5477       MachineInstrBuilder Mul =
5478           B.buildMul(NarrowTy, Src1Regs[DstIdx - i], Src2Regs[i]);
5479       Factors.push_back(Mul.getReg(0));
5480     }
5481     // Collect high parts of muls from previous DstIdx.
5482     for (unsigned i = DstIdx < SrcParts ? 0 : DstIdx - SrcParts;
5483          i <= std::min(DstIdx - 1, SrcParts - 1); ++i) {
5484       MachineInstrBuilder Umulh =
5485           B.buildUMulH(NarrowTy, Src1Regs[DstIdx - 1 - i], Src2Regs[i]);
5486       Factors.push_back(Umulh.getReg(0));
5487     }
5488     // Add CarrySum from additions calculated for previous DstIdx.
5489     if (DstIdx != 1) {
5490       Factors.push_back(CarrySumPrevDstIdx);
5491     }
5492 
5493     Register CarrySum;
5494     // Add all factors and accumulate all carries into CarrySum.
5495     if (DstIdx != DstParts - 1) {
5496       MachineInstrBuilder Uaddo =
5497           B.buildUAddo(NarrowTy, LLT::scalar(1), Factors[0], Factors[1]);
5498       FactorSum = Uaddo.getReg(0);
5499       CarrySum = B.buildZExt(NarrowTy, Uaddo.getReg(1)).getReg(0);
5500       for (unsigned i = 2; i < Factors.size(); ++i) {
5501         MachineInstrBuilder Uaddo =
5502             B.buildUAddo(NarrowTy, LLT::scalar(1), FactorSum, Factors[i]);
5503         FactorSum = Uaddo.getReg(0);
5504         MachineInstrBuilder Carry = B.buildZExt(NarrowTy, Uaddo.getReg(1));
5505         CarrySum = B.buildAdd(NarrowTy, CarrySum, Carry).getReg(0);
5506       }
5507     } else {
5508       // Since value for the next index is not calculated, neither is CarrySum.
5509       FactorSum = B.buildAdd(NarrowTy, Factors[0], Factors[1]).getReg(0);
5510       for (unsigned i = 2; i < Factors.size(); ++i)
5511         FactorSum = B.buildAdd(NarrowTy, FactorSum, Factors[i]).getReg(0);
5512     }
5513 
5514     CarrySumPrevDstIdx = CarrySum;
5515     DstRegs[DstIdx] = FactorSum;
5516     Factors.clear();
5517   }
5518 }
5519 
5520 LegalizerHelper::LegalizeResult
narrowScalarAddSub(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5521 LegalizerHelper::narrowScalarAddSub(MachineInstr &MI, unsigned TypeIdx,
5522                                     LLT NarrowTy) {
5523   if (TypeIdx != 0)
5524     return UnableToLegalize;
5525 
5526   Register DstReg = MI.getOperand(0).getReg();
5527   LLT DstType = MRI.getType(DstReg);
5528   // FIXME: add support for vector types
5529   if (DstType.isVector())
5530     return UnableToLegalize;
5531 
5532   unsigned Opcode = MI.getOpcode();
5533   unsigned OpO, OpE, OpF;
5534   switch (Opcode) {
5535   case TargetOpcode::G_SADDO:
5536   case TargetOpcode::G_SADDE:
5537   case TargetOpcode::G_UADDO:
5538   case TargetOpcode::G_UADDE:
5539   case TargetOpcode::G_ADD:
5540     OpO = TargetOpcode::G_UADDO;
5541     OpE = TargetOpcode::G_UADDE;
5542     OpF = TargetOpcode::G_UADDE;
5543     if (Opcode == TargetOpcode::G_SADDO || Opcode == TargetOpcode::G_SADDE)
5544       OpF = TargetOpcode::G_SADDE;
5545     break;
5546   case TargetOpcode::G_SSUBO:
5547   case TargetOpcode::G_SSUBE:
5548   case TargetOpcode::G_USUBO:
5549   case TargetOpcode::G_USUBE:
5550   case TargetOpcode::G_SUB:
5551     OpO = TargetOpcode::G_USUBO;
5552     OpE = TargetOpcode::G_USUBE;
5553     OpF = TargetOpcode::G_USUBE;
5554     if (Opcode == TargetOpcode::G_SSUBO || Opcode == TargetOpcode::G_SSUBE)
5555       OpF = TargetOpcode::G_SSUBE;
5556     break;
5557   default:
5558     llvm_unreachable("Unexpected add/sub opcode!");
5559   }
5560 
5561   // 1 for a plain add/sub, 2 if this is an operation with a carry-out.
5562   unsigned NumDefs = MI.getNumExplicitDefs();
5563   Register Src1 = MI.getOperand(NumDefs).getReg();
5564   Register Src2 = MI.getOperand(NumDefs + 1).getReg();
5565   Register CarryDst, CarryIn;
5566   if (NumDefs == 2)
5567     CarryDst = MI.getOperand(1).getReg();
5568   if (MI.getNumOperands() == NumDefs + 3)
5569     CarryIn = MI.getOperand(NumDefs + 2).getReg();
5570 
5571   LLT RegTy = MRI.getType(MI.getOperand(0).getReg());
5572   LLT LeftoverTy, DummyTy;
5573   SmallVector<Register, 2> Src1Regs, Src2Regs, Src1Left, Src2Left, DstRegs;
5574   extractParts(Src1, RegTy, NarrowTy, LeftoverTy, Src1Regs, Src1Left,
5575                MIRBuilder, MRI);
5576   extractParts(Src2, RegTy, NarrowTy, DummyTy, Src2Regs, Src2Left, MIRBuilder,
5577                MRI);
5578 
5579   int NarrowParts = Src1Regs.size();
5580   for (int I = 0, E = Src1Left.size(); I != E; ++I) {
5581     Src1Regs.push_back(Src1Left[I]);
5582     Src2Regs.push_back(Src2Left[I]);
5583   }
5584   DstRegs.reserve(Src1Regs.size());
5585 
5586   for (int i = 0, e = Src1Regs.size(); i != e; ++i) {
5587     Register DstReg =
5588         MRI.createGenericVirtualRegister(MRI.getType(Src1Regs[i]));
5589     Register CarryOut = MRI.createGenericVirtualRegister(LLT::scalar(1));
5590     // Forward the final carry-out to the destination register
5591     if (i == e - 1 && CarryDst)
5592       CarryOut = CarryDst;
5593 
5594     if (!CarryIn) {
5595       MIRBuilder.buildInstr(OpO, {DstReg, CarryOut},
5596                             {Src1Regs[i], Src2Regs[i]});
5597     } else if (i == e - 1) {
5598       MIRBuilder.buildInstr(OpF, {DstReg, CarryOut},
5599                             {Src1Regs[i], Src2Regs[i], CarryIn});
5600     } else {
5601       MIRBuilder.buildInstr(OpE, {DstReg, CarryOut},
5602                             {Src1Regs[i], Src2Regs[i], CarryIn});
5603     }
5604 
5605     DstRegs.push_back(DstReg);
5606     CarryIn = CarryOut;
5607   }
5608   insertParts(MI.getOperand(0).getReg(), RegTy, NarrowTy,
5609               ArrayRef(DstRegs).take_front(NarrowParts), LeftoverTy,
5610               ArrayRef(DstRegs).drop_front(NarrowParts));
5611 
5612   MI.eraseFromParent();
5613   return Legalized;
5614 }
5615 
5616 LegalizerHelper::LegalizeResult
narrowScalarMul(MachineInstr & MI,LLT NarrowTy)5617 LegalizerHelper::narrowScalarMul(MachineInstr &MI, LLT NarrowTy) {
5618   auto [DstReg, Src1, Src2] = MI.getFirst3Regs();
5619 
5620   LLT Ty = MRI.getType(DstReg);
5621   if (Ty.isVector())
5622     return UnableToLegalize;
5623 
5624   unsigned Size = Ty.getSizeInBits();
5625   unsigned NarrowSize = NarrowTy.getSizeInBits();
5626   if (Size % NarrowSize != 0)
5627     return UnableToLegalize;
5628 
5629   unsigned NumParts = Size / NarrowSize;
5630   bool IsMulHigh = MI.getOpcode() == TargetOpcode::G_UMULH;
5631   unsigned DstTmpParts = NumParts * (IsMulHigh ? 2 : 1);
5632 
5633   SmallVector<Register, 2> Src1Parts, Src2Parts;
5634   SmallVector<Register, 2> DstTmpRegs(DstTmpParts);
5635   extractParts(Src1, NarrowTy, NumParts, Src1Parts, MIRBuilder, MRI);
5636   extractParts(Src2, NarrowTy, NumParts, Src2Parts, MIRBuilder, MRI);
5637   multiplyRegisters(DstTmpRegs, Src1Parts, Src2Parts, NarrowTy);
5638 
5639   // Take only high half of registers if this is high mul.
5640   ArrayRef<Register> DstRegs(&DstTmpRegs[DstTmpParts - NumParts], NumParts);
5641   MIRBuilder.buildMergeLikeInstr(DstReg, DstRegs);
5642   MI.eraseFromParent();
5643   return Legalized;
5644 }
5645 
5646 LegalizerHelper::LegalizeResult
narrowScalarFPTOI(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5647 LegalizerHelper::narrowScalarFPTOI(MachineInstr &MI, unsigned TypeIdx,
5648                                    LLT NarrowTy) {
5649   if (TypeIdx != 0)
5650     return UnableToLegalize;
5651 
5652   bool IsSigned = MI.getOpcode() == TargetOpcode::G_FPTOSI;
5653 
5654   Register Src = MI.getOperand(1).getReg();
5655   LLT SrcTy = MRI.getType(Src);
5656 
5657   // If all finite floats fit into the narrowed integer type, we can just swap
5658   // out the result type. This is practically only useful for conversions from
5659   // half to at least 16-bits, so just handle the one case.
5660   if (SrcTy.getScalarType() != LLT::scalar(16) ||
5661       NarrowTy.getScalarSizeInBits() < (IsSigned ? 17u : 16u))
5662     return UnableToLegalize;
5663 
5664   Observer.changingInstr(MI);
5665   narrowScalarDst(MI, NarrowTy, 0,
5666                   IsSigned ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT);
5667   Observer.changedInstr(MI);
5668   return Legalized;
5669 }
5670 
5671 LegalizerHelper::LegalizeResult
narrowScalarExtract(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5672 LegalizerHelper::narrowScalarExtract(MachineInstr &MI, unsigned TypeIdx,
5673                                      LLT NarrowTy) {
5674   if (TypeIdx != 1)
5675     return UnableToLegalize;
5676 
5677   uint64_t NarrowSize = NarrowTy.getSizeInBits();
5678 
5679   int64_t SizeOp1 = MRI.getType(MI.getOperand(1).getReg()).getSizeInBits();
5680   // FIXME: add support for when SizeOp1 isn't an exact multiple of
5681   // NarrowSize.
5682   if (SizeOp1 % NarrowSize != 0)
5683     return UnableToLegalize;
5684   int NumParts = SizeOp1 / NarrowSize;
5685 
5686   SmallVector<Register, 2> SrcRegs, DstRegs;
5687   SmallVector<uint64_t, 2> Indexes;
5688   extractParts(MI.getOperand(1).getReg(), NarrowTy, NumParts, SrcRegs,
5689                MIRBuilder, MRI);
5690 
5691   Register OpReg = MI.getOperand(0).getReg();
5692   uint64_t OpStart = MI.getOperand(2).getImm();
5693   uint64_t OpSize = MRI.getType(OpReg).getSizeInBits();
5694   for (int i = 0; i < NumParts; ++i) {
5695     unsigned SrcStart = i * NarrowSize;
5696 
5697     if (SrcStart + NarrowSize <= OpStart || SrcStart >= OpStart + OpSize) {
5698       // No part of the extract uses this subregister, ignore it.
5699       continue;
5700     } else if (SrcStart == OpStart && NarrowTy == MRI.getType(OpReg)) {
5701       // The entire subregister is extracted, forward the value.
5702       DstRegs.push_back(SrcRegs[i]);
5703       continue;
5704     }
5705 
5706     // OpSegStart is where this destination segment would start in OpReg if it
5707     // extended infinitely in both directions.
5708     int64_t ExtractOffset;
5709     uint64_t SegSize;
5710     if (OpStart < SrcStart) {
5711       ExtractOffset = 0;
5712       SegSize = std::min(NarrowSize, OpStart + OpSize - SrcStart);
5713     } else {
5714       ExtractOffset = OpStart - SrcStart;
5715       SegSize = std::min(SrcStart + NarrowSize - OpStart, OpSize);
5716     }
5717 
5718     Register SegReg = SrcRegs[i];
5719     if (ExtractOffset != 0 || SegSize != NarrowSize) {
5720       // A genuine extract is needed.
5721       SegReg = MRI.createGenericVirtualRegister(LLT::scalar(SegSize));
5722       MIRBuilder.buildExtract(SegReg, SrcRegs[i], ExtractOffset);
5723     }
5724 
5725     DstRegs.push_back(SegReg);
5726   }
5727 
5728   Register DstReg = MI.getOperand(0).getReg();
5729   if (MRI.getType(DstReg).isVector())
5730     MIRBuilder.buildBuildVector(DstReg, DstRegs);
5731   else if (DstRegs.size() > 1)
5732     MIRBuilder.buildMergeLikeInstr(DstReg, DstRegs);
5733   else
5734     MIRBuilder.buildCopy(DstReg, DstRegs[0]);
5735   MI.eraseFromParent();
5736   return Legalized;
5737 }
5738 
5739 LegalizerHelper::LegalizeResult
narrowScalarInsert(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5740 LegalizerHelper::narrowScalarInsert(MachineInstr &MI, unsigned TypeIdx,
5741                                     LLT NarrowTy) {
5742   // FIXME: Don't know how to handle secondary types yet.
5743   if (TypeIdx != 0)
5744     return UnableToLegalize;
5745 
5746   SmallVector<Register, 2> SrcRegs, LeftoverRegs, DstRegs;
5747   SmallVector<uint64_t, 2> Indexes;
5748   LLT RegTy = MRI.getType(MI.getOperand(0).getReg());
5749   LLT LeftoverTy;
5750   extractParts(MI.getOperand(1).getReg(), RegTy, NarrowTy, LeftoverTy, SrcRegs,
5751                LeftoverRegs, MIRBuilder, MRI);
5752 
5753   for (Register Reg : LeftoverRegs)
5754     SrcRegs.push_back(Reg);
5755 
5756   uint64_t NarrowSize = NarrowTy.getSizeInBits();
5757   Register OpReg = MI.getOperand(2).getReg();
5758   uint64_t OpStart = MI.getOperand(3).getImm();
5759   uint64_t OpSize = MRI.getType(OpReg).getSizeInBits();
5760   for (int I = 0, E = SrcRegs.size(); I != E; ++I) {
5761     unsigned DstStart = I * NarrowSize;
5762 
5763     if (DstStart == OpStart && NarrowTy == MRI.getType(OpReg)) {
5764       // The entire subregister is defined by this insert, forward the new
5765       // value.
5766       DstRegs.push_back(OpReg);
5767       continue;
5768     }
5769 
5770     Register SrcReg = SrcRegs[I];
5771     if (MRI.getType(SrcRegs[I]) == LeftoverTy) {
5772       // The leftover reg is smaller than NarrowTy, so we need to extend it.
5773       SrcReg = MRI.createGenericVirtualRegister(NarrowTy);
5774       MIRBuilder.buildAnyExt(SrcReg, SrcRegs[I]);
5775     }
5776 
5777     if (DstStart + NarrowSize <= OpStart || DstStart >= OpStart + OpSize) {
5778       // No part of the insert affects this subregister, forward the original.
5779       DstRegs.push_back(SrcReg);
5780       continue;
5781     }
5782 
5783     // OpSegStart is where this destination segment would start in OpReg if it
5784     // extended infinitely in both directions.
5785     int64_t ExtractOffset, InsertOffset;
5786     uint64_t SegSize;
5787     if (OpStart < DstStart) {
5788       InsertOffset = 0;
5789       ExtractOffset = DstStart - OpStart;
5790       SegSize = std::min(NarrowSize, OpStart + OpSize - DstStart);
5791     } else {
5792       InsertOffset = OpStart - DstStart;
5793       ExtractOffset = 0;
5794       SegSize =
5795         std::min(NarrowSize - InsertOffset, OpStart + OpSize - DstStart);
5796     }
5797 
5798     Register SegReg = OpReg;
5799     if (ExtractOffset != 0 || SegSize != OpSize) {
5800       // A genuine extract is needed.
5801       SegReg = MRI.createGenericVirtualRegister(LLT::scalar(SegSize));
5802       MIRBuilder.buildExtract(SegReg, OpReg, ExtractOffset);
5803     }
5804 
5805     Register DstReg = MRI.createGenericVirtualRegister(NarrowTy);
5806     MIRBuilder.buildInsert(DstReg, SrcReg, SegReg, InsertOffset);
5807     DstRegs.push_back(DstReg);
5808   }
5809 
5810   uint64_t WideSize = DstRegs.size() * NarrowSize;
5811   Register DstReg = MI.getOperand(0).getReg();
5812   if (WideSize > RegTy.getSizeInBits()) {
5813     Register MergeReg = MRI.createGenericVirtualRegister(LLT::scalar(WideSize));
5814     MIRBuilder.buildMergeLikeInstr(MergeReg, DstRegs);
5815     MIRBuilder.buildTrunc(DstReg, MergeReg);
5816   } else
5817     MIRBuilder.buildMergeLikeInstr(DstReg, DstRegs);
5818 
5819   MI.eraseFromParent();
5820   return Legalized;
5821 }
5822 
5823 LegalizerHelper::LegalizeResult
narrowScalarBasic(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5824 LegalizerHelper::narrowScalarBasic(MachineInstr &MI, unsigned TypeIdx,
5825                                    LLT NarrowTy) {
5826   Register DstReg = MI.getOperand(0).getReg();
5827   LLT DstTy = MRI.getType(DstReg);
5828 
5829   assert(MI.getNumOperands() == 3 && TypeIdx == 0);
5830 
5831   SmallVector<Register, 4> DstRegs, DstLeftoverRegs;
5832   SmallVector<Register, 4> Src0Regs, Src0LeftoverRegs;
5833   SmallVector<Register, 4> Src1Regs, Src1LeftoverRegs;
5834   LLT LeftoverTy;
5835   if (!extractParts(MI.getOperand(1).getReg(), DstTy, NarrowTy, LeftoverTy,
5836                     Src0Regs, Src0LeftoverRegs, MIRBuilder, MRI))
5837     return UnableToLegalize;
5838 
5839   LLT Unused;
5840   if (!extractParts(MI.getOperand(2).getReg(), DstTy, NarrowTy, Unused,
5841                     Src1Regs, Src1LeftoverRegs, MIRBuilder, MRI))
5842     llvm_unreachable("inconsistent extractParts result");
5843 
5844   for (unsigned I = 0, E = Src1Regs.size(); I != E; ++I) {
5845     auto Inst = MIRBuilder.buildInstr(MI.getOpcode(), {NarrowTy},
5846                                         {Src0Regs[I], Src1Regs[I]});
5847     DstRegs.push_back(Inst.getReg(0));
5848   }
5849 
5850   for (unsigned I = 0, E = Src1LeftoverRegs.size(); I != E; ++I) {
5851     auto Inst = MIRBuilder.buildInstr(
5852       MI.getOpcode(),
5853       {LeftoverTy}, {Src0LeftoverRegs[I], Src1LeftoverRegs[I]});
5854     DstLeftoverRegs.push_back(Inst.getReg(0));
5855   }
5856 
5857   insertParts(DstReg, DstTy, NarrowTy, DstRegs,
5858               LeftoverTy, DstLeftoverRegs);
5859 
5860   MI.eraseFromParent();
5861   return Legalized;
5862 }
5863 
5864 LegalizerHelper::LegalizeResult
narrowScalarExt(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5865 LegalizerHelper::narrowScalarExt(MachineInstr &MI, unsigned TypeIdx,
5866                                  LLT NarrowTy) {
5867   if (TypeIdx != 0)
5868     return UnableToLegalize;
5869 
5870   auto [DstReg, SrcReg] = MI.getFirst2Regs();
5871 
5872   LLT DstTy = MRI.getType(DstReg);
5873   if (DstTy.isVector())
5874     return UnableToLegalize;
5875 
5876   SmallVector<Register, 8> Parts;
5877   LLT GCDTy = extractGCDType(Parts, DstTy, NarrowTy, SrcReg);
5878   LLT LCMTy = buildLCMMergePieces(DstTy, NarrowTy, GCDTy, Parts, MI.getOpcode());
5879   buildWidenedRemergeToDst(DstReg, LCMTy, Parts);
5880 
5881   MI.eraseFromParent();
5882   return Legalized;
5883 }
5884 
5885 LegalizerHelper::LegalizeResult
narrowScalarSelect(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5886 LegalizerHelper::narrowScalarSelect(MachineInstr &MI, unsigned TypeIdx,
5887                                     LLT NarrowTy) {
5888   if (TypeIdx != 0)
5889     return UnableToLegalize;
5890 
5891   Register CondReg = MI.getOperand(1).getReg();
5892   LLT CondTy = MRI.getType(CondReg);
5893   if (CondTy.isVector()) // TODO: Handle vselect
5894     return UnableToLegalize;
5895 
5896   Register DstReg = MI.getOperand(0).getReg();
5897   LLT DstTy = MRI.getType(DstReg);
5898 
5899   SmallVector<Register, 4> DstRegs, DstLeftoverRegs;
5900   SmallVector<Register, 4> Src1Regs, Src1LeftoverRegs;
5901   SmallVector<Register, 4> Src2Regs, Src2LeftoverRegs;
5902   LLT LeftoverTy;
5903   if (!extractParts(MI.getOperand(2).getReg(), DstTy, NarrowTy, LeftoverTy,
5904                     Src1Regs, Src1LeftoverRegs, MIRBuilder, MRI))
5905     return UnableToLegalize;
5906 
5907   LLT Unused;
5908   if (!extractParts(MI.getOperand(3).getReg(), DstTy, NarrowTy, Unused,
5909                     Src2Regs, Src2LeftoverRegs, MIRBuilder, MRI))
5910     llvm_unreachable("inconsistent extractParts result");
5911 
5912   for (unsigned I = 0, E = Src1Regs.size(); I != E; ++I) {
5913     auto Select = MIRBuilder.buildSelect(NarrowTy,
5914                                          CondReg, Src1Regs[I], Src2Regs[I]);
5915     DstRegs.push_back(Select.getReg(0));
5916   }
5917 
5918   for (unsigned I = 0, E = Src1LeftoverRegs.size(); I != E; ++I) {
5919     auto Select = MIRBuilder.buildSelect(
5920       LeftoverTy, CondReg, Src1LeftoverRegs[I], Src2LeftoverRegs[I]);
5921     DstLeftoverRegs.push_back(Select.getReg(0));
5922   }
5923 
5924   insertParts(DstReg, DstTy, NarrowTy, DstRegs,
5925               LeftoverTy, DstLeftoverRegs);
5926 
5927   MI.eraseFromParent();
5928   return Legalized;
5929 }
5930 
5931 LegalizerHelper::LegalizeResult
narrowScalarCTLZ(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5932 LegalizerHelper::narrowScalarCTLZ(MachineInstr &MI, unsigned TypeIdx,
5933                                   LLT NarrowTy) {
5934   if (TypeIdx != 1)
5935     return UnableToLegalize;
5936 
5937   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
5938   unsigned NarrowSize = NarrowTy.getSizeInBits();
5939 
5940   if (SrcTy.isScalar() && SrcTy.getSizeInBits() == 2 * NarrowSize) {
5941     const bool IsUndef = MI.getOpcode() == TargetOpcode::G_CTLZ_ZERO_UNDEF;
5942 
5943     MachineIRBuilder &B = MIRBuilder;
5944     auto UnmergeSrc = B.buildUnmerge(NarrowTy, SrcReg);
5945     // ctlz(Hi:Lo) -> Hi == 0 ? (NarrowSize + ctlz(Lo)) : ctlz(Hi)
5946     auto C_0 = B.buildConstant(NarrowTy, 0);
5947     auto HiIsZero = B.buildICmp(CmpInst::ICMP_EQ, LLT::scalar(1),
5948                                 UnmergeSrc.getReg(1), C_0);
5949     auto LoCTLZ = IsUndef ?
5950       B.buildCTLZ_ZERO_UNDEF(DstTy, UnmergeSrc.getReg(0)) :
5951       B.buildCTLZ(DstTy, UnmergeSrc.getReg(0));
5952     auto C_NarrowSize = B.buildConstant(DstTy, NarrowSize);
5953     auto HiIsZeroCTLZ = B.buildAdd(DstTy, LoCTLZ, C_NarrowSize);
5954     auto HiCTLZ = B.buildCTLZ_ZERO_UNDEF(DstTy, UnmergeSrc.getReg(1));
5955     B.buildSelect(DstReg, HiIsZero, HiIsZeroCTLZ, HiCTLZ);
5956 
5957     MI.eraseFromParent();
5958     return Legalized;
5959   }
5960 
5961   return UnableToLegalize;
5962 }
5963 
5964 LegalizerHelper::LegalizeResult
narrowScalarCTTZ(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5965 LegalizerHelper::narrowScalarCTTZ(MachineInstr &MI, unsigned TypeIdx,
5966                                   LLT NarrowTy) {
5967   if (TypeIdx != 1)
5968     return UnableToLegalize;
5969 
5970   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
5971   unsigned NarrowSize = NarrowTy.getSizeInBits();
5972 
5973   if (SrcTy.isScalar() && SrcTy.getSizeInBits() == 2 * NarrowSize) {
5974     const bool IsUndef = MI.getOpcode() == TargetOpcode::G_CTTZ_ZERO_UNDEF;
5975 
5976     MachineIRBuilder &B = MIRBuilder;
5977     auto UnmergeSrc = B.buildUnmerge(NarrowTy, SrcReg);
5978     // cttz(Hi:Lo) -> Lo == 0 ? (cttz(Hi) + NarrowSize) : cttz(Lo)
5979     auto C_0 = B.buildConstant(NarrowTy, 0);
5980     auto LoIsZero = B.buildICmp(CmpInst::ICMP_EQ, LLT::scalar(1),
5981                                 UnmergeSrc.getReg(0), C_0);
5982     auto HiCTTZ = IsUndef ?
5983       B.buildCTTZ_ZERO_UNDEF(DstTy, UnmergeSrc.getReg(1)) :
5984       B.buildCTTZ(DstTy, UnmergeSrc.getReg(1));
5985     auto C_NarrowSize = B.buildConstant(DstTy, NarrowSize);
5986     auto LoIsZeroCTTZ = B.buildAdd(DstTy, HiCTTZ, C_NarrowSize);
5987     auto LoCTTZ = B.buildCTTZ_ZERO_UNDEF(DstTy, UnmergeSrc.getReg(0));
5988     B.buildSelect(DstReg, LoIsZero, LoIsZeroCTTZ, LoCTTZ);
5989 
5990     MI.eraseFromParent();
5991     return Legalized;
5992   }
5993 
5994   return UnableToLegalize;
5995 }
5996 
5997 LegalizerHelper::LegalizeResult
narrowScalarCTPOP(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5998 LegalizerHelper::narrowScalarCTPOP(MachineInstr &MI, unsigned TypeIdx,
5999                                    LLT NarrowTy) {
6000   if (TypeIdx != 1)
6001     return UnableToLegalize;
6002 
6003   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
6004   unsigned NarrowSize = NarrowTy.getSizeInBits();
6005 
6006   if (SrcTy.isScalar() && SrcTy.getSizeInBits() == 2 * NarrowSize) {
6007     auto UnmergeSrc = MIRBuilder.buildUnmerge(NarrowTy, MI.getOperand(1));
6008 
6009     auto LoCTPOP = MIRBuilder.buildCTPOP(DstTy, UnmergeSrc.getReg(0));
6010     auto HiCTPOP = MIRBuilder.buildCTPOP(DstTy, UnmergeSrc.getReg(1));
6011     MIRBuilder.buildAdd(DstReg, HiCTPOP, LoCTPOP);
6012 
6013     MI.eraseFromParent();
6014     return Legalized;
6015   }
6016 
6017   return UnableToLegalize;
6018 }
6019 
6020 LegalizerHelper::LegalizeResult
narrowScalarFLDEXP(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)6021 LegalizerHelper::narrowScalarFLDEXP(MachineInstr &MI, unsigned TypeIdx,
6022                                     LLT NarrowTy) {
6023   if (TypeIdx != 1)
6024     return UnableToLegalize;
6025 
6026   MachineIRBuilder &B = MIRBuilder;
6027   Register ExpReg = MI.getOperand(2).getReg();
6028   LLT ExpTy = MRI.getType(ExpReg);
6029 
6030   unsigned ClampSize = NarrowTy.getScalarSizeInBits();
6031 
6032   // Clamp the exponent to the range of the target type.
6033   auto MinExp = B.buildConstant(ExpTy, minIntN(ClampSize));
6034   auto ClampMin = B.buildSMax(ExpTy, ExpReg, MinExp);
6035   auto MaxExp = B.buildConstant(ExpTy, maxIntN(ClampSize));
6036   auto Clamp = B.buildSMin(ExpTy, ClampMin, MaxExp);
6037 
6038   auto Trunc = B.buildTrunc(NarrowTy, Clamp);
6039   Observer.changingInstr(MI);
6040   MI.getOperand(2).setReg(Trunc.getReg(0));
6041   Observer.changedInstr(MI);
6042   return Legalized;
6043 }
6044 
6045 LegalizerHelper::LegalizeResult
lowerBitCount(MachineInstr & MI)6046 LegalizerHelper::lowerBitCount(MachineInstr &MI) {
6047   unsigned Opc = MI.getOpcode();
6048   const auto &TII = MIRBuilder.getTII();
6049   auto isSupported = [this](const LegalityQuery &Q) {
6050     auto QAction = LI.getAction(Q).Action;
6051     return QAction == Legal || QAction == Libcall || QAction == Custom;
6052   };
6053   switch (Opc) {
6054   default:
6055     return UnableToLegalize;
6056   case TargetOpcode::G_CTLZ_ZERO_UNDEF: {
6057     // This trivially expands to CTLZ.
6058     Observer.changingInstr(MI);
6059     MI.setDesc(TII.get(TargetOpcode::G_CTLZ));
6060     Observer.changedInstr(MI);
6061     return Legalized;
6062   }
6063   case TargetOpcode::G_CTLZ: {
6064     auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
6065     unsigned Len = SrcTy.getSizeInBits();
6066 
6067     if (isSupported({TargetOpcode::G_CTLZ_ZERO_UNDEF, {DstTy, SrcTy}})) {
6068       // If CTLZ_ZERO_UNDEF is supported, emit that and a select for zero.
6069       auto CtlzZU = MIRBuilder.buildCTLZ_ZERO_UNDEF(DstTy, SrcReg);
6070       auto ZeroSrc = MIRBuilder.buildConstant(SrcTy, 0);
6071       auto ICmp = MIRBuilder.buildICmp(
6072           CmpInst::ICMP_EQ, SrcTy.changeElementSize(1), SrcReg, ZeroSrc);
6073       auto LenConst = MIRBuilder.buildConstant(DstTy, Len);
6074       MIRBuilder.buildSelect(DstReg, ICmp, LenConst, CtlzZU);
6075       MI.eraseFromParent();
6076       return Legalized;
6077     }
6078     // for now, we do this:
6079     // NewLen = NextPowerOf2(Len);
6080     // x = x | (x >> 1);
6081     // x = x | (x >> 2);
6082     // ...
6083     // x = x | (x >>16);
6084     // x = x | (x >>32); // for 64-bit input
6085     // Upto NewLen/2
6086     // return Len - popcount(x);
6087     //
6088     // Ref: "Hacker's Delight" by Henry Warren
6089     Register Op = SrcReg;
6090     unsigned NewLen = PowerOf2Ceil(Len);
6091     for (unsigned i = 0; (1U << i) <= (NewLen / 2); ++i) {
6092       auto MIBShiftAmt = MIRBuilder.buildConstant(SrcTy, 1ULL << i);
6093       auto MIBOp = MIRBuilder.buildOr(
6094           SrcTy, Op, MIRBuilder.buildLShr(SrcTy, Op, MIBShiftAmt));
6095       Op = MIBOp.getReg(0);
6096     }
6097     auto MIBPop = MIRBuilder.buildCTPOP(DstTy, Op);
6098     MIRBuilder.buildSub(MI.getOperand(0), MIRBuilder.buildConstant(DstTy, Len),
6099                         MIBPop);
6100     MI.eraseFromParent();
6101     return Legalized;
6102   }
6103   case TargetOpcode::G_CTTZ_ZERO_UNDEF: {
6104     // This trivially expands to CTTZ.
6105     Observer.changingInstr(MI);
6106     MI.setDesc(TII.get(TargetOpcode::G_CTTZ));
6107     Observer.changedInstr(MI);
6108     return Legalized;
6109   }
6110   case TargetOpcode::G_CTTZ: {
6111     auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
6112 
6113     unsigned Len = SrcTy.getSizeInBits();
6114     if (isSupported({TargetOpcode::G_CTTZ_ZERO_UNDEF, {DstTy, SrcTy}})) {
6115       // If CTTZ_ZERO_UNDEF is legal or custom, emit that and a select with
6116       // zero.
6117       auto CttzZU = MIRBuilder.buildCTTZ_ZERO_UNDEF(DstTy, SrcReg);
6118       auto Zero = MIRBuilder.buildConstant(SrcTy, 0);
6119       auto ICmp = MIRBuilder.buildICmp(
6120           CmpInst::ICMP_EQ, DstTy.changeElementSize(1), SrcReg, Zero);
6121       auto LenConst = MIRBuilder.buildConstant(DstTy, Len);
6122       MIRBuilder.buildSelect(DstReg, ICmp, LenConst, CttzZU);
6123       MI.eraseFromParent();
6124       return Legalized;
6125     }
6126     // for now, we use: { return popcount(~x & (x - 1)); }
6127     // unless the target has ctlz but not ctpop, in which case we use:
6128     // { return 32 - nlz(~x & (x-1)); }
6129     // Ref: "Hacker's Delight" by Henry Warren
6130     auto MIBCstNeg1 = MIRBuilder.buildConstant(SrcTy, -1);
6131     auto MIBNot = MIRBuilder.buildXor(SrcTy, SrcReg, MIBCstNeg1);
6132     auto MIBTmp = MIRBuilder.buildAnd(
6133         SrcTy, MIBNot, MIRBuilder.buildAdd(SrcTy, SrcReg, MIBCstNeg1));
6134     if (!isSupported({TargetOpcode::G_CTPOP, {SrcTy, SrcTy}}) &&
6135         isSupported({TargetOpcode::G_CTLZ, {SrcTy, SrcTy}})) {
6136       auto MIBCstLen = MIRBuilder.buildConstant(SrcTy, Len);
6137       MIRBuilder.buildSub(MI.getOperand(0), MIBCstLen,
6138                           MIRBuilder.buildCTLZ(SrcTy, MIBTmp));
6139       MI.eraseFromParent();
6140       return Legalized;
6141     }
6142     Observer.changingInstr(MI);
6143     MI.setDesc(TII.get(TargetOpcode::G_CTPOP));
6144     MI.getOperand(1).setReg(MIBTmp.getReg(0));
6145     Observer.changedInstr(MI);
6146     return Legalized;
6147   }
6148   case TargetOpcode::G_CTPOP: {
6149     Register SrcReg = MI.getOperand(1).getReg();
6150     LLT Ty = MRI.getType(SrcReg);
6151     unsigned Size = Ty.getSizeInBits();
6152     MachineIRBuilder &B = MIRBuilder;
6153 
6154     // Count set bits in blocks of 2 bits. Default approach would be
6155     // B2Count = { val & 0x55555555 } + { (val >> 1) & 0x55555555 }
6156     // We use following formula instead:
6157     // B2Count = val - { (val >> 1) & 0x55555555 }
6158     // since it gives same result in blocks of 2 with one instruction less.
6159     auto C_1 = B.buildConstant(Ty, 1);
6160     auto B2Set1LoTo1Hi = B.buildLShr(Ty, SrcReg, C_1);
6161     APInt B2Mask1HiTo0 = APInt::getSplat(Size, APInt(8, 0x55));
6162     auto C_B2Mask1HiTo0 = B.buildConstant(Ty, B2Mask1HiTo0);
6163     auto B2Count1Hi = B.buildAnd(Ty, B2Set1LoTo1Hi, C_B2Mask1HiTo0);
6164     auto B2Count = B.buildSub(Ty, SrcReg, B2Count1Hi);
6165 
6166     // In order to get count in blocks of 4 add values from adjacent block of 2.
6167     // B4Count = { B2Count & 0x33333333 } + { (B2Count >> 2) & 0x33333333 }
6168     auto C_2 = B.buildConstant(Ty, 2);
6169     auto B4Set2LoTo2Hi = B.buildLShr(Ty, B2Count, C_2);
6170     APInt B4Mask2HiTo0 = APInt::getSplat(Size, APInt(8, 0x33));
6171     auto C_B4Mask2HiTo0 = B.buildConstant(Ty, B4Mask2HiTo0);
6172     auto B4HiB2Count = B.buildAnd(Ty, B4Set2LoTo2Hi, C_B4Mask2HiTo0);
6173     auto B4LoB2Count = B.buildAnd(Ty, B2Count, C_B4Mask2HiTo0);
6174     auto B4Count = B.buildAdd(Ty, B4HiB2Count, B4LoB2Count);
6175 
6176     // For count in blocks of 8 bits we don't have to mask high 4 bits before
6177     // addition since count value sits in range {0,...,8} and 4 bits are enough
6178     // to hold such binary values. After addition high 4 bits still hold count
6179     // of set bits in high 4 bit block, set them to zero and get 8 bit result.
6180     // B8Count = { B4Count + (B4Count >> 4) } & 0x0F0F0F0F
6181     auto C_4 = B.buildConstant(Ty, 4);
6182     auto B8HiB4Count = B.buildLShr(Ty, B4Count, C_4);
6183     auto B8CountDirty4Hi = B.buildAdd(Ty, B8HiB4Count, B4Count);
6184     APInt B8Mask4HiTo0 = APInt::getSplat(Size, APInt(8, 0x0F));
6185     auto C_B8Mask4HiTo0 = B.buildConstant(Ty, B8Mask4HiTo0);
6186     auto B8Count = B.buildAnd(Ty, B8CountDirty4Hi, C_B8Mask4HiTo0);
6187 
6188     assert(Size<=128 && "Scalar size is too large for CTPOP lower algorithm");
6189     // 8 bits can hold CTPOP result of 128 bit int or smaller. Mul with this
6190     // bitmask will set 8 msb in ResTmp to sum of all B8Counts in 8 bit blocks.
6191     auto MulMask = B.buildConstant(Ty, APInt::getSplat(Size, APInt(8, 0x01)));
6192     auto ResTmp = B.buildMul(Ty, B8Count, MulMask);
6193 
6194     // Shift count result from 8 high bits to low bits.
6195     auto C_SizeM8 = B.buildConstant(Ty, Size - 8);
6196     B.buildLShr(MI.getOperand(0).getReg(), ResTmp, C_SizeM8);
6197 
6198     MI.eraseFromParent();
6199     return Legalized;
6200   }
6201   }
6202 }
6203 
6204 // Check that (every element of) Reg is undef or not an exact multiple of BW.
isNonZeroModBitWidthOrUndef(const MachineRegisterInfo & MRI,Register Reg,unsigned BW)6205 static bool isNonZeroModBitWidthOrUndef(const MachineRegisterInfo &MRI,
6206                                         Register Reg, unsigned BW) {
6207   return matchUnaryPredicate(
6208       MRI, Reg,
6209       [=](const Constant *C) {
6210         // Null constant here means an undef.
6211         const ConstantInt *CI = dyn_cast_or_null<ConstantInt>(C);
6212         return !CI || CI->getValue().urem(BW) != 0;
6213       },
6214       /*AllowUndefs*/ true);
6215 }
6216 
6217 LegalizerHelper::LegalizeResult
lowerFunnelShiftWithInverse(MachineInstr & MI)6218 LegalizerHelper::lowerFunnelShiftWithInverse(MachineInstr &MI) {
6219   auto [Dst, X, Y, Z] = MI.getFirst4Regs();
6220   LLT Ty = MRI.getType(Dst);
6221   LLT ShTy = MRI.getType(Z);
6222 
6223   unsigned BW = Ty.getScalarSizeInBits();
6224 
6225   if (!isPowerOf2_32(BW))
6226     return UnableToLegalize;
6227 
6228   const bool IsFSHL = MI.getOpcode() == TargetOpcode::G_FSHL;
6229   unsigned RevOpcode = IsFSHL ? TargetOpcode::G_FSHR : TargetOpcode::G_FSHL;
6230 
6231   if (isNonZeroModBitWidthOrUndef(MRI, Z, BW)) {
6232     // fshl X, Y, Z -> fshr X, Y, -Z
6233     // fshr X, Y, Z -> fshl X, Y, -Z
6234     auto Zero = MIRBuilder.buildConstant(ShTy, 0);
6235     Z = MIRBuilder.buildSub(Ty, Zero, Z).getReg(0);
6236   } else {
6237     // fshl X, Y, Z -> fshr (srl X, 1), (fshr X, Y, 1), ~Z
6238     // fshr X, Y, Z -> fshl (fshl X, Y, 1), (shl Y, 1), ~Z
6239     auto One = MIRBuilder.buildConstant(ShTy, 1);
6240     if (IsFSHL) {
6241       Y = MIRBuilder.buildInstr(RevOpcode, {Ty}, {X, Y, One}).getReg(0);
6242       X = MIRBuilder.buildLShr(Ty, X, One).getReg(0);
6243     } else {
6244       X = MIRBuilder.buildInstr(RevOpcode, {Ty}, {X, Y, One}).getReg(0);
6245       Y = MIRBuilder.buildShl(Ty, Y, One).getReg(0);
6246     }
6247 
6248     Z = MIRBuilder.buildNot(ShTy, Z).getReg(0);
6249   }
6250 
6251   MIRBuilder.buildInstr(RevOpcode, {Dst}, {X, Y, Z});
6252   MI.eraseFromParent();
6253   return Legalized;
6254 }
6255 
6256 LegalizerHelper::LegalizeResult
lowerFunnelShiftAsShifts(MachineInstr & MI)6257 LegalizerHelper::lowerFunnelShiftAsShifts(MachineInstr &MI) {
6258   auto [Dst, X, Y, Z] = MI.getFirst4Regs();
6259   LLT Ty = MRI.getType(Dst);
6260   LLT ShTy = MRI.getType(Z);
6261 
6262   const unsigned BW = Ty.getScalarSizeInBits();
6263   const bool IsFSHL = MI.getOpcode() == TargetOpcode::G_FSHL;
6264 
6265   Register ShX, ShY;
6266   Register ShAmt, InvShAmt;
6267 
6268   // FIXME: Emit optimized urem by constant instead of letting it expand later.
6269   if (isNonZeroModBitWidthOrUndef(MRI, Z, BW)) {
6270     // fshl: X << C | Y >> (BW - C)
6271     // fshr: X << (BW - C) | Y >> C
6272     // where C = Z % BW is not zero
6273     auto BitWidthC = MIRBuilder.buildConstant(ShTy, BW);
6274     ShAmt = MIRBuilder.buildURem(ShTy, Z, BitWidthC).getReg(0);
6275     InvShAmt = MIRBuilder.buildSub(ShTy, BitWidthC, ShAmt).getReg(0);
6276     ShX = MIRBuilder.buildShl(Ty, X, IsFSHL ? ShAmt : InvShAmt).getReg(0);
6277     ShY = MIRBuilder.buildLShr(Ty, Y, IsFSHL ? InvShAmt : ShAmt).getReg(0);
6278   } else {
6279     // fshl: X << (Z % BW) | Y >> 1 >> (BW - 1 - (Z % BW))
6280     // fshr: X << 1 << (BW - 1 - (Z % BW)) | Y >> (Z % BW)
6281     auto Mask = MIRBuilder.buildConstant(ShTy, BW - 1);
6282     if (isPowerOf2_32(BW)) {
6283       // Z % BW -> Z & (BW - 1)
6284       ShAmt = MIRBuilder.buildAnd(ShTy, Z, Mask).getReg(0);
6285       // (BW - 1) - (Z % BW) -> ~Z & (BW - 1)
6286       auto NotZ = MIRBuilder.buildNot(ShTy, Z);
6287       InvShAmt = MIRBuilder.buildAnd(ShTy, NotZ, Mask).getReg(0);
6288     } else {
6289       auto BitWidthC = MIRBuilder.buildConstant(ShTy, BW);
6290       ShAmt = MIRBuilder.buildURem(ShTy, Z, BitWidthC).getReg(0);
6291       InvShAmt = MIRBuilder.buildSub(ShTy, Mask, ShAmt).getReg(0);
6292     }
6293 
6294     auto One = MIRBuilder.buildConstant(ShTy, 1);
6295     if (IsFSHL) {
6296       ShX = MIRBuilder.buildShl(Ty, X, ShAmt).getReg(0);
6297       auto ShY1 = MIRBuilder.buildLShr(Ty, Y, One);
6298       ShY = MIRBuilder.buildLShr(Ty, ShY1, InvShAmt).getReg(0);
6299     } else {
6300       auto ShX1 = MIRBuilder.buildShl(Ty, X, One);
6301       ShX = MIRBuilder.buildShl(Ty, ShX1, InvShAmt).getReg(0);
6302       ShY = MIRBuilder.buildLShr(Ty, Y, ShAmt).getReg(0);
6303     }
6304   }
6305 
6306   MIRBuilder.buildOr(Dst, ShX, ShY);
6307   MI.eraseFromParent();
6308   return Legalized;
6309 }
6310 
6311 LegalizerHelper::LegalizeResult
lowerFunnelShift(MachineInstr & MI)6312 LegalizerHelper::lowerFunnelShift(MachineInstr &MI) {
6313   // These operations approximately do the following (while avoiding undefined
6314   // shifts by BW):
6315   // G_FSHL: (X << (Z % BW)) | (Y >> (BW - (Z % BW)))
6316   // G_FSHR: (X << (BW - (Z % BW))) | (Y >> (Z % BW))
6317   Register Dst = MI.getOperand(0).getReg();
6318   LLT Ty = MRI.getType(Dst);
6319   LLT ShTy = MRI.getType(MI.getOperand(3).getReg());
6320 
6321   bool IsFSHL = MI.getOpcode() == TargetOpcode::G_FSHL;
6322   unsigned RevOpcode = IsFSHL ? TargetOpcode::G_FSHR : TargetOpcode::G_FSHL;
6323 
6324   // TODO: Use smarter heuristic that accounts for vector legalization.
6325   if (LI.getAction({RevOpcode, {Ty, ShTy}}).Action == Lower)
6326     return lowerFunnelShiftAsShifts(MI);
6327 
6328   // This only works for powers of 2, fallback to shifts if it fails.
6329   LegalizerHelper::LegalizeResult Result = lowerFunnelShiftWithInverse(MI);
6330   if (Result == UnableToLegalize)
6331     return lowerFunnelShiftAsShifts(MI);
6332   return Result;
6333 }
6334 
lowerEXT(MachineInstr & MI)6335 LegalizerHelper::LegalizeResult LegalizerHelper::lowerEXT(MachineInstr &MI) {
6336   auto [Dst, Src] = MI.getFirst2Regs();
6337   LLT DstTy = MRI.getType(Dst);
6338   LLT SrcTy = MRI.getType(Src);
6339 
6340   uint32_t DstTySize = DstTy.getSizeInBits();
6341   uint32_t DstTyScalarSize = DstTy.getScalarSizeInBits();
6342   uint32_t SrcTyScalarSize = SrcTy.getScalarSizeInBits();
6343 
6344   if (!isPowerOf2_32(DstTySize) || !isPowerOf2_32(DstTyScalarSize) ||
6345       !isPowerOf2_32(SrcTyScalarSize))
6346     return UnableToLegalize;
6347 
6348   // The step between extend is too large, split it by creating an intermediate
6349   // extend instruction
6350   if (SrcTyScalarSize * 2 < DstTyScalarSize) {
6351     LLT MidTy = SrcTy.changeElementSize(SrcTyScalarSize * 2);
6352     // If the destination type is illegal, split it into multiple statements
6353     // zext x -> zext(merge(zext(unmerge), zext(unmerge)))
6354     auto NewExt = MIRBuilder.buildInstr(MI.getOpcode(), {MidTy}, {Src});
6355     // Unmerge the vector
6356     LLT EltTy = MidTy.changeElementCount(
6357         MidTy.getElementCount().divideCoefficientBy(2));
6358     auto UnmergeSrc = MIRBuilder.buildUnmerge(EltTy, NewExt);
6359 
6360     // ZExt the vectors
6361     LLT ZExtResTy = DstTy.changeElementCount(
6362         DstTy.getElementCount().divideCoefficientBy(2));
6363     auto ZExtRes1 = MIRBuilder.buildInstr(MI.getOpcode(), {ZExtResTy},
6364                                           {UnmergeSrc.getReg(0)});
6365     auto ZExtRes2 = MIRBuilder.buildInstr(MI.getOpcode(), {ZExtResTy},
6366                                           {UnmergeSrc.getReg(1)});
6367 
6368     // Merge the ending vectors
6369     MIRBuilder.buildMergeLikeInstr(Dst, {ZExtRes1, ZExtRes2});
6370 
6371     MI.eraseFromParent();
6372     return Legalized;
6373   }
6374   return UnableToLegalize;
6375 }
6376 
lowerTRUNC(MachineInstr & MI)6377 LegalizerHelper::LegalizeResult LegalizerHelper::lowerTRUNC(MachineInstr &MI) {
6378   // MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
6379   MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
6380   // Similar to how operand splitting is done in SelectiondDAG, we can handle
6381   // %res(v8s8) = G_TRUNC %in(v8s32) by generating:
6382   //   %inlo(<4x s32>), %inhi(<4 x s32>) = G_UNMERGE %in(<8 x s32>)
6383   //   %lo16(<4 x s16>) = G_TRUNC %inlo
6384   //   %hi16(<4 x s16>) = G_TRUNC %inhi
6385   //   %in16(<8 x s16>) = G_CONCAT_VECTORS %lo16, %hi16
6386   //   %res(<8 x s8>) = G_TRUNC %in16
6387 
6388   assert(MI.getOpcode() == TargetOpcode::G_TRUNC);
6389 
6390   Register DstReg = MI.getOperand(0).getReg();
6391   Register SrcReg = MI.getOperand(1).getReg();
6392   LLT DstTy = MRI.getType(DstReg);
6393   LLT SrcTy = MRI.getType(SrcReg);
6394 
6395   if (DstTy.isVector() && isPowerOf2_32(DstTy.getNumElements()) &&
6396       isPowerOf2_32(DstTy.getScalarSizeInBits()) &&
6397       isPowerOf2_32(SrcTy.getNumElements()) &&
6398       isPowerOf2_32(SrcTy.getScalarSizeInBits())) {
6399     // Split input type.
6400     LLT SplitSrcTy = SrcTy.changeElementCount(
6401         SrcTy.getElementCount().divideCoefficientBy(2));
6402 
6403     // First, split the source into two smaller vectors.
6404     SmallVector<Register, 2> SplitSrcs;
6405     extractParts(SrcReg, SplitSrcTy, 2, SplitSrcs, MIRBuilder, MRI);
6406 
6407     // Truncate the splits into intermediate narrower elements.
6408     LLT InterTy;
6409     if (DstTy.getScalarSizeInBits() * 2 < SrcTy.getScalarSizeInBits())
6410       InterTy = SplitSrcTy.changeElementSize(DstTy.getScalarSizeInBits() * 2);
6411     else
6412       InterTy = SplitSrcTy.changeElementSize(DstTy.getScalarSizeInBits());
6413     for (unsigned I = 0; I < SplitSrcs.size(); ++I) {
6414       SplitSrcs[I] = MIRBuilder.buildTrunc(InterTy, SplitSrcs[I]).getReg(0);
6415     }
6416 
6417     // Combine the new truncates into one vector
6418     auto Merge = MIRBuilder.buildMergeLikeInstr(
6419         DstTy.changeElementSize(InterTy.getScalarSizeInBits()), SplitSrcs);
6420 
6421     // Truncate the new vector to the final result type
6422     if (DstTy.getScalarSizeInBits() * 2 < SrcTy.getScalarSizeInBits())
6423       MIRBuilder.buildTrunc(MI.getOperand(0).getReg(), Merge.getReg(0));
6424     else
6425       MIRBuilder.buildCopy(MI.getOperand(0).getReg(), Merge.getReg(0));
6426 
6427     MI.eraseFromParent();
6428 
6429     return Legalized;
6430   }
6431   return UnableToLegalize;
6432 }
6433 
6434 LegalizerHelper::LegalizeResult
lowerRotateWithReverseRotate(MachineInstr & MI)6435 LegalizerHelper::lowerRotateWithReverseRotate(MachineInstr &MI) {
6436   auto [Dst, DstTy, Src, SrcTy, Amt, AmtTy] = MI.getFirst3RegLLTs();
6437   auto Zero = MIRBuilder.buildConstant(AmtTy, 0);
6438   bool IsLeft = MI.getOpcode() == TargetOpcode::G_ROTL;
6439   unsigned RevRot = IsLeft ? TargetOpcode::G_ROTR : TargetOpcode::G_ROTL;
6440   auto Neg = MIRBuilder.buildSub(AmtTy, Zero, Amt);
6441   MIRBuilder.buildInstr(RevRot, {Dst}, {Src, Neg});
6442   MI.eraseFromParent();
6443   return Legalized;
6444 }
6445 
lowerRotate(MachineInstr & MI)6446 LegalizerHelper::LegalizeResult LegalizerHelper::lowerRotate(MachineInstr &MI) {
6447   auto [Dst, DstTy, Src, SrcTy, Amt, AmtTy] = MI.getFirst3RegLLTs();
6448 
6449   unsigned EltSizeInBits = DstTy.getScalarSizeInBits();
6450   bool IsLeft = MI.getOpcode() == TargetOpcode::G_ROTL;
6451 
6452   MIRBuilder.setInstrAndDebugLoc(MI);
6453 
6454   // If a rotate in the other direction is supported, use it.
6455   unsigned RevRot = IsLeft ? TargetOpcode::G_ROTR : TargetOpcode::G_ROTL;
6456   if (LI.isLegalOrCustom({RevRot, {DstTy, SrcTy}}) &&
6457       isPowerOf2_32(EltSizeInBits))
6458     return lowerRotateWithReverseRotate(MI);
6459 
6460   // If a funnel shift is supported, use it.
6461   unsigned FShOpc = IsLeft ? TargetOpcode::G_FSHL : TargetOpcode::G_FSHR;
6462   unsigned RevFsh = !IsLeft ? TargetOpcode::G_FSHL : TargetOpcode::G_FSHR;
6463   bool IsFShLegal = false;
6464   if ((IsFShLegal = LI.isLegalOrCustom({FShOpc, {DstTy, AmtTy}})) ||
6465       LI.isLegalOrCustom({RevFsh, {DstTy, AmtTy}})) {
6466     auto buildFunnelShift = [&](unsigned Opc, Register R1, Register R2,
6467                                 Register R3) {
6468       MIRBuilder.buildInstr(Opc, {R1}, {R2, R2, R3});
6469       MI.eraseFromParent();
6470       return Legalized;
6471     };
6472     // If a funnel shift in the other direction is supported, use it.
6473     if (IsFShLegal) {
6474       return buildFunnelShift(FShOpc, Dst, Src, Amt);
6475     } else if (isPowerOf2_32(EltSizeInBits)) {
6476       Amt = MIRBuilder.buildNeg(DstTy, Amt).getReg(0);
6477       return buildFunnelShift(RevFsh, Dst, Src, Amt);
6478     }
6479   }
6480 
6481   auto Zero = MIRBuilder.buildConstant(AmtTy, 0);
6482   unsigned ShOpc = IsLeft ? TargetOpcode::G_SHL : TargetOpcode::G_LSHR;
6483   unsigned RevShiftOpc = IsLeft ? TargetOpcode::G_LSHR : TargetOpcode::G_SHL;
6484   auto BitWidthMinusOneC = MIRBuilder.buildConstant(AmtTy, EltSizeInBits - 1);
6485   Register ShVal;
6486   Register RevShiftVal;
6487   if (isPowerOf2_32(EltSizeInBits)) {
6488     // (rotl x, c) -> x << (c & (w - 1)) | x >> (-c & (w - 1))
6489     // (rotr x, c) -> x >> (c & (w - 1)) | x << (-c & (w - 1))
6490     auto NegAmt = MIRBuilder.buildSub(AmtTy, Zero, Amt);
6491     auto ShAmt = MIRBuilder.buildAnd(AmtTy, Amt, BitWidthMinusOneC);
6492     ShVal = MIRBuilder.buildInstr(ShOpc, {DstTy}, {Src, ShAmt}).getReg(0);
6493     auto RevAmt = MIRBuilder.buildAnd(AmtTy, NegAmt, BitWidthMinusOneC);
6494     RevShiftVal =
6495         MIRBuilder.buildInstr(RevShiftOpc, {DstTy}, {Src, RevAmt}).getReg(0);
6496   } else {
6497     // (rotl x, c) -> x << (c % w) | x >> 1 >> (w - 1 - (c % w))
6498     // (rotr x, c) -> x >> (c % w) | x << 1 << (w - 1 - (c % w))
6499     auto BitWidthC = MIRBuilder.buildConstant(AmtTy, EltSizeInBits);
6500     auto ShAmt = MIRBuilder.buildURem(AmtTy, Amt, BitWidthC);
6501     ShVal = MIRBuilder.buildInstr(ShOpc, {DstTy}, {Src, ShAmt}).getReg(0);
6502     auto RevAmt = MIRBuilder.buildSub(AmtTy, BitWidthMinusOneC, ShAmt);
6503     auto One = MIRBuilder.buildConstant(AmtTy, 1);
6504     auto Inner = MIRBuilder.buildInstr(RevShiftOpc, {DstTy}, {Src, One});
6505     RevShiftVal =
6506         MIRBuilder.buildInstr(RevShiftOpc, {DstTy}, {Inner, RevAmt}).getReg(0);
6507   }
6508   MIRBuilder.buildOr(Dst, ShVal, RevShiftVal);
6509   MI.eraseFromParent();
6510   return Legalized;
6511 }
6512 
6513 // Expand s32 = G_UITOFP s64 using bit operations to an IEEE float
6514 // representation.
6515 LegalizerHelper::LegalizeResult
lowerU64ToF32BitOps(MachineInstr & MI)6516 LegalizerHelper::lowerU64ToF32BitOps(MachineInstr &MI) {
6517   auto [Dst, Src] = MI.getFirst2Regs();
6518   const LLT S64 = LLT::scalar(64);
6519   const LLT S32 = LLT::scalar(32);
6520   const LLT S1 = LLT::scalar(1);
6521 
6522   assert(MRI.getType(Src) == S64 && MRI.getType(Dst) == S32);
6523 
6524   // unsigned cul2f(ulong u) {
6525   //   uint lz = clz(u);
6526   //   uint e = (u != 0) ? 127U + 63U - lz : 0;
6527   //   u = (u << lz) & 0x7fffffffffffffffUL;
6528   //   ulong t = u & 0xffffffffffUL;
6529   //   uint v = (e << 23) | (uint)(u >> 40);
6530   //   uint r = t > 0x8000000000UL ? 1U : (t == 0x8000000000UL ? v & 1U : 0U);
6531   //   return as_float(v + r);
6532   // }
6533 
6534   auto Zero32 = MIRBuilder.buildConstant(S32, 0);
6535   auto Zero64 = MIRBuilder.buildConstant(S64, 0);
6536 
6537   auto LZ = MIRBuilder.buildCTLZ_ZERO_UNDEF(S32, Src);
6538 
6539   auto K = MIRBuilder.buildConstant(S32, 127U + 63U);
6540   auto Sub = MIRBuilder.buildSub(S32, K, LZ);
6541 
6542   auto NotZero = MIRBuilder.buildICmp(CmpInst::ICMP_NE, S1, Src, Zero64);
6543   auto E = MIRBuilder.buildSelect(S32, NotZero, Sub, Zero32);
6544 
6545   auto Mask0 = MIRBuilder.buildConstant(S64, (-1ULL) >> 1);
6546   auto ShlLZ = MIRBuilder.buildShl(S64, Src, LZ);
6547 
6548   auto U = MIRBuilder.buildAnd(S64, ShlLZ, Mask0);
6549 
6550   auto Mask1 = MIRBuilder.buildConstant(S64, 0xffffffffffULL);
6551   auto T = MIRBuilder.buildAnd(S64, U, Mask1);
6552 
6553   auto UShl = MIRBuilder.buildLShr(S64, U, MIRBuilder.buildConstant(S64, 40));
6554   auto ShlE = MIRBuilder.buildShl(S32, E, MIRBuilder.buildConstant(S32, 23));
6555   auto V = MIRBuilder.buildOr(S32, ShlE, MIRBuilder.buildTrunc(S32, UShl));
6556 
6557   auto C = MIRBuilder.buildConstant(S64, 0x8000000000ULL);
6558   auto RCmp = MIRBuilder.buildICmp(CmpInst::ICMP_UGT, S1, T, C);
6559   auto TCmp = MIRBuilder.buildICmp(CmpInst::ICMP_EQ, S1, T, C);
6560   auto One = MIRBuilder.buildConstant(S32, 1);
6561 
6562   auto VTrunc1 = MIRBuilder.buildAnd(S32, V, One);
6563   auto Select0 = MIRBuilder.buildSelect(S32, TCmp, VTrunc1, Zero32);
6564   auto R = MIRBuilder.buildSelect(S32, RCmp, One, Select0);
6565   MIRBuilder.buildAdd(Dst, V, R);
6566 
6567   MI.eraseFromParent();
6568   return Legalized;
6569 }
6570 
lowerUITOFP(MachineInstr & MI)6571 LegalizerHelper::LegalizeResult LegalizerHelper::lowerUITOFP(MachineInstr &MI) {
6572   auto [Dst, DstTy, Src, SrcTy] = MI.getFirst2RegLLTs();
6573 
6574   if (SrcTy == LLT::scalar(1)) {
6575     auto True = MIRBuilder.buildFConstant(DstTy, 1.0);
6576     auto False = MIRBuilder.buildFConstant(DstTy, 0.0);
6577     MIRBuilder.buildSelect(Dst, Src, True, False);
6578     MI.eraseFromParent();
6579     return Legalized;
6580   }
6581 
6582   if (SrcTy != LLT::scalar(64))
6583     return UnableToLegalize;
6584 
6585   if (DstTy == LLT::scalar(32)) {
6586     // TODO: SelectionDAG has several alternative expansions to port which may
6587     // be more reasonble depending on the available instructions. If a target
6588     // has sitofp, does not have CTLZ, or can efficiently use f64 as an
6589     // intermediate type, this is probably worse.
6590     return lowerU64ToF32BitOps(MI);
6591   }
6592 
6593   return UnableToLegalize;
6594 }
6595 
lowerSITOFP(MachineInstr & MI)6596 LegalizerHelper::LegalizeResult LegalizerHelper::lowerSITOFP(MachineInstr &MI) {
6597   auto [Dst, DstTy, Src, SrcTy] = MI.getFirst2RegLLTs();
6598 
6599   const LLT S64 = LLT::scalar(64);
6600   const LLT S32 = LLT::scalar(32);
6601   const LLT S1 = LLT::scalar(1);
6602 
6603   if (SrcTy == S1) {
6604     auto True = MIRBuilder.buildFConstant(DstTy, -1.0);
6605     auto False = MIRBuilder.buildFConstant(DstTy, 0.0);
6606     MIRBuilder.buildSelect(Dst, Src, True, False);
6607     MI.eraseFromParent();
6608     return Legalized;
6609   }
6610 
6611   if (SrcTy != S64)
6612     return UnableToLegalize;
6613 
6614   if (DstTy == S32) {
6615     // signed cl2f(long l) {
6616     //   long s = l >> 63;
6617     //   float r = cul2f((l + s) ^ s);
6618     //   return s ? -r : r;
6619     // }
6620     Register L = Src;
6621     auto SignBit = MIRBuilder.buildConstant(S64, 63);
6622     auto S = MIRBuilder.buildAShr(S64, L, SignBit);
6623 
6624     auto LPlusS = MIRBuilder.buildAdd(S64, L, S);
6625     auto Xor = MIRBuilder.buildXor(S64, LPlusS, S);
6626     auto R = MIRBuilder.buildUITOFP(S32, Xor);
6627 
6628     auto RNeg = MIRBuilder.buildFNeg(S32, R);
6629     auto SignNotZero = MIRBuilder.buildICmp(CmpInst::ICMP_NE, S1, S,
6630                                             MIRBuilder.buildConstant(S64, 0));
6631     MIRBuilder.buildSelect(Dst, SignNotZero, RNeg, R);
6632     MI.eraseFromParent();
6633     return Legalized;
6634   }
6635 
6636   return UnableToLegalize;
6637 }
6638 
lowerFPTOUI(MachineInstr & MI)6639 LegalizerHelper::LegalizeResult LegalizerHelper::lowerFPTOUI(MachineInstr &MI) {
6640   auto [Dst, DstTy, Src, SrcTy] = MI.getFirst2RegLLTs();
6641   const LLT S64 = LLT::scalar(64);
6642   const LLT S32 = LLT::scalar(32);
6643 
6644   if (SrcTy != S64 && SrcTy != S32)
6645     return UnableToLegalize;
6646   if (DstTy != S32 && DstTy != S64)
6647     return UnableToLegalize;
6648 
6649   // FPTOSI gives same result as FPTOUI for positive signed integers.
6650   // FPTOUI needs to deal with fp values that convert to unsigned integers
6651   // greater or equal to 2^31 for float or 2^63 for double. For brevity 2^Exp.
6652 
6653   APInt TwoPExpInt = APInt::getSignMask(DstTy.getSizeInBits());
6654   APFloat TwoPExpFP(SrcTy.getSizeInBits() == 32 ? APFloat::IEEEsingle()
6655                                                 : APFloat::IEEEdouble(),
6656                     APInt::getZero(SrcTy.getSizeInBits()));
6657   TwoPExpFP.convertFromAPInt(TwoPExpInt, false, APFloat::rmNearestTiesToEven);
6658 
6659   MachineInstrBuilder FPTOSI = MIRBuilder.buildFPTOSI(DstTy, Src);
6660 
6661   MachineInstrBuilder Threshold = MIRBuilder.buildFConstant(SrcTy, TwoPExpFP);
6662   // For fp Value greater or equal to Threshold(2^Exp), we use FPTOSI on
6663   // (Value - 2^Exp) and add 2^Exp by setting highest bit in result to 1.
6664   MachineInstrBuilder FSub = MIRBuilder.buildFSub(SrcTy, Src, Threshold);
6665   MachineInstrBuilder ResLowBits = MIRBuilder.buildFPTOSI(DstTy, FSub);
6666   MachineInstrBuilder ResHighBit = MIRBuilder.buildConstant(DstTy, TwoPExpInt);
6667   MachineInstrBuilder Res = MIRBuilder.buildXor(DstTy, ResLowBits, ResHighBit);
6668 
6669   const LLT S1 = LLT::scalar(1);
6670 
6671   MachineInstrBuilder FCMP =
6672       MIRBuilder.buildFCmp(CmpInst::FCMP_ULT, S1, Src, Threshold);
6673   MIRBuilder.buildSelect(Dst, FCMP, FPTOSI, Res);
6674 
6675   MI.eraseFromParent();
6676   return Legalized;
6677 }
6678 
lowerFPTOSI(MachineInstr & MI)6679 LegalizerHelper::LegalizeResult LegalizerHelper::lowerFPTOSI(MachineInstr &MI) {
6680   auto [Dst, DstTy, Src, SrcTy] = MI.getFirst2RegLLTs();
6681   const LLT S64 = LLT::scalar(64);
6682   const LLT S32 = LLT::scalar(32);
6683 
6684   // FIXME: Only f32 to i64 conversions are supported.
6685   if (SrcTy.getScalarType() != S32 || DstTy.getScalarType() != S64)
6686     return UnableToLegalize;
6687 
6688   // Expand f32 -> i64 conversion
6689   // This algorithm comes from compiler-rt's implementation of fixsfdi:
6690   // https://github.com/llvm/llvm-project/blob/main/compiler-rt/lib/builtins/fixsfdi.c
6691 
6692   unsigned SrcEltBits = SrcTy.getScalarSizeInBits();
6693 
6694   auto ExponentMask = MIRBuilder.buildConstant(SrcTy, 0x7F800000);
6695   auto ExponentLoBit = MIRBuilder.buildConstant(SrcTy, 23);
6696 
6697   auto AndExpMask = MIRBuilder.buildAnd(SrcTy, Src, ExponentMask);
6698   auto ExponentBits = MIRBuilder.buildLShr(SrcTy, AndExpMask, ExponentLoBit);
6699 
6700   auto SignMask = MIRBuilder.buildConstant(SrcTy,
6701                                            APInt::getSignMask(SrcEltBits));
6702   auto AndSignMask = MIRBuilder.buildAnd(SrcTy, Src, SignMask);
6703   auto SignLowBit = MIRBuilder.buildConstant(SrcTy, SrcEltBits - 1);
6704   auto Sign = MIRBuilder.buildAShr(SrcTy, AndSignMask, SignLowBit);
6705   Sign = MIRBuilder.buildSExt(DstTy, Sign);
6706 
6707   auto MantissaMask = MIRBuilder.buildConstant(SrcTy, 0x007FFFFF);
6708   auto AndMantissaMask = MIRBuilder.buildAnd(SrcTy, Src, MantissaMask);
6709   auto K = MIRBuilder.buildConstant(SrcTy, 0x00800000);
6710 
6711   auto R = MIRBuilder.buildOr(SrcTy, AndMantissaMask, K);
6712   R = MIRBuilder.buildZExt(DstTy, R);
6713 
6714   auto Bias = MIRBuilder.buildConstant(SrcTy, 127);
6715   auto Exponent = MIRBuilder.buildSub(SrcTy, ExponentBits, Bias);
6716   auto SubExponent = MIRBuilder.buildSub(SrcTy, Exponent, ExponentLoBit);
6717   auto ExponentSub = MIRBuilder.buildSub(SrcTy, ExponentLoBit, Exponent);
6718 
6719   auto Shl = MIRBuilder.buildShl(DstTy, R, SubExponent);
6720   auto Srl = MIRBuilder.buildLShr(DstTy, R, ExponentSub);
6721 
6722   const LLT S1 = LLT::scalar(1);
6723   auto CmpGt = MIRBuilder.buildICmp(CmpInst::ICMP_SGT,
6724                                     S1, Exponent, ExponentLoBit);
6725 
6726   R = MIRBuilder.buildSelect(DstTy, CmpGt, Shl, Srl);
6727 
6728   auto XorSign = MIRBuilder.buildXor(DstTy, R, Sign);
6729   auto Ret = MIRBuilder.buildSub(DstTy, XorSign, Sign);
6730 
6731   auto ZeroSrcTy = MIRBuilder.buildConstant(SrcTy, 0);
6732 
6733   auto ExponentLt0 = MIRBuilder.buildICmp(CmpInst::ICMP_SLT,
6734                                           S1, Exponent, ZeroSrcTy);
6735 
6736   auto ZeroDstTy = MIRBuilder.buildConstant(DstTy, 0);
6737   MIRBuilder.buildSelect(Dst, ExponentLt0, ZeroDstTy, Ret);
6738 
6739   MI.eraseFromParent();
6740   return Legalized;
6741 }
6742 
6743 // f64 -> f16 conversion using round-to-nearest-even rounding mode.
6744 LegalizerHelper::LegalizeResult
lowerFPTRUNC_F64_TO_F16(MachineInstr & MI)6745 LegalizerHelper::lowerFPTRUNC_F64_TO_F16(MachineInstr &MI) {
6746   const LLT S1 = LLT::scalar(1);
6747   const LLT S32 = LLT::scalar(32);
6748 
6749   auto [Dst, Src] = MI.getFirst2Regs();
6750   assert(MRI.getType(Dst).getScalarType() == LLT::scalar(16) &&
6751          MRI.getType(Src).getScalarType() == LLT::scalar(64));
6752 
6753   if (MRI.getType(Src).isVector()) // TODO: Handle vectors directly.
6754     return UnableToLegalize;
6755 
6756   if (MIRBuilder.getMF().getTarget().Options.UnsafeFPMath) {
6757     unsigned Flags = MI.getFlags();
6758     auto Src32 = MIRBuilder.buildFPTrunc(S32, Src, Flags);
6759     MIRBuilder.buildFPTrunc(Dst, Src32, Flags);
6760     MI.eraseFromParent();
6761     return Legalized;
6762   }
6763 
6764   const unsigned ExpMask = 0x7ff;
6765   const unsigned ExpBiasf64 = 1023;
6766   const unsigned ExpBiasf16 = 15;
6767 
6768   auto Unmerge = MIRBuilder.buildUnmerge(S32, Src);
6769   Register U = Unmerge.getReg(0);
6770   Register UH = Unmerge.getReg(1);
6771 
6772   auto E = MIRBuilder.buildLShr(S32, UH, MIRBuilder.buildConstant(S32, 20));
6773   E = MIRBuilder.buildAnd(S32, E, MIRBuilder.buildConstant(S32, ExpMask));
6774 
6775   // Subtract the fp64 exponent bias (1023) to get the real exponent and
6776   // add the f16 bias (15) to get the biased exponent for the f16 format.
6777   E = MIRBuilder.buildAdd(
6778     S32, E, MIRBuilder.buildConstant(S32, -ExpBiasf64 + ExpBiasf16));
6779 
6780   auto M = MIRBuilder.buildLShr(S32, UH, MIRBuilder.buildConstant(S32, 8));
6781   M = MIRBuilder.buildAnd(S32, M, MIRBuilder.buildConstant(S32, 0xffe));
6782 
6783   auto MaskedSig = MIRBuilder.buildAnd(S32, UH,
6784                                        MIRBuilder.buildConstant(S32, 0x1ff));
6785   MaskedSig = MIRBuilder.buildOr(S32, MaskedSig, U);
6786 
6787   auto Zero = MIRBuilder.buildConstant(S32, 0);
6788   auto SigCmpNE0 = MIRBuilder.buildICmp(CmpInst::ICMP_NE, S1, MaskedSig, Zero);
6789   auto Lo40Set = MIRBuilder.buildZExt(S32, SigCmpNE0);
6790   M = MIRBuilder.buildOr(S32, M, Lo40Set);
6791 
6792   // (M != 0 ? 0x0200 : 0) | 0x7c00;
6793   auto Bits0x200 = MIRBuilder.buildConstant(S32, 0x0200);
6794   auto CmpM_NE0 = MIRBuilder.buildICmp(CmpInst::ICMP_NE, S1, M, Zero);
6795   auto SelectCC = MIRBuilder.buildSelect(S32, CmpM_NE0, Bits0x200, Zero);
6796 
6797   auto Bits0x7c00 = MIRBuilder.buildConstant(S32, 0x7c00);
6798   auto I = MIRBuilder.buildOr(S32, SelectCC, Bits0x7c00);
6799 
6800   // N = M | (E << 12);
6801   auto EShl12 = MIRBuilder.buildShl(S32, E, MIRBuilder.buildConstant(S32, 12));
6802   auto N = MIRBuilder.buildOr(S32, M, EShl12);
6803 
6804   // B = clamp(1-E, 0, 13);
6805   auto One = MIRBuilder.buildConstant(S32, 1);
6806   auto OneSubExp = MIRBuilder.buildSub(S32, One, E);
6807   auto B = MIRBuilder.buildSMax(S32, OneSubExp, Zero);
6808   B = MIRBuilder.buildSMin(S32, B, MIRBuilder.buildConstant(S32, 13));
6809 
6810   auto SigSetHigh = MIRBuilder.buildOr(S32, M,
6811                                        MIRBuilder.buildConstant(S32, 0x1000));
6812 
6813   auto D = MIRBuilder.buildLShr(S32, SigSetHigh, B);
6814   auto D0 = MIRBuilder.buildShl(S32, D, B);
6815 
6816   auto D0_NE_SigSetHigh = MIRBuilder.buildICmp(CmpInst::ICMP_NE, S1,
6817                                              D0, SigSetHigh);
6818   auto D1 = MIRBuilder.buildZExt(S32, D0_NE_SigSetHigh);
6819   D = MIRBuilder.buildOr(S32, D, D1);
6820 
6821   auto CmpELtOne = MIRBuilder.buildICmp(CmpInst::ICMP_SLT, S1, E, One);
6822   auto V = MIRBuilder.buildSelect(S32, CmpELtOne, D, N);
6823 
6824   auto VLow3 = MIRBuilder.buildAnd(S32, V, MIRBuilder.buildConstant(S32, 7));
6825   V = MIRBuilder.buildLShr(S32, V, MIRBuilder.buildConstant(S32, 2));
6826 
6827   auto VLow3Eq3 = MIRBuilder.buildICmp(CmpInst::ICMP_EQ, S1, VLow3,
6828                                        MIRBuilder.buildConstant(S32, 3));
6829   auto V0 = MIRBuilder.buildZExt(S32, VLow3Eq3);
6830 
6831   auto VLow3Gt5 = MIRBuilder.buildICmp(CmpInst::ICMP_SGT, S1, VLow3,
6832                                        MIRBuilder.buildConstant(S32, 5));
6833   auto V1 = MIRBuilder.buildZExt(S32, VLow3Gt5);
6834 
6835   V1 = MIRBuilder.buildOr(S32, V0, V1);
6836   V = MIRBuilder.buildAdd(S32, V, V1);
6837 
6838   auto CmpEGt30 = MIRBuilder.buildICmp(CmpInst::ICMP_SGT,  S1,
6839                                        E, MIRBuilder.buildConstant(S32, 30));
6840   V = MIRBuilder.buildSelect(S32, CmpEGt30,
6841                              MIRBuilder.buildConstant(S32, 0x7c00), V);
6842 
6843   auto CmpEGt1039 = MIRBuilder.buildICmp(CmpInst::ICMP_EQ, S1,
6844                                          E, MIRBuilder.buildConstant(S32, 1039));
6845   V = MIRBuilder.buildSelect(S32, CmpEGt1039, I, V);
6846 
6847   // Extract the sign bit.
6848   auto Sign = MIRBuilder.buildLShr(S32, UH, MIRBuilder.buildConstant(S32, 16));
6849   Sign = MIRBuilder.buildAnd(S32, Sign, MIRBuilder.buildConstant(S32, 0x8000));
6850 
6851   // Insert the sign bit
6852   V = MIRBuilder.buildOr(S32, Sign, V);
6853 
6854   MIRBuilder.buildTrunc(Dst, V);
6855   MI.eraseFromParent();
6856   return Legalized;
6857 }
6858 
6859 LegalizerHelper::LegalizeResult
lowerFPTRUNC(MachineInstr & MI)6860 LegalizerHelper::lowerFPTRUNC(MachineInstr &MI) {
6861   auto [DstTy, SrcTy] = MI.getFirst2LLTs();
6862   const LLT S64 = LLT::scalar(64);
6863   const LLT S16 = LLT::scalar(16);
6864 
6865   if (DstTy.getScalarType() == S16 && SrcTy.getScalarType() == S64)
6866     return lowerFPTRUNC_F64_TO_F16(MI);
6867 
6868   return UnableToLegalize;
6869 }
6870 
6871 // TODO: If RHS is a constant SelectionDAGBuilder expands this into a
6872 // multiplication tree.
lowerFPOWI(MachineInstr & MI)6873 LegalizerHelper::LegalizeResult LegalizerHelper::lowerFPOWI(MachineInstr &MI) {
6874   auto [Dst, Src0, Src1] = MI.getFirst3Regs();
6875   LLT Ty = MRI.getType(Dst);
6876 
6877   auto CvtSrc1 = MIRBuilder.buildSITOFP(Ty, Src1);
6878   MIRBuilder.buildFPow(Dst, Src0, CvtSrc1, MI.getFlags());
6879   MI.eraseFromParent();
6880   return Legalized;
6881 }
6882 
minMaxToCompare(unsigned Opc)6883 static CmpInst::Predicate minMaxToCompare(unsigned Opc) {
6884   switch (Opc) {
6885   case TargetOpcode::G_SMIN:
6886     return CmpInst::ICMP_SLT;
6887   case TargetOpcode::G_SMAX:
6888     return CmpInst::ICMP_SGT;
6889   case TargetOpcode::G_UMIN:
6890     return CmpInst::ICMP_ULT;
6891   case TargetOpcode::G_UMAX:
6892     return CmpInst::ICMP_UGT;
6893   default:
6894     llvm_unreachable("not in integer min/max");
6895   }
6896 }
6897 
lowerMinMax(MachineInstr & MI)6898 LegalizerHelper::LegalizeResult LegalizerHelper::lowerMinMax(MachineInstr &MI) {
6899   auto [Dst, Src0, Src1] = MI.getFirst3Regs();
6900 
6901   const CmpInst::Predicate Pred = minMaxToCompare(MI.getOpcode());
6902   LLT CmpType = MRI.getType(Dst).changeElementSize(1);
6903 
6904   auto Cmp = MIRBuilder.buildICmp(Pred, CmpType, Src0, Src1);
6905   MIRBuilder.buildSelect(Dst, Cmp, Src0, Src1);
6906 
6907   MI.eraseFromParent();
6908   return Legalized;
6909 }
6910 
6911 LegalizerHelper::LegalizeResult
lowerFCopySign(MachineInstr & MI)6912 LegalizerHelper::lowerFCopySign(MachineInstr &MI) {
6913   auto [Dst, DstTy, Src0, Src0Ty, Src1, Src1Ty] = MI.getFirst3RegLLTs();
6914   const int Src0Size = Src0Ty.getScalarSizeInBits();
6915   const int Src1Size = Src1Ty.getScalarSizeInBits();
6916 
6917   auto SignBitMask = MIRBuilder.buildConstant(
6918     Src0Ty, APInt::getSignMask(Src0Size));
6919 
6920   auto NotSignBitMask = MIRBuilder.buildConstant(
6921     Src0Ty, APInt::getLowBitsSet(Src0Size, Src0Size - 1));
6922 
6923   Register And0 = MIRBuilder.buildAnd(Src0Ty, Src0, NotSignBitMask).getReg(0);
6924   Register And1;
6925   if (Src0Ty == Src1Ty) {
6926     And1 = MIRBuilder.buildAnd(Src1Ty, Src1, SignBitMask).getReg(0);
6927   } else if (Src0Size > Src1Size) {
6928     auto ShiftAmt = MIRBuilder.buildConstant(Src0Ty, Src0Size - Src1Size);
6929     auto Zext = MIRBuilder.buildZExt(Src0Ty, Src1);
6930     auto Shift = MIRBuilder.buildShl(Src0Ty, Zext, ShiftAmt);
6931     And1 = MIRBuilder.buildAnd(Src0Ty, Shift, SignBitMask).getReg(0);
6932   } else {
6933     auto ShiftAmt = MIRBuilder.buildConstant(Src1Ty, Src1Size - Src0Size);
6934     auto Shift = MIRBuilder.buildLShr(Src1Ty, Src1, ShiftAmt);
6935     auto Trunc = MIRBuilder.buildTrunc(Src0Ty, Shift);
6936     And1 = MIRBuilder.buildAnd(Src0Ty, Trunc, SignBitMask).getReg(0);
6937   }
6938 
6939   // Be careful about setting nsz/nnan/ninf on every instruction, since the
6940   // constants are a nan and -0.0, but the final result should preserve
6941   // everything.
6942   unsigned Flags = MI.getFlags();
6943   MIRBuilder.buildOr(Dst, And0, And1, Flags);
6944 
6945   MI.eraseFromParent();
6946   return Legalized;
6947 }
6948 
6949 LegalizerHelper::LegalizeResult
lowerFMinNumMaxNum(MachineInstr & MI)6950 LegalizerHelper::lowerFMinNumMaxNum(MachineInstr &MI) {
6951   unsigned NewOp = MI.getOpcode() == TargetOpcode::G_FMINNUM ?
6952     TargetOpcode::G_FMINNUM_IEEE : TargetOpcode::G_FMAXNUM_IEEE;
6953 
6954   auto [Dst, Src0, Src1] = MI.getFirst3Regs();
6955   LLT Ty = MRI.getType(Dst);
6956 
6957   if (!MI.getFlag(MachineInstr::FmNoNans)) {
6958     // Insert canonicalizes if it's possible we need to quiet to get correct
6959     // sNaN behavior.
6960 
6961     // Note this must be done here, and not as an optimization combine in the
6962     // absence of a dedicate quiet-snan instruction as we're using an
6963     // omni-purpose G_FCANONICALIZE.
6964     if (!isKnownNeverSNaN(Src0, MRI))
6965       Src0 = MIRBuilder.buildFCanonicalize(Ty, Src0, MI.getFlags()).getReg(0);
6966 
6967     if (!isKnownNeverSNaN(Src1, MRI))
6968       Src1 = MIRBuilder.buildFCanonicalize(Ty, Src1, MI.getFlags()).getReg(0);
6969   }
6970 
6971   // If there are no nans, it's safe to simply replace this with the non-IEEE
6972   // version.
6973   MIRBuilder.buildInstr(NewOp, {Dst}, {Src0, Src1}, MI.getFlags());
6974   MI.eraseFromParent();
6975   return Legalized;
6976 }
6977 
lowerFMad(MachineInstr & MI)6978 LegalizerHelper::LegalizeResult LegalizerHelper::lowerFMad(MachineInstr &MI) {
6979   // Expand G_FMAD a, b, c -> G_FADD (G_FMUL a, b), c
6980   Register DstReg = MI.getOperand(0).getReg();
6981   LLT Ty = MRI.getType(DstReg);
6982   unsigned Flags = MI.getFlags();
6983 
6984   auto Mul = MIRBuilder.buildFMul(Ty, MI.getOperand(1), MI.getOperand(2),
6985                                   Flags);
6986   MIRBuilder.buildFAdd(DstReg, Mul, MI.getOperand(3), Flags);
6987   MI.eraseFromParent();
6988   return Legalized;
6989 }
6990 
6991 LegalizerHelper::LegalizeResult
lowerIntrinsicRound(MachineInstr & MI)6992 LegalizerHelper::lowerIntrinsicRound(MachineInstr &MI) {
6993   auto [DstReg, X] = MI.getFirst2Regs();
6994   const unsigned Flags = MI.getFlags();
6995   const LLT Ty = MRI.getType(DstReg);
6996   const LLT CondTy = Ty.changeElementSize(1);
6997 
6998   // round(x) =>
6999   //  t = trunc(x);
7000   //  d = fabs(x - t);
7001   //  o = copysign(d >= 0.5 ? 1.0 : 0.0, x);
7002   //  return t + o;
7003 
7004   auto T = MIRBuilder.buildIntrinsicTrunc(Ty, X, Flags);
7005 
7006   auto Diff = MIRBuilder.buildFSub(Ty, X, T, Flags);
7007   auto AbsDiff = MIRBuilder.buildFAbs(Ty, Diff, Flags);
7008 
7009   auto Half = MIRBuilder.buildFConstant(Ty, 0.5);
7010   auto Cmp =
7011       MIRBuilder.buildFCmp(CmpInst::FCMP_OGE, CondTy, AbsDiff, Half, Flags);
7012 
7013   // Could emit G_UITOFP instead
7014   auto One = MIRBuilder.buildFConstant(Ty, 1.0);
7015   auto Zero = MIRBuilder.buildFConstant(Ty, 0.0);
7016   auto BoolFP = MIRBuilder.buildSelect(Ty, Cmp, One, Zero);
7017   auto SignedOffset = MIRBuilder.buildFCopysign(Ty, BoolFP, X);
7018 
7019   MIRBuilder.buildFAdd(DstReg, T, SignedOffset, Flags);
7020 
7021   MI.eraseFromParent();
7022   return Legalized;
7023 }
7024 
lowerFFloor(MachineInstr & MI)7025 LegalizerHelper::LegalizeResult LegalizerHelper::lowerFFloor(MachineInstr &MI) {
7026   auto [DstReg, SrcReg] = MI.getFirst2Regs();
7027   unsigned Flags = MI.getFlags();
7028   LLT Ty = MRI.getType(DstReg);
7029   const LLT CondTy = Ty.changeElementSize(1);
7030 
7031   // result = trunc(src);
7032   // if (src < 0.0 && src != result)
7033   //   result += -1.0.
7034 
7035   auto Trunc = MIRBuilder.buildIntrinsicTrunc(Ty, SrcReg, Flags);
7036   auto Zero = MIRBuilder.buildFConstant(Ty, 0.0);
7037 
7038   auto Lt0 = MIRBuilder.buildFCmp(CmpInst::FCMP_OLT, CondTy,
7039                                   SrcReg, Zero, Flags);
7040   auto NeTrunc = MIRBuilder.buildFCmp(CmpInst::FCMP_ONE, CondTy,
7041                                       SrcReg, Trunc, Flags);
7042   auto And = MIRBuilder.buildAnd(CondTy, Lt0, NeTrunc);
7043   auto AddVal = MIRBuilder.buildSITOFP(Ty, And);
7044 
7045   MIRBuilder.buildFAdd(DstReg, Trunc, AddVal, Flags);
7046   MI.eraseFromParent();
7047   return Legalized;
7048 }
7049 
7050 LegalizerHelper::LegalizeResult
lowerMergeValues(MachineInstr & MI)7051 LegalizerHelper::lowerMergeValues(MachineInstr &MI) {
7052   const unsigned NumOps = MI.getNumOperands();
7053   auto [DstReg, DstTy, Src0Reg, Src0Ty] = MI.getFirst2RegLLTs();
7054   unsigned PartSize = Src0Ty.getSizeInBits();
7055 
7056   LLT WideTy = LLT::scalar(DstTy.getSizeInBits());
7057   Register ResultReg = MIRBuilder.buildZExt(WideTy, Src0Reg).getReg(0);
7058 
7059   for (unsigned I = 2; I != NumOps; ++I) {
7060     const unsigned Offset = (I - 1) * PartSize;
7061 
7062     Register SrcReg = MI.getOperand(I).getReg();
7063     auto ZextInput = MIRBuilder.buildZExt(WideTy, SrcReg);
7064 
7065     Register NextResult = I + 1 == NumOps && WideTy == DstTy ? DstReg :
7066       MRI.createGenericVirtualRegister(WideTy);
7067 
7068     auto ShiftAmt = MIRBuilder.buildConstant(WideTy, Offset);
7069     auto Shl = MIRBuilder.buildShl(WideTy, ZextInput, ShiftAmt);
7070     MIRBuilder.buildOr(NextResult, ResultReg, Shl);
7071     ResultReg = NextResult;
7072   }
7073 
7074   if (DstTy.isPointer()) {
7075     if (MIRBuilder.getDataLayout().isNonIntegralAddressSpace(
7076           DstTy.getAddressSpace())) {
7077       LLVM_DEBUG(dbgs() << "Not casting nonintegral address space\n");
7078       return UnableToLegalize;
7079     }
7080 
7081     MIRBuilder.buildIntToPtr(DstReg, ResultReg);
7082   }
7083 
7084   MI.eraseFromParent();
7085   return Legalized;
7086 }
7087 
7088 LegalizerHelper::LegalizeResult
lowerUnmergeValues(MachineInstr & MI)7089 LegalizerHelper::lowerUnmergeValues(MachineInstr &MI) {
7090   const unsigned NumDst = MI.getNumOperands() - 1;
7091   Register SrcReg = MI.getOperand(NumDst).getReg();
7092   Register Dst0Reg = MI.getOperand(0).getReg();
7093   LLT DstTy = MRI.getType(Dst0Reg);
7094   if (DstTy.isPointer())
7095     return UnableToLegalize; // TODO
7096 
7097   SrcReg = coerceToScalar(SrcReg);
7098   if (!SrcReg)
7099     return UnableToLegalize;
7100 
7101   // Expand scalarizing unmerge as bitcast to integer and shift.
7102   LLT IntTy = MRI.getType(SrcReg);
7103 
7104   MIRBuilder.buildTrunc(Dst0Reg, SrcReg);
7105 
7106   const unsigned DstSize = DstTy.getSizeInBits();
7107   unsigned Offset = DstSize;
7108   for (unsigned I = 1; I != NumDst; ++I, Offset += DstSize) {
7109     auto ShiftAmt = MIRBuilder.buildConstant(IntTy, Offset);
7110     auto Shift = MIRBuilder.buildLShr(IntTy, SrcReg, ShiftAmt);
7111     MIRBuilder.buildTrunc(MI.getOperand(I), Shift);
7112   }
7113 
7114   MI.eraseFromParent();
7115   return Legalized;
7116 }
7117 
7118 /// Lower a vector extract or insert by writing the vector to a stack temporary
7119 /// and reloading the element or vector.
7120 ///
7121 /// %dst = G_EXTRACT_VECTOR_ELT %vec, %idx
7122 ///  =>
7123 ///  %stack_temp = G_FRAME_INDEX
7124 ///  G_STORE %vec, %stack_temp
7125 ///  %idx = clamp(%idx, %vec.getNumElements())
7126 ///  %element_ptr = G_PTR_ADD %stack_temp, %idx
7127 ///  %dst = G_LOAD %element_ptr
7128 LegalizerHelper::LegalizeResult
lowerExtractInsertVectorElt(MachineInstr & MI)7129 LegalizerHelper::lowerExtractInsertVectorElt(MachineInstr &MI) {
7130   Register DstReg = MI.getOperand(0).getReg();
7131   Register SrcVec = MI.getOperand(1).getReg();
7132   Register InsertVal;
7133   if (MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT)
7134     InsertVal = MI.getOperand(2).getReg();
7135 
7136   Register Idx = MI.getOperand(MI.getNumOperands() - 1).getReg();
7137 
7138   LLT VecTy = MRI.getType(SrcVec);
7139   LLT EltTy = VecTy.getElementType();
7140   unsigned NumElts = VecTy.getNumElements();
7141 
7142   int64_t IdxVal;
7143   if (mi_match(Idx, MRI, m_ICst(IdxVal)) && IdxVal <= NumElts) {
7144     SmallVector<Register, 8> SrcRegs;
7145     extractParts(SrcVec, EltTy, NumElts, SrcRegs, MIRBuilder, MRI);
7146 
7147     if (InsertVal) {
7148       SrcRegs[IdxVal] = MI.getOperand(2).getReg();
7149       MIRBuilder.buildMergeLikeInstr(DstReg, SrcRegs);
7150     } else {
7151       MIRBuilder.buildCopy(DstReg, SrcRegs[IdxVal]);
7152     }
7153 
7154     MI.eraseFromParent();
7155     return Legalized;
7156   }
7157 
7158   if (!EltTy.isByteSized()) { // Not implemented.
7159     LLVM_DEBUG(dbgs() << "Can't handle non-byte element vectors yet\n");
7160     return UnableToLegalize;
7161   }
7162 
7163   unsigned EltBytes = EltTy.getSizeInBytes();
7164   Align VecAlign = getStackTemporaryAlignment(VecTy);
7165   Align EltAlign;
7166 
7167   MachinePointerInfo PtrInfo;
7168   auto StackTemp = createStackTemporary(
7169       TypeSize::getFixed(VecTy.getSizeInBytes()), VecAlign, PtrInfo);
7170   MIRBuilder.buildStore(SrcVec, StackTemp, PtrInfo, VecAlign);
7171 
7172   // Get the pointer to the element, and be sure not to hit undefined behavior
7173   // if the index is out of bounds.
7174   Register EltPtr = getVectorElementPointer(StackTemp.getReg(0), VecTy, Idx);
7175 
7176   if (mi_match(Idx, MRI, m_ICst(IdxVal))) {
7177     int64_t Offset = IdxVal * EltBytes;
7178     PtrInfo = PtrInfo.getWithOffset(Offset);
7179     EltAlign = commonAlignment(VecAlign, Offset);
7180   } else {
7181     // We lose information with a variable offset.
7182     EltAlign = getStackTemporaryAlignment(EltTy);
7183     PtrInfo = MachinePointerInfo(MRI.getType(EltPtr).getAddressSpace());
7184   }
7185 
7186   if (InsertVal) {
7187     // Write the inserted element
7188     MIRBuilder.buildStore(InsertVal, EltPtr, PtrInfo, EltAlign);
7189 
7190     // Reload the whole vector.
7191     MIRBuilder.buildLoad(DstReg, StackTemp, PtrInfo, VecAlign);
7192   } else {
7193     MIRBuilder.buildLoad(DstReg, EltPtr, PtrInfo, EltAlign);
7194   }
7195 
7196   MI.eraseFromParent();
7197   return Legalized;
7198 }
7199 
7200 LegalizerHelper::LegalizeResult
lowerShuffleVector(MachineInstr & MI)7201 LegalizerHelper::lowerShuffleVector(MachineInstr &MI) {
7202   auto [DstReg, DstTy, Src0Reg, Src0Ty, Src1Reg, Src1Ty] =
7203       MI.getFirst3RegLLTs();
7204   LLT IdxTy = LLT::scalar(32);
7205 
7206   ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
7207   Register Undef;
7208   SmallVector<Register, 32> BuildVec;
7209   LLT EltTy = DstTy.getScalarType();
7210 
7211   for (int Idx : Mask) {
7212     if (Idx < 0) {
7213       if (!Undef.isValid())
7214         Undef = MIRBuilder.buildUndef(EltTy).getReg(0);
7215       BuildVec.push_back(Undef);
7216       continue;
7217     }
7218 
7219     if (Src0Ty.isScalar()) {
7220       BuildVec.push_back(Idx == 0 ? Src0Reg : Src1Reg);
7221     } else {
7222       int NumElts = Src0Ty.getNumElements();
7223       Register SrcVec = Idx < NumElts ? Src0Reg : Src1Reg;
7224       int ExtractIdx = Idx < NumElts ? Idx : Idx - NumElts;
7225       auto IdxK = MIRBuilder.buildConstant(IdxTy, ExtractIdx);
7226       auto Extract = MIRBuilder.buildExtractVectorElement(EltTy, SrcVec, IdxK);
7227       BuildVec.push_back(Extract.getReg(0));
7228     }
7229   }
7230 
7231   if (DstTy.isScalar())
7232     MIRBuilder.buildCopy(DstReg, BuildVec[0]);
7233   else
7234     MIRBuilder.buildBuildVector(DstReg, BuildVec);
7235   MI.eraseFromParent();
7236   return Legalized;
7237 }
7238 
getDynStackAllocTargetPtr(Register SPReg,Register AllocSize,Align Alignment,LLT PtrTy)7239 Register LegalizerHelper::getDynStackAllocTargetPtr(Register SPReg,
7240                                                     Register AllocSize,
7241                                                     Align Alignment,
7242                                                     LLT PtrTy) {
7243   LLT IntPtrTy = LLT::scalar(PtrTy.getSizeInBits());
7244 
7245   auto SPTmp = MIRBuilder.buildCopy(PtrTy, SPReg);
7246   SPTmp = MIRBuilder.buildCast(IntPtrTy, SPTmp);
7247 
7248   // Subtract the final alloc from the SP. We use G_PTRTOINT here so we don't
7249   // have to generate an extra instruction to negate the alloc and then use
7250   // G_PTR_ADD to add the negative offset.
7251   auto Alloc = MIRBuilder.buildSub(IntPtrTy, SPTmp, AllocSize);
7252   if (Alignment > Align(1)) {
7253     APInt AlignMask(IntPtrTy.getSizeInBits(), Alignment.value(), true);
7254     AlignMask.negate();
7255     auto AlignCst = MIRBuilder.buildConstant(IntPtrTy, AlignMask);
7256     Alloc = MIRBuilder.buildAnd(IntPtrTy, Alloc, AlignCst);
7257   }
7258 
7259   return MIRBuilder.buildCast(PtrTy, Alloc).getReg(0);
7260 }
7261 
7262 LegalizerHelper::LegalizeResult
lowerDynStackAlloc(MachineInstr & MI)7263 LegalizerHelper::lowerDynStackAlloc(MachineInstr &MI) {
7264   const auto &MF = *MI.getMF();
7265   const auto &TFI = *MF.getSubtarget().getFrameLowering();
7266   if (TFI.getStackGrowthDirection() == TargetFrameLowering::StackGrowsUp)
7267     return UnableToLegalize;
7268 
7269   Register Dst = MI.getOperand(0).getReg();
7270   Register AllocSize = MI.getOperand(1).getReg();
7271   Align Alignment = assumeAligned(MI.getOperand(2).getImm());
7272 
7273   LLT PtrTy = MRI.getType(Dst);
7274   Register SPReg = TLI.getStackPointerRegisterToSaveRestore();
7275   Register SPTmp =
7276       getDynStackAllocTargetPtr(SPReg, AllocSize, Alignment, PtrTy);
7277 
7278   MIRBuilder.buildCopy(SPReg, SPTmp);
7279   MIRBuilder.buildCopy(Dst, SPTmp);
7280 
7281   MI.eraseFromParent();
7282   return Legalized;
7283 }
7284 
7285 LegalizerHelper::LegalizeResult
lowerStackSave(MachineInstr & MI)7286 LegalizerHelper::lowerStackSave(MachineInstr &MI) {
7287   Register StackPtr = TLI.getStackPointerRegisterToSaveRestore();
7288   if (!StackPtr)
7289     return UnableToLegalize;
7290 
7291   MIRBuilder.buildCopy(MI.getOperand(0), StackPtr);
7292   MI.eraseFromParent();
7293   return Legalized;
7294 }
7295 
7296 LegalizerHelper::LegalizeResult
lowerStackRestore(MachineInstr & MI)7297 LegalizerHelper::lowerStackRestore(MachineInstr &MI) {
7298   Register StackPtr = TLI.getStackPointerRegisterToSaveRestore();
7299   if (!StackPtr)
7300     return UnableToLegalize;
7301 
7302   MIRBuilder.buildCopy(StackPtr, MI.getOperand(0));
7303   MI.eraseFromParent();
7304   return Legalized;
7305 }
7306 
7307 LegalizerHelper::LegalizeResult
lowerExtract(MachineInstr & MI)7308 LegalizerHelper::lowerExtract(MachineInstr &MI) {
7309   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
7310   unsigned Offset = MI.getOperand(2).getImm();
7311 
7312   // Extract sub-vector or one element
7313   if (SrcTy.isVector()) {
7314     unsigned SrcEltSize = SrcTy.getElementType().getSizeInBits();
7315     unsigned DstSize = DstTy.getSizeInBits();
7316 
7317     if ((Offset % SrcEltSize == 0) && (DstSize % SrcEltSize == 0) &&
7318         (Offset + DstSize <= SrcTy.getSizeInBits())) {
7319       // Unmerge and allow access to each Src element for the artifact combiner.
7320       auto Unmerge = MIRBuilder.buildUnmerge(SrcTy.getElementType(), SrcReg);
7321 
7322       // Take element(s) we need to extract and copy it (merge them).
7323       SmallVector<Register, 8> SubVectorElts;
7324       for (unsigned Idx = Offset / SrcEltSize;
7325            Idx < (Offset + DstSize) / SrcEltSize; ++Idx) {
7326         SubVectorElts.push_back(Unmerge.getReg(Idx));
7327       }
7328       if (SubVectorElts.size() == 1)
7329         MIRBuilder.buildCopy(DstReg, SubVectorElts[0]);
7330       else
7331         MIRBuilder.buildMergeLikeInstr(DstReg, SubVectorElts);
7332 
7333       MI.eraseFromParent();
7334       return Legalized;
7335     }
7336   }
7337 
7338   if (DstTy.isScalar() &&
7339       (SrcTy.isScalar() ||
7340        (SrcTy.isVector() && DstTy == SrcTy.getElementType()))) {
7341     LLT SrcIntTy = SrcTy;
7342     if (!SrcTy.isScalar()) {
7343       SrcIntTy = LLT::scalar(SrcTy.getSizeInBits());
7344       SrcReg = MIRBuilder.buildBitcast(SrcIntTy, SrcReg).getReg(0);
7345     }
7346 
7347     if (Offset == 0)
7348       MIRBuilder.buildTrunc(DstReg, SrcReg);
7349     else {
7350       auto ShiftAmt = MIRBuilder.buildConstant(SrcIntTy, Offset);
7351       auto Shr = MIRBuilder.buildLShr(SrcIntTy, SrcReg, ShiftAmt);
7352       MIRBuilder.buildTrunc(DstReg, Shr);
7353     }
7354 
7355     MI.eraseFromParent();
7356     return Legalized;
7357   }
7358 
7359   return UnableToLegalize;
7360 }
7361 
lowerInsert(MachineInstr & MI)7362 LegalizerHelper::LegalizeResult LegalizerHelper::lowerInsert(MachineInstr &MI) {
7363   auto [Dst, Src, InsertSrc] = MI.getFirst3Regs();
7364   uint64_t Offset = MI.getOperand(3).getImm();
7365 
7366   LLT DstTy = MRI.getType(Src);
7367   LLT InsertTy = MRI.getType(InsertSrc);
7368 
7369   // Insert sub-vector or one element
7370   if (DstTy.isVector() && !InsertTy.isPointer()) {
7371     LLT EltTy = DstTy.getElementType();
7372     unsigned EltSize = EltTy.getSizeInBits();
7373     unsigned InsertSize = InsertTy.getSizeInBits();
7374 
7375     if ((Offset % EltSize == 0) && (InsertSize % EltSize == 0) &&
7376         (Offset + InsertSize <= DstTy.getSizeInBits())) {
7377       auto UnmergeSrc = MIRBuilder.buildUnmerge(EltTy, Src);
7378       SmallVector<Register, 8> DstElts;
7379       unsigned Idx = 0;
7380       // Elements from Src before insert start Offset
7381       for (; Idx < Offset / EltSize; ++Idx) {
7382         DstElts.push_back(UnmergeSrc.getReg(Idx));
7383       }
7384 
7385       // Replace elements in Src with elements from InsertSrc
7386       if (InsertTy.getSizeInBits() > EltSize) {
7387         auto UnmergeInsertSrc = MIRBuilder.buildUnmerge(EltTy, InsertSrc);
7388         for (unsigned i = 0; Idx < (Offset + InsertSize) / EltSize;
7389              ++Idx, ++i) {
7390           DstElts.push_back(UnmergeInsertSrc.getReg(i));
7391         }
7392       } else {
7393         DstElts.push_back(InsertSrc);
7394         ++Idx;
7395       }
7396 
7397       // Remaining elements from Src after insert
7398       for (; Idx < DstTy.getNumElements(); ++Idx) {
7399         DstElts.push_back(UnmergeSrc.getReg(Idx));
7400       }
7401 
7402       MIRBuilder.buildMergeLikeInstr(Dst, DstElts);
7403       MI.eraseFromParent();
7404       return Legalized;
7405     }
7406   }
7407 
7408   if (InsertTy.isVector() ||
7409       (DstTy.isVector() && DstTy.getElementType() != InsertTy))
7410     return UnableToLegalize;
7411 
7412   const DataLayout &DL = MIRBuilder.getDataLayout();
7413   if ((DstTy.isPointer() &&
7414        DL.isNonIntegralAddressSpace(DstTy.getAddressSpace())) ||
7415       (InsertTy.isPointer() &&
7416        DL.isNonIntegralAddressSpace(InsertTy.getAddressSpace()))) {
7417     LLVM_DEBUG(dbgs() << "Not casting non-integral address space integer\n");
7418     return UnableToLegalize;
7419   }
7420 
7421   LLT IntDstTy = DstTy;
7422 
7423   if (!DstTy.isScalar()) {
7424     IntDstTy = LLT::scalar(DstTy.getSizeInBits());
7425     Src = MIRBuilder.buildCast(IntDstTy, Src).getReg(0);
7426   }
7427 
7428   if (!InsertTy.isScalar()) {
7429     const LLT IntInsertTy = LLT::scalar(InsertTy.getSizeInBits());
7430     InsertSrc = MIRBuilder.buildPtrToInt(IntInsertTy, InsertSrc).getReg(0);
7431   }
7432 
7433   Register ExtInsSrc = MIRBuilder.buildZExt(IntDstTy, InsertSrc).getReg(0);
7434   if (Offset != 0) {
7435     auto ShiftAmt = MIRBuilder.buildConstant(IntDstTy, Offset);
7436     ExtInsSrc = MIRBuilder.buildShl(IntDstTy, ExtInsSrc, ShiftAmt).getReg(0);
7437   }
7438 
7439   APInt MaskVal = APInt::getBitsSetWithWrap(
7440       DstTy.getSizeInBits(), Offset + InsertTy.getSizeInBits(), Offset);
7441 
7442   auto Mask = MIRBuilder.buildConstant(IntDstTy, MaskVal);
7443   auto MaskedSrc = MIRBuilder.buildAnd(IntDstTy, Src, Mask);
7444   auto Or = MIRBuilder.buildOr(IntDstTy, MaskedSrc, ExtInsSrc);
7445 
7446   MIRBuilder.buildCast(Dst, Or);
7447   MI.eraseFromParent();
7448   return Legalized;
7449 }
7450 
7451 LegalizerHelper::LegalizeResult
lowerSADDO_SSUBO(MachineInstr & MI)7452 LegalizerHelper::lowerSADDO_SSUBO(MachineInstr &MI) {
7453   auto [Dst0, Dst0Ty, Dst1, Dst1Ty, LHS, LHSTy, RHS, RHSTy] =
7454       MI.getFirst4RegLLTs();
7455   const bool IsAdd = MI.getOpcode() == TargetOpcode::G_SADDO;
7456 
7457   LLT Ty = Dst0Ty;
7458   LLT BoolTy = Dst1Ty;
7459 
7460   if (IsAdd)
7461     MIRBuilder.buildAdd(Dst0, LHS, RHS);
7462   else
7463     MIRBuilder.buildSub(Dst0, LHS, RHS);
7464 
7465   // TODO: If SADDSAT/SSUBSAT is legal, compare results to detect overflow.
7466 
7467   auto Zero = MIRBuilder.buildConstant(Ty, 0);
7468 
7469   // For an addition, the result should be less than one of the operands (LHS)
7470   // if and only if the other operand (RHS) is negative, otherwise there will
7471   // be overflow.
7472   // For a subtraction, the result should be less than one of the operands
7473   // (LHS) if and only if the other operand (RHS) is (non-zero) positive,
7474   // otherwise there will be overflow.
7475   auto ResultLowerThanLHS =
7476       MIRBuilder.buildICmp(CmpInst::ICMP_SLT, BoolTy, Dst0, LHS);
7477   auto ConditionRHS = MIRBuilder.buildICmp(
7478       IsAdd ? CmpInst::ICMP_SLT : CmpInst::ICMP_SGT, BoolTy, RHS, Zero);
7479 
7480   MIRBuilder.buildXor(Dst1, ConditionRHS, ResultLowerThanLHS);
7481   MI.eraseFromParent();
7482   return Legalized;
7483 }
7484 
7485 LegalizerHelper::LegalizeResult
lowerAddSubSatToMinMax(MachineInstr & MI)7486 LegalizerHelper::lowerAddSubSatToMinMax(MachineInstr &MI) {
7487   auto [Res, LHS, RHS] = MI.getFirst3Regs();
7488   LLT Ty = MRI.getType(Res);
7489   bool IsSigned;
7490   bool IsAdd;
7491   unsigned BaseOp;
7492   switch (MI.getOpcode()) {
7493   default:
7494     llvm_unreachable("unexpected addsat/subsat opcode");
7495   case TargetOpcode::G_UADDSAT:
7496     IsSigned = false;
7497     IsAdd = true;
7498     BaseOp = TargetOpcode::G_ADD;
7499     break;
7500   case TargetOpcode::G_SADDSAT:
7501     IsSigned = true;
7502     IsAdd = true;
7503     BaseOp = TargetOpcode::G_ADD;
7504     break;
7505   case TargetOpcode::G_USUBSAT:
7506     IsSigned = false;
7507     IsAdd = false;
7508     BaseOp = TargetOpcode::G_SUB;
7509     break;
7510   case TargetOpcode::G_SSUBSAT:
7511     IsSigned = true;
7512     IsAdd = false;
7513     BaseOp = TargetOpcode::G_SUB;
7514     break;
7515   }
7516 
7517   if (IsSigned) {
7518     // sadd.sat(a, b) ->
7519     //   hi = 0x7fffffff - smax(a, 0)
7520     //   lo = 0x80000000 - smin(a, 0)
7521     //   a + smin(smax(lo, b), hi)
7522     // ssub.sat(a, b) ->
7523     //   lo = smax(a, -1) - 0x7fffffff
7524     //   hi = smin(a, -1) - 0x80000000
7525     //   a - smin(smax(lo, b), hi)
7526     // TODO: AMDGPU can use a "median of 3" instruction here:
7527     //   a +/- med3(lo, b, hi)
7528     uint64_t NumBits = Ty.getScalarSizeInBits();
7529     auto MaxVal =
7530         MIRBuilder.buildConstant(Ty, APInt::getSignedMaxValue(NumBits));
7531     auto MinVal =
7532         MIRBuilder.buildConstant(Ty, APInt::getSignedMinValue(NumBits));
7533     MachineInstrBuilder Hi, Lo;
7534     if (IsAdd) {
7535       auto Zero = MIRBuilder.buildConstant(Ty, 0);
7536       Hi = MIRBuilder.buildSub(Ty, MaxVal, MIRBuilder.buildSMax(Ty, LHS, Zero));
7537       Lo = MIRBuilder.buildSub(Ty, MinVal, MIRBuilder.buildSMin(Ty, LHS, Zero));
7538     } else {
7539       auto NegOne = MIRBuilder.buildConstant(Ty, -1);
7540       Lo = MIRBuilder.buildSub(Ty, MIRBuilder.buildSMax(Ty, LHS, NegOne),
7541                                MaxVal);
7542       Hi = MIRBuilder.buildSub(Ty, MIRBuilder.buildSMin(Ty, LHS, NegOne),
7543                                MinVal);
7544     }
7545     auto RHSClamped =
7546         MIRBuilder.buildSMin(Ty, MIRBuilder.buildSMax(Ty, Lo, RHS), Hi);
7547     MIRBuilder.buildInstr(BaseOp, {Res}, {LHS, RHSClamped});
7548   } else {
7549     // uadd.sat(a, b) -> a + umin(~a, b)
7550     // usub.sat(a, b) -> a - umin(a, b)
7551     Register Not = IsAdd ? MIRBuilder.buildNot(Ty, LHS).getReg(0) : LHS;
7552     auto Min = MIRBuilder.buildUMin(Ty, Not, RHS);
7553     MIRBuilder.buildInstr(BaseOp, {Res}, {LHS, Min});
7554   }
7555 
7556   MI.eraseFromParent();
7557   return Legalized;
7558 }
7559 
7560 LegalizerHelper::LegalizeResult
lowerAddSubSatToAddoSubo(MachineInstr & MI)7561 LegalizerHelper::lowerAddSubSatToAddoSubo(MachineInstr &MI) {
7562   auto [Res, LHS, RHS] = MI.getFirst3Regs();
7563   LLT Ty = MRI.getType(Res);
7564   LLT BoolTy = Ty.changeElementSize(1);
7565   bool IsSigned;
7566   bool IsAdd;
7567   unsigned OverflowOp;
7568   switch (MI.getOpcode()) {
7569   default:
7570     llvm_unreachable("unexpected addsat/subsat opcode");
7571   case TargetOpcode::G_UADDSAT:
7572     IsSigned = false;
7573     IsAdd = true;
7574     OverflowOp = TargetOpcode::G_UADDO;
7575     break;
7576   case TargetOpcode::G_SADDSAT:
7577     IsSigned = true;
7578     IsAdd = true;
7579     OverflowOp = TargetOpcode::G_SADDO;
7580     break;
7581   case TargetOpcode::G_USUBSAT:
7582     IsSigned = false;
7583     IsAdd = false;
7584     OverflowOp = TargetOpcode::G_USUBO;
7585     break;
7586   case TargetOpcode::G_SSUBSAT:
7587     IsSigned = true;
7588     IsAdd = false;
7589     OverflowOp = TargetOpcode::G_SSUBO;
7590     break;
7591   }
7592 
7593   auto OverflowRes =
7594       MIRBuilder.buildInstr(OverflowOp, {Ty, BoolTy}, {LHS, RHS});
7595   Register Tmp = OverflowRes.getReg(0);
7596   Register Ov = OverflowRes.getReg(1);
7597   MachineInstrBuilder Clamp;
7598   if (IsSigned) {
7599     // sadd.sat(a, b) ->
7600     //   {tmp, ov} = saddo(a, b)
7601     //   ov ? (tmp >>s 31) + 0x80000000 : r
7602     // ssub.sat(a, b) ->
7603     //   {tmp, ov} = ssubo(a, b)
7604     //   ov ? (tmp >>s 31) + 0x80000000 : r
7605     uint64_t NumBits = Ty.getScalarSizeInBits();
7606     auto ShiftAmount = MIRBuilder.buildConstant(Ty, NumBits - 1);
7607     auto Sign = MIRBuilder.buildAShr(Ty, Tmp, ShiftAmount);
7608     auto MinVal =
7609         MIRBuilder.buildConstant(Ty, APInt::getSignedMinValue(NumBits));
7610     Clamp = MIRBuilder.buildAdd(Ty, Sign, MinVal);
7611   } else {
7612     // uadd.sat(a, b) ->
7613     //   {tmp, ov} = uaddo(a, b)
7614     //   ov ? 0xffffffff : tmp
7615     // usub.sat(a, b) ->
7616     //   {tmp, ov} = usubo(a, b)
7617     //   ov ? 0 : tmp
7618     Clamp = MIRBuilder.buildConstant(Ty, IsAdd ? -1 : 0);
7619   }
7620   MIRBuilder.buildSelect(Res, Ov, Clamp, Tmp);
7621 
7622   MI.eraseFromParent();
7623   return Legalized;
7624 }
7625 
7626 LegalizerHelper::LegalizeResult
lowerShlSat(MachineInstr & MI)7627 LegalizerHelper::lowerShlSat(MachineInstr &MI) {
7628   assert((MI.getOpcode() == TargetOpcode::G_SSHLSAT ||
7629           MI.getOpcode() == TargetOpcode::G_USHLSAT) &&
7630          "Expected shlsat opcode!");
7631   bool IsSigned = MI.getOpcode() == TargetOpcode::G_SSHLSAT;
7632   auto [Res, LHS, RHS] = MI.getFirst3Regs();
7633   LLT Ty = MRI.getType(Res);
7634   LLT BoolTy = Ty.changeElementSize(1);
7635 
7636   unsigned BW = Ty.getScalarSizeInBits();
7637   auto Result = MIRBuilder.buildShl(Ty, LHS, RHS);
7638   auto Orig = IsSigned ? MIRBuilder.buildAShr(Ty, Result, RHS)
7639                        : MIRBuilder.buildLShr(Ty, Result, RHS);
7640 
7641   MachineInstrBuilder SatVal;
7642   if (IsSigned) {
7643     auto SatMin = MIRBuilder.buildConstant(Ty, APInt::getSignedMinValue(BW));
7644     auto SatMax = MIRBuilder.buildConstant(Ty, APInt::getSignedMaxValue(BW));
7645     auto Cmp = MIRBuilder.buildICmp(CmpInst::ICMP_SLT, BoolTy, LHS,
7646                                     MIRBuilder.buildConstant(Ty, 0));
7647     SatVal = MIRBuilder.buildSelect(Ty, Cmp, SatMin, SatMax);
7648   } else {
7649     SatVal = MIRBuilder.buildConstant(Ty, APInt::getMaxValue(BW));
7650   }
7651   auto Ov = MIRBuilder.buildICmp(CmpInst::ICMP_NE, BoolTy, LHS, Orig);
7652   MIRBuilder.buildSelect(Res, Ov, SatVal, Result);
7653 
7654   MI.eraseFromParent();
7655   return Legalized;
7656 }
7657 
lowerBswap(MachineInstr & MI)7658 LegalizerHelper::LegalizeResult LegalizerHelper::lowerBswap(MachineInstr &MI) {
7659   auto [Dst, Src] = MI.getFirst2Regs();
7660   const LLT Ty = MRI.getType(Src);
7661   unsigned SizeInBytes = (Ty.getScalarSizeInBits() + 7) / 8;
7662   unsigned BaseShiftAmt = (SizeInBytes - 1) * 8;
7663 
7664   // Swap most and least significant byte, set remaining bytes in Res to zero.
7665   auto ShiftAmt = MIRBuilder.buildConstant(Ty, BaseShiftAmt);
7666   auto LSByteShiftedLeft = MIRBuilder.buildShl(Ty, Src, ShiftAmt);
7667   auto MSByteShiftedRight = MIRBuilder.buildLShr(Ty, Src, ShiftAmt);
7668   auto Res = MIRBuilder.buildOr(Ty, MSByteShiftedRight, LSByteShiftedLeft);
7669 
7670   // Set i-th high/low byte in Res to i-th low/high byte from Src.
7671   for (unsigned i = 1; i < SizeInBytes / 2; ++i) {
7672     // AND with Mask leaves byte i unchanged and sets remaining bytes to 0.
7673     APInt APMask(SizeInBytes * 8, 0xFF << (i * 8));
7674     auto Mask = MIRBuilder.buildConstant(Ty, APMask);
7675     auto ShiftAmt = MIRBuilder.buildConstant(Ty, BaseShiftAmt - 16 * i);
7676     // Low byte shifted left to place of high byte: (Src & Mask) << ShiftAmt.
7677     auto LoByte = MIRBuilder.buildAnd(Ty, Src, Mask);
7678     auto LoShiftedLeft = MIRBuilder.buildShl(Ty, LoByte, ShiftAmt);
7679     Res = MIRBuilder.buildOr(Ty, Res, LoShiftedLeft);
7680     // High byte shifted right to place of low byte: (Src >> ShiftAmt) & Mask.
7681     auto SrcShiftedRight = MIRBuilder.buildLShr(Ty, Src, ShiftAmt);
7682     auto HiShiftedRight = MIRBuilder.buildAnd(Ty, SrcShiftedRight, Mask);
7683     Res = MIRBuilder.buildOr(Ty, Res, HiShiftedRight);
7684   }
7685   Res.getInstr()->getOperand(0).setReg(Dst);
7686 
7687   MI.eraseFromParent();
7688   return Legalized;
7689 }
7690 
7691 //{ (Src & Mask) >> N } | { (Src << N) & Mask }
SwapN(unsigned N,DstOp Dst,MachineIRBuilder & B,MachineInstrBuilder Src,APInt Mask)7692 static MachineInstrBuilder SwapN(unsigned N, DstOp Dst, MachineIRBuilder &B,
7693                                  MachineInstrBuilder Src, APInt Mask) {
7694   const LLT Ty = Dst.getLLTTy(*B.getMRI());
7695   MachineInstrBuilder C_N = B.buildConstant(Ty, N);
7696   MachineInstrBuilder MaskLoNTo0 = B.buildConstant(Ty, Mask);
7697   auto LHS = B.buildLShr(Ty, B.buildAnd(Ty, Src, MaskLoNTo0), C_N);
7698   auto RHS = B.buildAnd(Ty, B.buildShl(Ty, Src, C_N), MaskLoNTo0);
7699   return B.buildOr(Dst, LHS, RHS);
7700 }
7701 
7702 LegalizerHelper::LegalizeResult
lowerBitreverse(MachineInstr & MI)7703 LegalizerHelper::lowerBitreverse(MachineInstr &MI) {
7704   auto [Dst, Src] = MI.getFirst2Regs();
7705   const LLT Ty = MRI.getType(Src);
7706   unsigned Size = Ty.getSizeInBits();
7707 
7708   MachineInstrBuilder BSWAP =
7709       MIRBuilder.buildInstr(TargetOpcode::G_BSWAP, {Ty}, {Src});
7710 
7711   // swap high and low 4 bits in 8 bit blocks 7654|3210 -> 3210|7654
7712   //    [(val & 0xF0F0F0F0) >> 4] | [(val & 0x0F0F0F0F) << 4]
7713   // -> [(val & 0xF0F0F0F0) >> 4] | [(val << 4) & 0xF0F0F0F0]
7714   MachineInstrBuilder Swap4 =
7715       SwapN(4, Ty, MIRBuilder, BSWAP, APInt::getSplat(Size, APInt(8, 0xF0)));
7716 
7717   // swap high and low 2 bits in 4 bit blocks 32|10 76|54 -> 10|32 54|76
7718   //    [(val & 0xCCCCCCCC) >> 2] & [(val & 0x33333333) << 2]
7719   // -> [(val & 0xCCCCCCCC) >> 2] & [(val << 2) & 0xCCCCCCCC]
7720   MachineInstrBuilder Swap2 =
7721       SwapN(2, Ty, MIRBuilder, Swap4, APInt::getSplat(Size, APInt(8, 0xCC)));
7722 
7723   // swap high and low 1 bit in 2 bit blocks 1|0 3|2 5|4 7|6 -> 0|1 2|3 4|5 6|7
7724   //    [(val & 0xAAAAAAAA) >> 1] & [(val & 0x55555555) << 1]
7725   // -> [(val & 0xAAAAAAAA) >> 1] & [(val << 1) & 0xAAAAAAAA]
7726   SwapN(1, Dst, MIRBuilder, Swap2, APInt::getSplat(Size, APInt(8, 0xAA)));
7727 
7728   MI.eraseFromParent();
7729   return Legalized;
7730 }
7731 
7732 LegalizerHelper::LegalizeResult
lowerReadWriteRegister(MachineInstr & MI)7733 LegalizerHelper::lowerReadWriteRegister(MachineInstr &MI) {
7734   MachineFunction &MF = MIRBuilder.getMF();
7735 
7736   bool IsRead = MI.getOpcode() == TargetOpcode::G_READ_REGISTER;
7737   int NameOpIdx = IsRead ? 1 : 0;
7738   int ValRegIndex = IsRead ? 0 : 1;
7739 
7740   Register ValReg = MI.getOperand(ValRegIndex).getReg();
7741   const LLT Ty = MRI.getType(ValReg);
7742   const MDString *RegStr = cast<MDString>(
7743     cast<MDNode>(MI.getOperand(NameOpIdx).getMetadata())->getOperand(0));
7744 
7745   Register PhysReg = TLI.getRegisterByName(RegStr->getString().data(), Ty, MF);
7746   if (!PhysReg.isValid())
7747     return UnableToLegalize;
7748 
7749   if (IsRead)
7750     MIRBuilder.buildCopy(ValReg, PhysReg);
7751   else
7752     MIRBuilder.buildCopy(PhysReg, ValReg);
7753 
7754   MI.eraseFromParent();
7755   return Legalized;
7756 }
7757 
7758 LegalizerHelper::LegalizeResult
lowerSMULH_UMULH(MachineInstr & MI)7759 LegalizerHelper::lowerSMULH_UMULH(MachineInstr &MI) {
7760   bool IsSigned = MI.getOpcode() == TargetOpcode::G_SMULH;
7761   unsigned ExtOp = IsSigned ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
7762   Register Result = MI.getOperand(0).getReg();
7763   LLT OrigTy = MRI.getType(Result);
7764   auto SizeInBits = OrigTy.getScalarSizeInBits();
7765   LLT WideTy = OrigTy.changeElementSize(SizeInBits * 2);
7766 
7767   auto LHS = MIRBuilder.buildInstr(ExtOp, {WideTy}, {MI.getOperand(1)});
7768   auto RHS = MIRBuilder.buildInstr(ExtOp, {WideTy}, {MI.getOperand(2)});
7769   auto Mul = MIRBuilder.buildMul(WideTy, LHS, RHS);
7770   unsigned ShiftOp = IsSigned ? TargetOpcode::G_ASHR : TargetOpcode::G_LSHR;
7771 
7772   auto ShiftAmt = MIRBuilder.buildConstant(WideTy, SizeInBits);
7773   auto Shifted = MIRBuilder.buildInstr(ShiftOp, {WideTy}, {Mul, ShiftAmt});
7774   MIRBuilder.buildTrunc(Result, Shifted);
7775 
7776   MI.eraseFromParent();
7777   return Legalized;
7778 }
7779 
7780 LegalizerHelper::LegalizeResult
lowerISFPCLASS(MachineInstr & MI)7781 LegalizerHelper::lowerISFPCLASS(MachineInstr &MI) {
7782   auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
7783   FPClassTest Mask = static_cast<FPClassTest>(MI.getOperand(2).getImm());
7784 
7785   if (Mask == fcNone) {
7786     MIRBuilder.buildConstant(DstReg, 0);
7787     MI.eraseFromParent();
7788     return Legalized;
7789   }
7790   if (Mask == fcAllFlags) {
7791     MIRBuilder.buildConstant(DstReg, 1);
7792     MI.eraseFromParent();
7793     return Legalized;
7794   }
7795 
7796   // TODO: Try inverting the test with getInvertedFPClassTest like the DAG
7797   // version
7798 
7799   unsigned BitSize = SrcTy.getScalarSizeInBits();
7800   const fltSemantics &Semantics = getFltSemanticForLLT(SrcTy.getScalarType());
7801 
7802   LLT IntTy = LLT::scalar(BitSize);
7803   if (SrcTy.isVector())
7804     IntTy = LLT::vector(SrcTy.getElementCount(), IntTy);
7805   auto AsInt = MIRBuilder.buildCopy(IntTy, SrcReg);
7806 
7807   // Various masks.
7808   APInt SignBit = APInt::getSignMask(BitSize);
7809   APInt ValueMask = APInt::getSignedMaxValue(BitSize);     // All bits but sign.
7810   APInt Inf = APFloat::getInf(Semantics).bitcastToAPInt(); // Exp and int bit.
7811   APInt ExpMask = Inf;
7812   APInt AllOneMantissa = APFloat::getLargest(Semantics).bitcastToAPInt() & ~Inf;
7813   APInt QNaNBitMask =
7814       APInt::getOneBitSet(BitSize, AllOneMantissa.getActiveBits() - 1);
7815   APInt InvertionMask = APInt::getAllOnes(DstTy.getScalarSizeInBits());
7816 
7817   auto SignBitC = MIRBuilder.buildConstant(IntTy, SignBit);
7818   auto ValueMaskC = MIRBuilder.buildConstant(IntTy, ValueMask);
7819   auto InfC = MIRBuilder.buildConstant(IntTy, Inf);
7820   auto ExpMaskC = MIRBuilder.buildConstant(IntTy, ExpMask);
7821   auto ZeroC = MIRBuilder.buildConstant(IntTy, 0);
7822 
7823   auto Abs = MIRBuilder.buildAnd(IntTy, AsInt, ValueMaskC);
7824   auto Sign =
7825       MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_NE, DstTy, AsInt, Abs);
7826 
7827   auto Res = MIRBuilder.buildConstant(DstTy, 0);
7828   // Clang doesn't support capture of structured bindings:
7829   LLT DstTyCopy = DstTy;
7830   const auto appendToRes = [&](MachineInstrBuilder ToAppend) {
7831     Res = MIRBuilder.buildOr(DstTyCopy, Res, ToAppend);
7832   };
7833 
7834   // Tests that involve more than one class should be processed first.
7835   if ((Mask & fcFinite) == fcFinite) {
7836     // finite(V) ==> abs(V) u< exp_mask
7837     appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, Abs,
7838                                      ExpMaskC));
7839     Mask &= ~fcFinite;
7840   } else if ((Mask & fcFinite) == fcPosFinite) {
7841     // finite(V) && V > 0 ==> V u< exp_mask
7842     appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, AsInt,
7843                                      ExpMaskC));
7844     Mask &= ~fcPosFinite;
7845   } else if ((Mask & fcFinite) == fcNegFinite) {
7846     // finite(V) && V < 0 ==> abs(V) u< exp_mask && signbit == 1
7847     auto Cmp = MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, Abs,
7848                                     ExpMaskC);
7849     auto And = MIRBuilder.buildAnd(DstTy, Cmp, Sign);
7850     appendToRes(And);
7851     Mask &= ~fcNegFinite;
7852   }
7853 
7854   if (FPClassTest PartialCheck = Mask & (fcZero | fcSubnormal)) {
7855     // fcZero | fcSubnormal => test all exponent bits are 0
7856     // TODO: Handle sign bit specific cases
7857     // TODO: Handle inverted case
7858     if (PartialCheck == (fcZero | fcSubnormal)) {
7859       auto ExpBits = MIRBuilder.buildAnd(IntTy, AsInt, ExpMaskC);
7860       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
7861                                        ExpBits, ZeroC));
7862       Mask &= ~PartialCheck;
7863     }
7864   }
7865 
7866   // Check for individual classes.
7867   if (FPClassTest PartialCheck = Mask & fcZero) {
7868     if (PartialCheck == fcPosZero)
7869       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
7870                                        AsInt, ZeroC));
7871     else if (PartialCheck == fcZero)
7872       appendToRes(
7873           MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, ZeroC));
7874     else // fcNegZero
7875       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
7876                                        AsInt, SignBitC));
7877   }
7878 
7879   if (FPClassTest PartialCheck = Mask & fcSubnormal) {
7880     // issubnormal(V) ==> unsigned(abs(V) - 1) u< (all mantissa bits set)
7881     // issubnormal(V) && V>0 ==> unsigned(V - 1) u< (all mantissa bits set)
7882     auto V = (PartialCheck == fcPosSubnormal) ? AsInt : Abs;
7883     auto OneC = MIRBuilder.buildConstant(IntTy, 1);
7884     auto VMinusOne = MIRBuilder.buildSub(IntTy, V, OneC);
7885     auto SubnormalRes =
7886         MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, VMinusOne,
7887                              MIRBuilder.buildConstant(IntTy, AllOneMantissa));
7888     if (PartialCheck == fcNegSubnormal)
7889       SubnormalRes = MIRBuilder.buildAnd(DstTy, SubnormalRes, Sign);
7890     appendToRes(SubnormalRes);
7891   }
7892 
7893   if (FPClassTest PartialCheck = Mask & fcInf) {
7894     if (PartialCheck == fcPosInf)
7895       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
7896                                        AsInt, InfC));
7897     else if (PartialCheck == fcInf)
7898       appendToRes(
7899           MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, InfC));
7900     else { // fcNegInf
7901       APInt NegInf = APFloat::getInf(Semantics, true).bitcastToAPInt();
7902       auto NegInfC = MIRBuilder.buildConstant(IntTy, NegInf);
7903       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
7904                                        AsInt, NegInfC));
7905     }
7906   }
7907 
7908   if (FPClassTest PartialCheck = Mask & fcNan) {
7909     auto InfWithQnanBitC = MIRBuilder.buildConstant(IntTy, Inf | QNaNBitMask);
7910     if (PartialCheck == fcNan) {
7911       // isnan(V) ==> abs(V) u> int(inf)
7912       appendToRes(
7913           MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
7914     } else if (PartialCheck == fcQNan) {
7915       // isquiet(V) ==> abs(V) u>= (unsigned(Inf) | quiet_bit)
7916       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGE, DstTy, Abs,
7917                                        InfWithQnanBitC));
7918     } else { // fcSNan
7919       // issignaling(V) ==> abs(V) u> unsigned(Inf) &&
7920       //                    abs(V) u< (unsigned(Inf) | quiet_bit)
7921       auto IsNan =
7922           MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC);
7923       auto IsNotQnan = MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy,
7924                                             Abs, InfWithQnanBitC);
7925       appendToRes(MIRBuilder.buildAnd(DstTy, IsNan, IsNotQnan));
7926     }
7927   }
7928 
7929   if (FPClassTest PartialCheck = Mask & fcNormal) {
7930     // isnormal(V) ==> (0 u< exp u< max_exp) ==> (unsigned(exp-1) u<
7931     // (max_exp-1))
7932     APInt ExpLSB = ExpMask & ~(ExpMask.shl(1));
7933     auto ExpMinusOne = MIRBuilder.buildSub(
7934         IntTy, Abs, MIRBuilder.buildConstant(IntTy, ExpLSB));
7935     APInt MaxExpMinusOne = ExpMask - ExpLSB;
7936     auto NormalRes =
7937         MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, ExpMinusOne,
7938                              MIRBuilder.buildConstant(IntTy, MaxExpMinusOne));
7939     if (PartialCheck == fcNegNormal)
7940       NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, Sign);
7941     else if (PartialCheck == fcPosNormal) {
7942       auto PosSign = MIRBuilder.buildXor(
7943           DstTy, Sign, MIRBuilder.buildConstant(DstTy, InvertionMask));
7944       NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, PosSign);
7945     }
7946     appendToRes(NormalRes);
7947   }
7948 
7949   MIRBuilder.buildCopy(DstReg, Res);
7950   MI.eraseFromParent();
7951   return Legalized;
7952 }
7953 
lowerSelect(MachineInstr & MI)7954 LegalizerHelper::LegalizeResult LegalizerHelper::lowerSelect(MachineInstr &MI) {
7955   // Implement vector G_SELECT in terms of XOR, AND, OR.
7956   auto [DstReg, DstTy, MaskReg, MaskTy, Op1Reg, Op1Ty, Op2Reg, Op2Ty] =
7957       MI.getFirst4RegLLTs();
7958   if (!DstTy.isVector())
7959     return UnableToLegalize;
7960 
7961   bool IsEltPtr = DstTy.getElementType().isPointer();
7962   if (IsEltPtr) {
7963     LLT ScalarPtrTy = LLT::scalar(DstTy.getScalarSizeInBits());
7964     LLT NewTy = DstTy.changeElementType(ScalarPtrTy);
7965     Op1Reg = MIRBuilder.buildPtrToInt(NewTy, Op1Reg).getReg(0);
7966     Op2Reg = MIRBuilder.buildPtrToInt(NewTy, Op2Reg).getReg(0);
7967     DstTy = NewTy;
7968   }
7969 
7970   if (MaskTy.isScalar()) {
7971     // Turn the scalar condition into a vector condition mask.
7972 
7973     Register MaskElt = MaskReg;
7974 
7975     // The condition was potentially zero extended before, but we want a sign
7976     // extended boolean.
7977     if (MaskTy != LLT::scalar(1))
7978       MaskElt = MIRBuilder.buildSExtInReg(MaskTy, MaskElt, 1).getReg(0);
7979 
7980     // Continue the sign extension (or truncate) to match the data type.
7981     MaskElt = MIRBuilder.buildSExtOrTrunc(DstTy.getElementType(),
7982                                           MaskElt).getReg(0);
7983 
7984     // Generate a vector splat idiom.
7985     auto ShufSplat = MIRBuilder.buildShuffleSplat(DstTy, MaskElt);
7986     MaskReg = ShufSplat.getReg(0);
7987     MaskTy = DstTy;
7988   }
7989 
7990   if (MaskTy.getSizeInBits() != DstTy.getSizeInBits()) {
7991     return UnableToLegalize;
7992   }
7993 
7994   auto NotMask = MIRBuilder.buildNot(MaskTy, MaskReg);
7995   auto NewOp1 = MIRBuilder.buildAnd(MaskTy, Op1Reg, MaskReg);
7996   auto NewOp2 = MIRBuilder.buildAnd(MaskTy, Op2Reg, NotMask);
7997   if (IsEltPtr) {
7998     auto Or = MIRBuilder.buildOr(DstTy, NewOp1, NewOp2);
7999     MIRBuilder.buildIntToPtr(DstReg, Or);
8000   } else {
8001     MIRBuilder.buildOr(DstReg, NewOp1, NewOp2);
8002   }
8003   MI.eraseFromParent();
8004   return Legalized;
8005 }
8006 
lowerDIVREM(MachineInstr & MI)8007 LegalizerHelper::LegalizeResult LegalizerHelper::lowerDIVREM(MachineInstr &MI) {
8008   // Split DIVREM into individual instructions.
8009   unsigned Opcode = MI.getOpcode();
8010 
8011   MIRBuilder.buildInstr(
8012       Opcode == TargetOpcode::G_SDIVREM ? TargetOpcode::G_SDIV
8013                                         : TargetOpcode::G_UDIV,
8014       {MI.getOperand(0).getReg()}, {MI.getOperand(2), MI.getOperand(3)});
8015   MIRBuilder.buildInstr(
8016       Opcode == TargetOpcode::G_SDIVREM ? TargetOpcode::G_SREM
8017                                         : TargetOpcode::G_UREM,
8018       {MI.getOperand(1).getReg()}, {MI.getOperand(2), MI.getOperand(3)});
8019   MI.eraseFromParent();
8020   return Legalized;
8021 }
8022 
8023 LegalizerHelper::LegalizeResult
lowerAbsToAddXor(MachineInstr & MI)8024 LegalizerHelper::lowerAbsToAddXor(MachineInstr &MI) {
8025   // Expand %res = G_ABS %a into:
8026   // %v1 = G_ASHR %a, scalar_size-1
8027   // %v2 = G_ADD %a, %v1
8028   // %res = G_XOR %v2, %v1
8029   LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
8030   Register OpReg = MI.getOperand(1).getReg();
8031   auto ShiftAmt =
8032       MIRBuilder.buildConstant(DstTy, DstTy.getScalarSizeInBits() - 1);
8033   auto Shift = MIRBuilder.buildAShr(DstTy, OpReg, ShiftAmt);
8034   auto Add = MIRBuilder.buildAdd(DstTy, OpReg, Shift);
8035   MIRBuilder.buildXor(MI.getOperand(0).getReg(), Add, Shift);
8036   MI.eraseFromParent();
8037   return Legalized;
8038 }
8039 
8040 LegalizerHelper::LegalizeResult
lowerAbsToMaxNeg(MachineInstr & MI)8041 LegalizerHelper::lowerAbsToMaxNeg(MachineInstr &MI) {
8042   // Expand %res = G_ABS %a into:
8043   // %v1 = G_CONSTANT 0
8044   // %v2 = G_SUB %v1, %a
8045   // %res = G_SMAX %a, %v2
8046   Register SrcReg = MI.getOperand(1).getReg();
8047   LLT Ty = MRI.getType(SrcReg);
8048   auto Zero = MIRBuilder.buildConstant(Ty, 0).getReg(0);
8049   auto Sub = MIRBuilder.buildSub(Ty, Zero, SrcReg).getReg(0);
8050   MIRBuilder.buildSMax(MI.getOperand(0), SrcReg, Sub);
8051   MI.eraseFromParent();
8052   return Legalized;
8053 }
8054 
8055 LegalizerHelper::LegalizeResult
lowerVectorReduction(MachineInstr & MI)8056 LegalizerHelper::lowerVectorReduction(MachineInstr &MI) {
8057   Register SrcReg = MI.getOperand(1).getReg();
8058   LLT SrcTy = MRI.getType(SrcReg);
8059   LLT DstTy = MRI.getType(SrcReg);
8060 
8061   // The source could be a scalar if the IR type was <1 x sN>.
8062   if (SrcTy.isScalar()) {
8063     if (DstTy.getSizeInBits() > SrcTy.getSizeInBits())
8064       return UnableToLegalize; // FIXME: handle extension.
8065     // This can be just a plain copy.
8066     Observer.changingInstr(MI);
8067     MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::COPY));
8068     Observer.changedInstr(MI);
8069     return Legalized;
8070   }
8071   return UnableToLegalize;
8072 }
8073 
8074 static Type *getTypeForLLT(LLT Ty, LLVMContext &C);
8075 
lowerVAArg(MachineInstr & MI)8076 LegalizerHelper::LegalizeResult LegalizerHelper::lowerVAArg(MachineInstr &MI) {
8077   MachineFunction &MF = *MI.getMF();
8078   const DataLayout &DL = MIRBuilder.getDataLayout();
8079   LLVMContext &Ctx = MF.getFunction().getContext();
8080   Register ListPtr = MI.getOperand(1).getReg();
8081   LLT PtrTy = MRI.getType(ListPtr);
8082 
8083   // LstPtr is a pointer to the head of the list. Get the address
8084   // of the head of the list.
8085   Align PtrAlignment = DL.getABITypeAlign(getTypeForLLT(PtrTy, Ctx));
8086   MachineMemOperand *PtrLoadMMO = MF.getMachineMemOperand(
8087       MachinePointerInfo(), MachineMemOperand::MOLoad, PtrTy, PtrAlignment);
8088   auto VAList = MIRBuilder.buildLoad(PtrTy, ListPtr, *PtrLoadMMO).getReg(0);
8089 
8090   const Align A(MI.getOperand(2).getImm());
8091   LLT PtrTyAsScalarTy = LLT::scalar(PtrTy.getSizeInBits());
8092   if (A > TLI.getMinStackArgumentAlignment()) {
8093     Register AlignAmt =
8094         MIRBuilder.buildConstant(PtrTyAsScalarTy, A.value() - 1).getReg(0);
8095     auto AddDst = MIRBuilder.buildPtrAdd(PtrTy, VAList, AlignAmt);
8096     auto AndDst = MIRBuilder.buildMaskLowPtrBits(PtrTy, AddDst, Log2(A));
8097     VAList = AndDst.getReg(0);
8098   }
8099 
8100   // Increment the pointer, VAList, to the next vaarg
8101   // The list should be bumped by the size of element in the current head of
8102   // list.
8103   Register Dst = MI.getOperand(0).getReg();
8104   LLT LLTTy = MRI.getType(Dst);
8105   Type *Ty = getTypeForLLT(LLTTy, Ctx);
8106   auto IncAmt =
8107       MIRBuilder.buildConstant(PtrTyAsScalarTy, DL.getTypeAllocSize(Ty));
8108   auto Succ = MIRBuilder.buildPtrAdd(PtrTy, VAList, IncAmt);
8109 
8110   // Store the increment VAList to the legalized pointer
8111   MachineMemOperand *StoreMMO = MF.getMachineMemOperand(
8112       MachinePointerInfo(), MachineMemOperand::MOStore, PtrTy, PtrAlignment);
8113   MIRBuilder.buildStore(Succ, ListPtr, *StoreMMO);
8114   // Load the actual argument out of the pointer VAList
8115   Align EltAlignment = DL.getABITypeAlign(Ty);
8116   MachineMemOperand *EltLoadMMO = MF.getMachineMemOperand(
8117       MachinePointerInfo(), MachineMemOperand::MOLoad, LLTTy, EltAlignment);
8118   MIRBuilder.buildLoad(Dst, VAList, *EltLoadMMO);
8119 
8120   MI.eraseFromParent();
8121   return Legalized;
8122 }
8123 
shouldLowerMemFuncForSize(const MachineFunction & MF)8124 static bool shouldLowerMemFuncForSize(const MachineFunction &MF) {
8125   // On Darwin, -Os means optimize for size without hurting performance, so
8126   // only really optimize for size when -Oz (MinSize) is used.
8127   if (MF.getTarget().getTargetTriple().isOSDarwin())
8128     return MF.getFunction().hasMinSize();
8129   return MF.getFunction().hasOptSize();
8130 }
8131 
8132 // Returns a list of types to use for memory op lowering in MemOps. A partial
8133 // port of findOptimalMemOpLowering in TargetLowering.
findGISelOptimalMemOpLowering(std::vector<LLT> & MemOps,unsigned Limit,const MemOp & Op,unsigned DstAS,unsigned SrcAS,const AttributeList & FuncAttributes,const TargetLowering & TLI)8134 static bool findGISelOptimalMemOpLowering(std::vector<LLT> &MemOps,
8135                                           unsigned Limit, const MemOp &Op,
8136                                           unsigned DstAS, unsigned SrcAS,
8137                                           const AttributeList &FuncAttributes,
8138                                           const TargetLowering &TLI) {
8139   if (Op.isMemcpyWithFixedDstAlign() && Op.getSrcAlign() < Op.getDstAlign())
8140     return false;
8141 
8142   LLT Ty = TLI.getOptimalMemOpLLT(Op, FuncAttributes);
8143 
8144   if (Ty == LLT()) {
8145     // Use the largest scalar type whose alignment constraints are satisfied.
8146     // We only need to check DstAlign here as SrcAlign is always greater or
8147     // equal to DstAlign (or zero).
8148     Ty = LLT::scalar(64);
8149     if (Op.isFixedDstAlign())
8150       while (Op.getDstAlign() < Ty.getSizeInBytes() &&
8151              !TLI.allowsMisalignedMemoryAccesses(Ty, DstAS, Op.getDstAlign()))
8152         Ty = LLT::scalar(Ty.getSizeInBytes());
8153     assert(Ty.getSizeInBits() > 0 && "Could not find valid type");
8154     // FIXME: check for the largest legal type we can load/store to.
8155   }
8156 
8157   unsigned NumMemOps = 0;
8158   uint64_t Size = Op.size();
8159   while (Size) {
8160     unsigned TySize = Ty.getSizeInBytes();
8161     while (TySize > Size) {
8162       // For now, only use non-vector load / store's for the left-over pieces.
8163       LLT NewTy = Ty;
8164       // FIXME: check for mem op safety and legality of the types. Not all of
8165       // SDAGisms map cleanly to GISel concepts.
8166       if (NewTy.isVector())
8167         NewTy = NewTy.getSizeInBits() > 64 ? LLT::scalar(64) : LLT::scalar(32);
8168       NewTy = LLT::scalar(llvm::bit_floor(NewTy.getSizeInBits() - 1));
8169       unsigned NewTySize = NewTy.getSizeInBytes();
8170       assert(NewTySize > 0 && "Could not find appropriate type");
8171 
8172       // If the new LLT cannot cover all of the remaining bits, then consider
8173       // issuing a (or a pair of) unaligned and overlapping load / store.
8174       unsigned Fast;
8175       // Need to get a VT equivalent for allowMisalignedMemoryAccesses().
8176       MVT VT = getMVTForLLT(Ty);
8177       if (NumMemOps && Op.allowOverlap() && NewTySize < Size &&
8178           TLI.allowsMisalignedMemoryAccesses(
8179               VT, DstAS, Op.isFixedDstAlign() ? Op.getDstAlign() : Align(1),
8180               MachineMemOperand::MONone, &Fast) &&
8181           Fast)
8182         TySize = Size;
8183       else {
8184         Ty = NewTy;
8185         TySize = NewTySize;
8186       }
8187     }
8188 
8189     if (++NumMemOps > Limit)
8190       return false;
8191 
8192     MemOps.push_back(Ty);
8193     Size -= TySize;
8194   }
8195 
8196   return true;
8197 }
8198 
getTypeForLLT(LLT Ty,LLVMContext & C)8199 static Type *getTypeForLLT(LLT Ty, LLVMContext &C) {
8200   if (Ty.isVector())
8201     return FixedVectorType::get(IntegerType::get(C, Ty.getScalarSizeInBits()),
8202                                 Ty.getNumElements());
8203   return IntegerType::get(C, Ty.getSizeInBits());
8204 }
8205 
8206 // Get a vectorized representation of the memset value operand, GISel edition.
getMemsetValue(Register Val,LLT Ty,MachineIRBuilder & MIB)8207 static Register getMemsetValue(Register Val, LLT Ty, MachineIRBuilder &MIB) {
8208   MachineRegisterInfo &MRI = *MIB.getMRI();
8209   unsigned NumBits = Ty.getScalarSizeInBits();
8210   auto ValVRegAndVal = getIConstantVRegValWithLookThrough(Val, MRI);
8211   if (!Ty.isVector() && ValVRegAndVal) {
8212     APInt Scalar = ValVRegAndVal->Value.trunc(8);
8213     APInt SplatVal = APInt::getSplat(NumBits, Scalar);
8214     return MIB.buildConstant(Ty, SplatVal).getReg(0);
8215   }
8216 
8217   // Extend the byte value to the larger type, and then multiply by a magic
8218   // value 0x010101... in order to replicate it across every byte.
8219   // Unless it's zero, in which case just emit a larger G_CONSTANT 0.
8220   if (ValVRegAndVal && ValVRegAndVal->Value == 0) {
8221     return MIB.buildConstant(Ty, 0).getReg(0);
8222   }
8223 
8224   LLT ExtType = Ty.getScalarType();
8225   auto ZExt = MIB.buildZExtOrTrunc(ExtType, Val);
8226   if (NumBits > 8) {
8227     APInt Magic = APInt::getSplat(NumBits, APInt(8, 0x01));
8228     auto MagicMI = MIB.buildConstant(ExtType, Magic);
8229     Val = MIB.buildMul(ExtType, ZExt, MagicMI).getReg(0);
8230   }
8231 
8232   // For vector types create a G_BUILD_VECTOR.
8233   if (Ty.isVector())
8234     Val = MIB.buildSplatVector(Ty, Val).getReg(0);
8235 
8236   return Val;
8237 }
8238 
8239 LegalizerHelper::LegalizeResult
lowerMemset(MachineInstr & MI,Register Dst,Register Val,uint64_t KnownLen,Align Alignment,bool IsVolatile)8240 LegalizerHelper::lowerMemset(MachineInstr &MI, Register Dst, Register Val,
8241                              uint64_t KnownLen, Align Alignment,
8242                              bool IsVolatile) {
8243   auto &MF = *MI.getParent()->getParent();
8244   const auto &TLI = *MF.getSubtarget().getTargetLowering();
8245   auto &DL = MF.getDataLayout();
8246   LLVMContext &C = MF.getFunction().getContext();
8247 
8248   assert(KnownLen != 0 && "Have a zero length memset length!");
8249 
8250   bool DstAlignCanChange = false;
8251   MachineFrameInfo &MFI = MF.getFrameInfo();
8252   bool OptSize = shouldLowerMemFuncForSize(MF);
8253 
8254   MachineInstr *FIDef = getOpcodeDef(TargetOpcode::G_FRAME_INDEX, Dst, MRI);
8255   if (FIDef && !MFI.isFixedObjectIndex(FIDef->getOperand(1).getIndex()))
8256     DstAlignCanChange = true;
8257 
8258   unsigned Limit = TLI.getMaxStoresPerMemset(OptSize);
8259   std::vector<LLT> MemOps;
8260 
8261   const auto &DstMMO = **MI.memoperands_begin();
8262   MachinePointerInfo DstPtrInfo = DstMMO.getPointerInfo();
8263 
8264   auto ValVRegAndVal = getIConstantVRegValWithLookThrough(Val, MRI);
8265   bool IsZeroVal = ValVRegAndVal && ValVRegAndVal->Value == 0;
8266 
8267   if (!findGISelOptimalMemOpLowering(MemOps, Limit,
8268                                      MemOp::Set(KnownLen, DstAlignCanChange,
8269                                                 Alignment,
8270                                                 /*IsZeroMemset=*/IsZeroVal,
8271                                                 /*IsVolatile=*/IsVolatile),
8272                                      DstPtrInfo.getAddrSpace(), ~0u,
8273                                      MF.getFunction().getAttributes(), TLI))
8274     return UnableToLegalize;
8275 
8276   if (DstAlignCanChange) {
8277     // Get an estimate of the type from the LLT.
8278     Type *IRTy = getTypeForLLT(MemOps[0], C);
8279     Align NewAlign = DL.getABITypeAlign(IRTy);
8280     if (NewAlign > Alignment) {
8281       Alignment = NewAlign;
8282       unsigned FI = FIDef->getOperand(1).getIndex();
8283       // Give the stack frame object a larger alignment if needed.
8284       if (MFI.getObjectAlign(FI) < Alignment)
8285         MFI.setObjectAlignment(FI, Alignment);
8286     }
8287   }
8288 
8289   MachineIRBuilder MIB(MI);
8290   // Find the largest store and generate the bit pattern for it.
8291   LLT LargestTy = MemOps[0];
8292   for (unsigned i = 1; i < MemOps.size(); i++)
8293     if (MemOps[i].getSizeInBits() > LargestTy.getSizeInBits())
8294       LargestTy = MemOps[i];
8295 
8296   // The memset stored value is always defined as an s8, so in order to make it
8297   // work with larger store types we need to repeat the bit pattern across the
8298   // wider type.
8299   Register MemSetValue = getMemsetValue(Val, LargestTy, MIB);
8300 
8301   if (!MemSetValue)
8302     return UnableToLegalize;
8303 
8304   // Generate the stores. For each store type in the list, we generate the
8305   // matching store of that type to the destination address.
8306   LLT PtrTy = MRI.getType(Dst);
8307   unsigned DstOff = 0;
8308   unsigned Size = KnownLen;
8309   for (unsigned I = 0; I < MemOps.size(); I++) {
8310     LLT Ty = MemOps[I];
8311     unsigned TySize = Ty.getSizeInBytes();
8312     if (TySize > Size) {
8313       // Issuing an unaligned load / store pair that overlaps with the previous
8314       // pair. Adjust the offset accordingly.
8315       assert(I == MemOps.size() - 1 && I != 0);
8316       DstOff -= TySize - Size;
8317     }
8318 
8319     // If this store is smaller than the largest store see whether we can get
8320     // the smaller value for free with a truncate.
8321     Register Value = MemSetValue;
8322     if (Ty.getSizeInBits() < LargestTy.getSizeInBits()) {
8323       MVT VT = getMVTForLLT(Ty);
8324       MVT LargestVT = getMVTForLLT(LargestTy);
8325       if (!LargestTy.isVector() && !Ty.isVector() &&
8326           TLI.isTruncateFree(LargestVT, VT))
8327         Value = MIB.buildTrunc(Ty, MemSetValue).getReg(0);
8328       else
8329         Value = getMemsetValue(Val, Ty, MIB);
8330       if (!Value)
8331         return UnableToLegalize;
8332     }
8333 
8334     auto *StoreMMO = MF.getMachineMemOperand(&DstMMO, DstOff, Ty);
8335 
8336     Register Ptr = Dst;
8337     if (DstOff != 0) {
8338       auto Offset =
8339           MIB.buildConstant(LLT::scalar(PtrTy.getSizeInBits()), DstOff);
8340       Ptr = MIB.buildPtrAdd(PtrTy, Dst, Offset).getReg(0);
8341     }
8342 
8343     MIB.buildStore(Value, Ptr, *StoreMMO);
8344     DstOff += Ty.getSizeInBytes();
8345     Size -= TySize;
8346   }
8347 
8348   MI.eraseFromParent();
8349   return Legalized;
8350 }
8351 
8352 LegalizerHelper::LegalizeResult
lowerMemcpyInline(MachineInstr & MI)8353 LegalizerHelper::lowerMemcpyInline(MachineInstr &MI) {
8354   assert(MI.getOpcode() == TargetOpcode::G_MEMCPY_INLINE);
8355 
8356   auto [Dst, Src, Len] = MI.getFirst3Regs();
8357 
8358   const auto *MMOIt = MI.memoperands_begin();
8359   const MachineMemOperand *MemOp = *MMOIt;
8360   bool IsVolatile = MemOp->isVolatile();
8361 
8362   // See if this is a constant length copy
8363   auto LenVRegAndVal = getIConstantVRegValWithLookThrough(Len, MRI);
8364   // FIXME: support dynamically sized G_MEMCPY_INLINE
8365   assert(LenVRegAndVal &&
8366          "inline memcpy with dynamic size is not yet supported");
8367   uint64_t KnownLen = LenVRegAndVal->Value.getZExtValue();
8368   if (KnownLen == 0) {
8369     MI.eraseFromParent();
8370     return Legalized;
8371   }
8372 
8373   const auto &DstMMO = **MI.memoperands_begin();
8374   const auto &SrcMMO = **std::next(MI.memoperands_begin());
8375   Align DstAlign = DstMMO.getBaseAlign();
8376   Align SrcAlign = SrcMMO.getBaseAlign();
8377 
8378   return lowerMemcpyInline(MI, Dst, Src, KnownLen, DstAlign, SrcAlign,
8379                            IsVolatile);
8380 }
8381 
8382 LegalizerHelper::LegalizeResult
lowerMemcpyInline(MachineInstr & MI,Register Dst,Register Src,uint64_t KnownLen,Align DstAlign,Align SrcAlign,bool IsVolatile)8383 LegalizerHelper::lowerMemcpyInline(MachineInstr &MI, Register Dst, Register Src,
8384                                    uint64_t KnownLen, Align DstAlign,
8385                                    Align SrcAlign, bool IsVolatile) {
8386   assert(MI.getOpcode() == TargetOpcode::G_MEMCPY_INLINE);
8387   return lowerMemcpy(MI, Dst, Src, KnownLen,
8388                      std::numeric_limits<uint64_t>::max(), DstAlign, SrcAlign,
8389                      IsVolatile);
8390 }
8391 
8392 LegalizerHelper::LegalizeResult
lowerMemcpy(MachineInstr & MI,Register Dst,Register Src,uint64_t KnownLen,uint64_t Limit,Align DstAlign,Align SrcAlign,bool IsVolatile)8393 LegalizerHelper::lowerMemcpy(MachineInstr &MI, Register Dst, Register Src,
8394                              uint64_t KnownLen, uint64_t Limit, Align DstAlign,
8395                              Align SrcAlign, bool IsVolatile) {
8396   auto &MF = *MI.getParent()->getParent();
8397   const auto &TLI = *MF.getSubtarget().getTargetLowering();
8398   auto &DL = MF.getDataLayout();
8399   LLVMContext &C = MF.getFunction().getContext();
8400 
8401   assert(KnownLen != 0 && "Have a zero length memcpy length!");
8402 
8403   bool DstAlignCanChange = false;
8404   MachineFrameInfo &MFI = MF.getFrameInfo();
8405   Align Alignment = std::min(DstAlign, SrcAlign);
8406 
8407   MachineInstr *FIDef = getOpcodeDef(TargetOpcode::G_FRAME_INDEX, Dst, MRI);
8408   if (FIDef && !MFI.isFixedObjectIndex(FIDef->getOperand(1).getIndex()))
8409     DstAlignCanChange = true;
8410 
8411   // FIXME: infer better src pointer alignment like SelectionDAG does here.
8412   // FIXME: also use the equivalent of isMemSrcFromConstant and alwaysinlining
8413   // if the memcpy is in a tail call position.
8414 
8415   std::vector<LLT> MemOps;
8416 
8417   const auto &DstMMO = **MI.memoperands_begin();
8418   const auto &SrcMMO = **std::next(MI.memoperands_begin());
8419   MachinePointerInfo DstPtrInfo = DstMMO.getPointerInfo();
8420   MachinePointerInfo SrcPtrInfo = SrcMMO.getPointerInfo();
8421 
8422   if (!findGISelOptimalMemOpLowering(
8423           MemOps, Limit,
8424           MemOp::Copy(KnownLen, DstAlignCanChange, Alignment, SrcAlign,
8425                       IsVolatile),
8426           DstPtrInfo.getAddrSpace(), SrcPtrInfo.getAddrSpace(),
8427           MF.getFunction().getAttributes(), TLI))
8428     return UnableToLegalize;
8429 
8430   if (DstAlignCanChange) {
8431     // Get an estimate of the type from the LLT.
8432     Type *IRTy = getTypeForLLT(MemOps[0], C);
8433     Align NewAlign = DL.getABITypeAlign(IRTy);
8434 
8435     // Don't promote to an alignment that would require dynamic stack
8436     // realignment.
8437     const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
8438     if (!TRI->hasStackRealignment(MF))
8439       while (NewAlign > Alignment && DL.exceedsNaturalStackAlignment(NewAlign))
8440         NewAlign = NewAlign.previous();
8441 
8442     if (NewAlign > Alignment) {
8443       Alignment = NewAlign;
8444       unsigned FI = FIDef->getOperand(1).getIndex();
8445       // Give the stack frame object a larger alignment if needed.
8446       if (MFI.getObjectAlign(FI) < Alignment)
8447         MFI.setObjectAlignment(FI, Alignment);
8448     }
8449   }
8450 
8451   LLVM_DEBUG(dbgs() << "Inlining memcpy: " << MI << " into loads & stores\n");
8452 
8453   MachineIRBuilder MIB(MI);
8454   // Now we need to emit a pair of load and stores for each of the types we've
8455   // collected. I.e. for each type, generate a load from the source pointer of
8456   // that type width, and then generate a corresponding store to the dest buffer
8457   // of that value loaded. This can result in a sequence of loads and stores
8458   // mixed types, depending on what the target specifies as good types to use.
8459   unsigned CurrOffset = 0;
8460   unsigned Size = KnownLen;
8461   for (auto CopyTy : MemOps) {
8462     // Issuing an unaligned load / store pair  that overlaps with the previous
8463     // pair. Adjust the offset accordingly.
8464     if (CopyTy.getSizeInBytes() > Size)
8465       CurrOffset -= CopyTy.getSizeInBytes() - Size;
8466 
8467     // Construct MMOs for the accesses.
8468     auto *LoadMMO =
8469         MF.getMachineMemOperand(&SrcMMO, CurrOffset, CopyTy.getSizeInBytes());
8470     auto *StoreMMO =
8471         MF.getMachineMemOperand(&DstMMO, CurrOffset, CopyTy.getSizeInBytes());
8472 
8473     // Create the load.
8474     Register LoadPtr = Src;
8475     Register Offset;
8476     if (CurrOffset != 0) {
8477       LLT SrcTy = MRI.getType(Src);
8478       Offset = MIB.buildConstant(LLT::scalar(SrcTy.getSizeInBits()), CurrOffset)
8479                    .getReg(0);
8480       LoadPtr = MIB.buildPtrAdd(SrcTy, Src, Offset).getReg(0);
8481     }
8482     auto LdVal = MIB.buildLoad(CopyTy, LoadPtr, *LoadMMO);
8483 
8484     // Create the store.
8485     Register StorePtr = Dst;
8486     if (CurrOffset != 0) {
8487       LLT DstTy = MRI.getType(Dst);
8488       StorePtr = MIB.buildPtrAdd(DstTy, Dst, Offset).getReg(0);
8489     }
8490     MIB.buildStore(LdVal, StorePtr, *StoreMMO);
8491     CurrOffset += CopyTy.getSizeInBytes();
8492     Size -= CopyTy.getSizeInBytes();
8493   }
8494 
8495   MI.eraseFromParent();
8496   return Legalized;
8497 }
8498 
8499 LegalizerHelper::LegalizeResult
lowerMemmove(MachineInstr & MI,Register Dst,Register Src,uint64_t KnownLen,Align DstAlign,Align SrcAlign,bool IsVolatile)8500 LegalizerHelper::lowerMemmove(MachineInstr &MI, Register Dst, Register Src,
8501                               uint64_t KnownLen, Align DstAlign, Align SrcAlign,
8502                               bool IsVolatile) {
8503   auto &MF = *MI.getParent()->getParent();
8504   const auto &TLI = *MF.getSubtarget().getTargetLowering();
8505   auto &DL = MF.getDataLayout();
8506   LLVMContext &C = MF.getFunction().getContext();
8507 
8508   assert(KnownLen != 0 && "Have a zero length memmove length!");
8509 
8510   bool DstAlignCanChange = false;
8511   MachineFrameInfo &MFI = MF.getFrameInfo();
8512   bool OptSize = shouldLowerMemFuncForSize(MF);
8513   Align Alignment = std::min(DstAlign, SrcAlign);
8514 
8515   MachineInstr *FIDef = getOpcodeDef(TargetOpcode::G_FRAME_INDEX, Dst, MRI);
8516   if (FIDef && !MFI.isFixedObjectIndex(FIDef->getOperand(1).getIndex()))
8517     DstAlignCanChange = true;
8518 
8519   unsigned Limit = TLI.getMaxStoresPerMemmove(OptSize);
8520   std::vector<LLT> MemOps;
8521 
8522   const auto &DstMMO = **MI.memoperands_begin();
8523   const auto &SrcMMO = **std::next(MI.memoperands_begin());
8524   MachinePointerInfo DstPtrInfo = DstMMO.getPointerInfo();
8525   MachinePointerInfo SrcPtrInfo = SrcMMO.getPointerInfo();
8526 
8527   // FIXME: SelectionDAG always passes false for 'AllowOverlap', apparently due
8528   // to a bug in it's findOptimalMemOpLowering implementation. For now do the
8529   // same thing here.
8530   if (!findGISelOptimalMemOpLowering(
8531           MemOps, Limit,
8532           MemOp::Copy(KnownLen, DstAlignCanChange, Alignment, SrcAlign,
8533                       /*IsVolatile*/ true),
8534           DstPtrInfo.getAddrSpace(), SrcPtrInfo.getAddrSpace(),
8535           MF.getFunction().getAttributes(), TLI))
8536     return UnableToLegalize;
8537 
8538   if (DstAlignCanChange) {
8539     // Get an estimate of the type from the LLT.
8540     Type *IRTy = getTypeForLLT(MemOps[0], C);
8541     Align NewAlign = DL.getABITypeAlign(IRTy);
8542 
8543     // Don't promote to an alignment that would require dynamic stack
8544     // realignment.
8545     const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
8546     if (!TRI->hasStackRealignment(MF))
8547       while (NewAlign > Alignment && DL.exceedsNaturalStackAlignment(NewAlign))
8548         NewAlign = NewAlign.previous();
8549 
8550     if (NewAlign > Alignment) {
8551       Alignment = NewAlign;
8552       unsigned FI = FIDef->getOperand(1).getIndex();
8553       // Give the stack frame object a larger alignment if needed.
8554       if (MFI.getObjectAlign(FI) < Alignment)
8555         MFI.setObjectAlignment(FI, Alignment);
8556     }
8557   }
8558 
8559   LLVM_DEBUG(dbgs() << "Inlining memmove: " << MI << " into loads & stores\n");
8560 
8561   MachineIRBuilder MIB(MI);
8562   // Memmove requires that we perform the loads first before issuing the stores.
8563   // Apart from that, this loop is pretty much doing the same thing as the
8564   // memcpy codegen function.
8565   unsigned CurrOffset = 0;
8566   SmallVector<Register, 16> LoadVals;
8567   for (auto CopyTy : MemOps) {
8568     // Construct MMO for the load.
8569     auto *LoadMMO =
8570         MF.getMachineMemOperand(&SrcMMO, CurrOffset, CopyTy.getSizeInBytes());
8571 
8572     // Create the load.
8573     Register LoadPtr = Src;
8574     if (CurrOffset != 0) {
8575       LLT SrcTy = MRI.getType(Src);
8576       auto Offset =
8577           MIB.buildConstant(LLT::scalar(SrcTy.getSizeInBits()), CurrOffset);
8578       LoadPtr = MIB.buildPtrAdd(SrcTy, Src, Offset).getReg(0);
8579     }
8580     LoadVals.push_back(MIB.buildLoad(CopyTy, LoadPtr, *LoadMMO).getReg(0));
8581     CurrOffset += CopyTy.getSizeInBytes();
8582   }
8583 
8584   CurrOffset = 0;
8585   for (unsigned I = 0; I < MemOps.size(); ++I) {
8586     LLT CopyTy = MemOps[I];
8587     // Now store the values loaded.
8588     auto *StoreMMO =
8589         MF.getMachineMemOperand(&DstMMO, CurrOffset, CopyTy.getSizeInBytes());
8590 
8591     Register StorePtr = Dst;
8592     if (CurrOffset != 0) {
8593       LLT DstTy = MRI.getType(Dst);
8594       auto Offset =
8595           MIB.buildConstant(LLT::scalar(DstTy.getSizeInBits()), CurrOffset);
8596       StorePtr = MIB.buildPtrAdd(DstTy, Dst, Offset).getReg(0);
8597     }
8598     MIB.buildStore(LoadVals[I], StorePtr, *StoreMMO);
8599     CurrOffset += CopyTy.getSizeInBytes();
8600   }
8601   MI.eraseFromParent();
8602   return Legalized;
8603 }
8604 
8605 LegalizerHelper::LegalizeResult
lowerMemCpyFamily(MachineInstr & MI,unsigned MaxLen)8606 LegalizerHelper::lowerMemCpyFamily(MachineInstr &MI, unsigned MaxLen) {
8607   const unsigned Opc = MI.getOpcode();
8608   // This combine is fairly complex so it's not written with a separate
8609   // matcher function.
8610   assert((Opc == TargetOpcode::G_MEMCPY || Opc == TargetOpcode::G_MEMMOVE ||
8611           Opc == TargetOpcode::G_MEMSET) &&
8612          "Expected memcpy like instruction");
8613 
8614   auto MMOIt = MI.memoperands_begin();
8615   const MachineMemOperand *MemOp = *MMOIt;
8616 
8617   Align DstAlign = MemOp->getBaseAlign();
8618   Align SrcAlign;
8619   auto [Dst, Src, Len] = MI.getFirst3Regs();
8620 
8621   if (Opc != TargetOpcode::G_MEMSET) {
8622     assert(MMOIt != MI.memoperands_end() && "Expected a second MMO on MI");
8623     MemOp = *(++MMOIt);
8624     SrcAlign = MemOp->getBaseAlign();
8625   }
8626 
8627   // See if this is a constant length copy
8628   auto LenVRegAndVal = getIConstantVRegValWithLookThrough(Len, MRI);
8629   if (!LenVRegAndVal)
8630     return UnableToLegalize;
8631   uint64_t KnownLen = LenVRegAndVal->Value.getZExtValue();
8632 
8633   if (KnownLen == 0) {
8634     MI.eraseFromParent();
8635     return Legalized;
8636   }
8637 
8638   bool IsVolatile = MemOp->isVolatile();
8639   if (Opc == TargetOpcode::G_MEMCPY_INLINE)
8640     return lowerMemcpyInline(MI, Dst, Src, KnownLen, DstAlign, SrcAlign,
8641                              IsVolatile);
8642 
8643   // Don't try to optimize volatile.
8644   if (IsVolatile)
8645     return UnableToLegalize;
8646 
8647   if (MaxLen && KnownLen > MaxLen)
8648     return UnableToLegalize;
8649 
8650   if (Opc == TargetOpcode::G_MEMCPY) {
8651     auto &MF = *MI.getParent()->getParent();
8652     const auto &TLI = *MF.getSubtarget().getTargetLowering();
8653     bool OptSize = shouldLowerMemFuncForSize(MF);
8654     uint64_t Limit = TLI.getMaxStoresPerMemcpy(OptSize);
8655     return lowerMemcpy(MI, Dst, Src, KnownLen, Limit, DstAlign, SrcAlign,
8656                        IsVolatile);
8657   }
8658   if (Opc == TargetOpcode::G_MEMMOVE)
8659     return lowerMemmove(MI, Dst, Src, KnownLen, DstAlign, SrcAlign, IsVolatile);
8660   if (Opc == TargetOpcode::G_MEMSET)
8661     return lowerMemset(MI, Dst, Src, KnownLen, DstAlign, IsVolatile);
8662   return UnableToLegalize;
8663 }
8664