1 //===-- llvm/CodeGen/GlobalISel/LegalizerHelper.cpp -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file This file implements the LegalizerHelper class to legalize
10 /// individual instructions and the LegalizeMachineIR wrapper pass for the
11 /// primary legalization.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
16 #include "llvm/CodeGen/GlobalISel/CallLowering.h"
17 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
18 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
19 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
20 #include "llvm/CodeGen/GlobalISel/LostDebugLocObserver.h"
21 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
22 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
23 #include "llvm/CodeGen/GlobalISel/Utils.h"
24 #include "llvm/CodeGen/MachineFrameInfo.h"
25 #include "llvm/CodeGen/MachineRegisterInfo.h"
26 #include "llvm/CodeGen/TargetFrameLowering.h"
27 #include "llvm/CodeGen/TargetInstrInfo.h"
28 #include "llvm/CodeGen/TargetLowering.h"
29 #include "llvm/CodeGen/TargetOpcodes.h"
30 #include "llvm/CodeGen/TargetSubtargetInfo.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/MathExtras.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include "llvm/Target/TargetMachine.h"
36 #include <numeric>
37 #include <optional>
38 
39 #define DEBUG_TYPE "legalizer"
40 
41 using namespace llvm;
42 using namespace LegalizeActions;
43 using namespace MIPatternMatch;
44 
45 /// Try to break down \p OrigTy into \p NarrowTy sized pieces.
46 ///
47 /// Returns the number of \p NarrowTy elements needed to reconstruct \p OrigTy,
48 /// with any leftover piece as type \p LeftoverTy
49 ///
50 /// Returns -1 in the first element of the pair if the breakdown is not
51 /// satisfiable.
52 static std::pair<int, int>
getNarrowTypeBreakDown(LLT OrigTy,LLT NarrowTy,LLT & LeftoverTy)53 getNarrowTypeBreakDown(LLT OrigTy, LLT NarrowTy, LLT &LeftoverTy) {
54   assert(!LeftoverTy.isValid() && "this is an out argument");
55 
56   unsigned Size = OrigTy.getSizeInBits();
57   unsigned NarrowSize = NarrowTy.getSizeInBits();
58   unsigned NumParts = Size / NarrowSize;
59   unsigned LeftoverSize = Size - NumParts * NarrowSize;
60   assert(Size > NarrowSize);
61 
62   if (LeftoverSize == 0)
63     return {NumParts, 0};
64 
65   if (NarrowTy.isVector()) {
66     unsigned EltSize = OrigTy.getScalarSizeInBits();
67     if (LeftoverSize % EltSize != 0)
68       return {-1, -1};
69     LeftoverTy = LLT::scalarOrVector(
70         ElementCount::getFixed(LeftoverSize / EltSize), EltSize);
71   } else {
72     LeftoverTy = LLT::scalar(LeftoverSize);
73   }
74 
75   int NumLeftover = LeftoverSize / LeftoverTy.getSizeInBits();
76   return std::make_pair(NumParts, NumLeftover);
77 }
78 
getFloatTypeForLLT(LLVMContext & Ctx,LLT Ty)79 static Type *getFloatTypeForLLT(LLVMContext &Ctx, LLT Ty) {
80 
81   if (!Ty.isScalar())
82     return nullptr;
83 
84   switch (Ty.getSizeInBits()) {
85   case 16:
86     return Type::getHalfTy(Ctx);
87   case 32:
88     return Type::getFloatTy(Ctx);
89   case 64:
90     return Type::getDoubleTy(Ctx);
91   case 80:
92     return Type::getX86_FP80Ty(Ctx);
93   case 128:
94     return Type::getFP128Ty(Ctx);
95   default:
96     return nullptr;
97   }
98 }
99 
LegalizerHelper(MachineFunction & MF,GISelChangeObserver & Observer,MachineIRBuilder & Builder)100 LegalizerHelper::LegalizerHelper(MachineFunction &MF,
101                                  GISelChangeObserver &Observer,
102                                  MachineIRBuilder &Builder)
103     : MIRBuilder(Builder), Observer(Observer), MRI(MF.getRegInfo()),
104       LI(*MF.getSubtarget().getLegalizerInfo()),
105       TLI(*MF.getSubtarget().getTargetLowering()) { }
106 
LegalizerHelper(MachineFunction & MF,const LegalizerInfo & LI,GISelChangeObserver & Observer,MachineIRBuilder & B)107 LegalizerHelper::LegalizerHelper(MachineFunction &MF, const LegalizerInfo &LI,
108                                  GISelChangeObserver &Observer,
109                                  MachineIRBuilder &B)
110   : MIRBuilder(B), Observer(Observer), MRI(MF.getRegInfo()), LI(LI),
111     TLI(*MF.getSubtarget().getTargetLowering()) { }
112 
113 LegalizerHelper::LegalizeResult
legalizeInstrStep(MachineInstr & MI,LostDebugLocObserver & LocObserver)114 LegalizerHelper::legalizeInstrStep(MachineInstr &MI,
115                                    LostDebugLocObserver &LocObserver) {
116   LLVM_DEBUG(dbgs() << "Legalizing: " << MI);
117 
118   MIRBuilder.setInstrAndDebugLoc(MI);
119 
120   if (MI.getOpcode() == TargetOpcode::G_INTRINSIC ||
121       MI.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS)
122     return LI.legalizeIntrinsic(*this, MI) ? Legalized : UnableToLegalize;
123   auto Step = LI.getAction(MI, MRI);
124   switch (Step.Action) {
125   case Legal:
126     LLVM_DEBUG(dbgs() << ".. Already legal\n");
127     return AlreadyLegal;
128   case Libcall:
129     LLVM_DEBUG(dbgs() << ".. Convert to libcall\n");
130     return libcall(MI, LocObserver);
131   case NarrowScalar:
132     LLVM_DEBUG(dbgs() << ".. Narrow scalar\n");
133     return narrowScalar(MI, Step.TypeIdx, Step.NewType);
134   case WidenScalar:
135     LLVM_DEBUG(dbgs() << ".. Widen scalar\n");
136     return widenScalar(MI, Step.TypeIdx, Step.NewType);
137   case Bitcast:
138     LLVM_DEBUG(dbgs() << ".. Bitcast type\n");
139     return bitcast(MI, Step.TypeIdx, Step.NewType);
140   case Lower:
141     LLVM_DEBUG(dbgs() << ".. Lower\n");
142     return lower(MI, Step.TypeIdx, Step.NewType);
143   case FewerElements:
144     LLVM_DEBUG(dbgs() << ".. Reduce number of elements\n");
145     return fewerElementsVector(MI, Step.TypeIdx, Step.NewType);
146   case MoreElements:
147     LLVM_DEBUG(dbgs() << ".. Increase number of elements\n");
148     return moreElementsVector(MI, Step.TypeIdx, Step.NewType);
149   case Custom:
150     LLVM_DEBUG(dbgs() << ".. Custom legalization\n");
151     return LI.legalizeCustom(*this, MI) ? Legalized : UnableToLegalize;
152   default:
153     LLVM_DEBUG(dbgs() << ".. Unable to legalize\n");
154     return UnableToLegalize;
155   }
156 }
157 
extractParts(Register Reg,LLT Ty,int NumParts,SmallVectorImpl<Register> & VRegs)158 void LegalizerHelper::extractParts(Register Reg, LLT Ty, int NumParts,
159                                    SmallVectorImpl<Register> &VRegs) {
160   for (int i = 0; i < NumParts; ++i)
161     VRegs.push_back(MRI.createGenericVirtualRegister(Ty));
162   MIRBuilder.buildUnmerge(VRegs, Reg);
163 }
164 
extractParts(Register Reg,LLT RegTy,LLT MainTy,LLT & LeftoverTy,SmallVectorImpl<Register> & VRegs,SmallVectorImpl<Register> & LeftoverRegs)165 bool LegalizerHelper::extractParts(Register Reg, LLT RegTy,
166                                    LLT MainTy, LLT &LeftoverTy,
167                                    SmallVectorImpl<Register> &VRegs,
168                                    SmallVectorImpl<Register> &LeftoverRegs) {
169   assert(!LeftoverTy.isValid() && "this is an out argument");
170 
171   unsigned RegSize = RegTy.getSizeInBits();
172   unsigned MainSize = MainTy.getSizeInBits();
173   unsigned NumParts = RegSize / MainSize;
174   unsigned LeftoverSize = RegSize - NumParts * MainSize;
175 
176   // Use an unmerge when possible.
177   if (LeftoverSize == 0) {
178     for (unsigned I = 0; I < NumParts; ++I)
179       VRegs.push_back(MRI.createGenericVirtualRegister(MainTy));
180     MIRBuilder.buildUnmerge(VRegs, Reg);
181     return true;
182   }
183 
184   // Perform irregular split. Leftover is last element of RegPieces.
185   if (MainTy.isVector()) {
186     SmallVector<Register, 8> RegPieces;
187     extractVectorParts(Reg, MainTy.getNumElements(), RegPieces);
188     for (unsigned i = 0; i < RegPieces.size() - 1; ++i)
189       VRegs.push_back(RegPieces[i]);
190     LeftoverRegs.push_back(RegPieces[RegPieces.size() - 1]);
191     LeftoverTy = MRI.getType(LeftoverRegs[0]);
192     return true;
193   }
194 
195   LeftoverTy = LLT::scalar(LeftoverSize);
196   // For irregular sizes, extract the individual parts.
197   for (unsigned I = 0; I != NumParts; ++I) {
198     Register NewReg = MRI.createGenericVirtualRegister(MainTy);
199     VRegs.push_back(NewReg);
200     MIRBuilder.buildExtract(NewReg, Reg, MainSize * I);
201   }
202 
203   for (unsigned Offset = MainSize * NumParts; Offset < RegSize;
204        Offset += LeftoverSize) {
205     Register NewReg = MRI.createGenericVirtualRegister(LeftoverTy);
206     LeftoverRegs.push_back(NewReg);
207     MIRBuilder.buildExtract(NewReg, Reg, Offset);
208   }
209 
210   return true;
211 }
212 
extractVectorParts(Register Reg,unsigned NumElts,SmallVectorImpl<Register> & VRegs)213 void LegalizerHelper::extractVectorParts(Register Reg, unsigned NumElts,
214                                          SmallVectorImpl<Register> &VRegs) {
215   LLT RegTy = MRI.getType(Reg);
216   assert(RegTy.isVector() && "Expected a vector type");
217 
218   LLT EltTy = RegTy.getElementType();
219   LLT NarrowTy = (NumElts == 1) ? EltTy : LLT::fixed_vector(NumElts, EltTy);
220   unsigned RegNumElts = RegTy.getNumElements();
221   unsigned LeftoverNumElts = RegNumElts % NumElts;
222   unsigned NumNarrowTyPieces = RegNumElts / NumElts;
223 
224   // Perfect split without leftover
225   if (LeftoverNumElts == 0)
226     return extractParts(Reg, NarrowTy, NumNarrowTyPieces, VRegs);
227 
228   // Irregular split. Provide direct access to all elements for artifact
229   // combiner using unmerge to elements. Then build vectors with NumElts
230   // elements. Remaining element(s) will be (used to build vector) Leftover.
231   SmallVector<Register, 8> Elts;
232   extractParts(Reg, EltTy, RegNumElts, Elts);
233 
234   unsigned Offset = 0;
235   // Requested sub-vectors of NarrowTy.
236   for (unsigned i = 0; i < NumNarrowTyPieces; ++i, Offset += NumElts) {
237     ArrayRef<Register> Pieces(&Elts[Offset], NumElts);
238     VRegs.push_back(MIRBuilder.buildMergeLikeInstr(NarrowTy, Pieces).getReg(0));
239   }
240 
241   // Leftover element(s).
242   if (LeftoverNumElts == 1) {
243     VRegs.push_back(Elts[Offset]);
244   } else {
245     LLT LeftoverTy = LLT::fixed_vector(LeftoverNumElts, EltTy);
246     ArrayRef<Register> Pieces(&Elts[Offset], LeftoverNumElts);
247     VRegs.push_back(
248         MIRBuilder.buildMergeLikeInstr(LeftoverTy, Pieces).getReg(0));
249   }
250 }
251 
insertParts(Register DstReg,LLT ResultTy,LLT PartTy,ArrayRef<Register> PartRegs,LLT LeftoverTy,ArrayRef<Register> LeftoverRegs)252 void LegalizerHelper::insertParts(Register DstReg,
253                                   LLT ResultTy, LLT PartTy,
254                                   ArrayRef<Register> PartRegs,
255                                   LLT LeftoverTy,
256                                   ArrayRef<Register> LeftoverRegs) {
257   if (!LeftoverTy.isValid()) {
258     assert(LeftoverRegs.empty());
259 
260     if (!ResultTy.isVector()) {
261       MIRBuilder.buildMergeLikeInstr(DstReg, PartRegs);
262       return;
263     }
264 
265     if (PartTy.isVector())
266       MIRBuilder.buildConcatVectors(DstReg, PartRegs);
267     else
268       MIRBuilder.buildBuildVector(DstReg, PartRegs);
269     return;
270   }
271 
272   // Merge sub-vectors with different number of elements and insert into DstReg.
273   if (ResultTy.isVector()) {
274     assert(LeftoverRegs.size() == 1 && "Expected one leftover register");
275     SmallVector<Register, 8> AllRegs;
276     for (auto Reg : concat<const Register>(PartRegs, LeftoverRegs))
277       AllRegs.push_back(Reg);
278     return mergeMixedSubvectors(DstReg, AllRegs);
279   }
280 
281   SmallVector<Register> GCDRegs;
282   LLT GCDTy = getGCDType(getGCDType(ResultTy, LeftoverTy), PartTy);
283   for (auto PartReg : concat<const Register>(PartRegs, LeftoverRegs))
284     extractGCDType(GCDRegs, GCDTy, PartReg);
285   LLT ResultLCMTy = buildLCMMergePieces(ResultTy, LeftoverTy, GCDTy, GCDRegs);
286   buildWidenedRemergeToDst(DstReg, ResultLCMTy, GCDRegs);
287 }
288 
appendVectorElts(SmallVectorImpl<Register> & Elts,Register Reg)289 void LegalizerHelper::appendVectorElts(SmallVectorImpl<Register> &Elts,
290                                        Register Reg) {
291   LLT Ty = MRI.getType(Reg);
292   SmallVector<Register, 8> RegElts;
293   extractParts(Reg, Ty.getScalarType(), Ty.getNumElements(), RegElts);
294   Elts.append(RegElts);
295 }
296 
297 /// Merge \p PartRegs with different types into \p DstReg.
mergeMixedSubvectors(Register DstReg,ArrayRef<Register> PartRegs)298 void LegalizerHelper::mergeMixedSubvectors(Register DstReg,
299                                            ArrayRef<Register> PartRegs) {
300   SmallVector<Register, 8> AllElts;
301   for (unsigned i = 0; i < PartRegs.size() - 1; ++i)
302     appendVectorElts(AllElts, PartRegs[i]);
303 
304   Register Leftover = PartRegs[PartRegs.size() - 1];
305   if (MRI.getType(Leftover).isScalar())
306     AllElts.push_back(Leftover);
307   else
308     appendVectorElts(AllElts, Leftover);
309 
310   MIRBuilder.buildMergeLikeInstr(DstReg, AllElts);
311 }
312 
313 /// Append the result registers of G_UNMERGE_VALUES \p MI to \p Regs.
getUnmergeResults(SmallVectorImpl<Register> & Regs,const MachineInstr & MI)314 static void getUnmergeResults(SmallVectorImpl<Register> &Regs,
315                               const MachineInstr &MI) {
316   assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES);
317 
318   const int StartIdx = Regs.size();
319   const int NumResults = MI.getNumOperands() - 1;
320   Regs.resize(Regs.size() + NumResults);
321   for (int I = 0; I != NumResults; ++I)
322     Regs[StartIdx + I] = MI.getOperand(I).getReg();
323 }
324 
extractGCDType(SmallVectorImpl<Register> & Parts,LLT GCDTy,Register SrcReg)325 void LegalizerHelper::extractGCDType(SmallVectorImpl<Register> &Parts,
326                                      LLT GCDTy, Register SrcReg) {
327   LLT SrcTy = MRI.getType(SrcReg);
328   if (SrcTy == GCDTy) {
329     // If the source already evenly divides the result type, we don't need to do
330     // anything.
331     Parts.push_back(SrcReg);
332   } else {
333     // Need to split into common type sized pieces.
334     auto Unmerge = MIRBuilder.buildUnmerge(GCDTy, SrcReg);
335     getUnmergeResults(Parts, *Unmerge);
336   }
337 }
338 
extractGCDType(SmallVectorImpl<Register> & Parts,LLT DstTy,LLT NarrowTy,Register SrcReg)339 LLT LegalizerHelper::extractGCDType(SmallVectorImpl<Register> &Parts, LLT DstTy,
340                                     LLT NarrowTy, Register SrcReg) {
341   LLT SrcTy = MRI.getType(SrcReg);
342   LLT GCDTy = getGCDType(getGCDType(SrcTy, NarrowTy), DstTy);
343   extractGCDType(Parts, GCDTy, SrcReg);
344   return GCDTy;
345 }
346 
buildLCMMergePieces(LLT DstTy,LLT NarrowTy,LLT GCDTy,SmallVectorImpl<Register> & VRegs,unsigned PadStrategy)347 LLT LegalizerHelper::buildLCMMergePieces(LLT DstTy, LLT NarrowTy, LLT GCDTy,
348                                          SmallVectorImpl<Register> &VRegs,
349                                          unsigned PadStrategy) {
350   LLT LCMTy = getLCMType(DstTy, NarrowTy);
351 
352   int NumParts = LCMTy.getSizeInBits() / NarrowTy.getSizeInBits();
353   int NumSubParts = NarrowTy.getSizeInBits() / GCDTy.getSizeInBits();
354   int NumOrigSrc = VRegs.size();
355 
356   Register PadReg;
357 
358   // Get a value we can use to pad the source value if the sources won't evenly
359   // cover the result type.
360   if (NumOrigSrc < NumParts * NumSubParts) {
361     if (PadStrategy == TargetOpcode::G_ZEXT)
362       PadReg = MIRBuilder.buildConstant(GCDTy, 0).getReg(0);
363     else if (PadStrategy == TargetOpcode::G_ANYEXT)
364       PadReg = MIRBuilder.buildUndef(GCDTy).getReg(0);
365     else {
366       assert(PadStrategy == TargetOpcode::G_SEXT);
367 
368       // Shift the sign bit of the low register through the high register.
369       auto ShiftAmt =
370         MIRBuilder.buildConstant(LLT::scalar(64), GCDTy.getSizeInBits() - 1);
371       PadReg = MIRBuilder.buildAShr(GCDTy, VRegs.back(), ShiftAmt).getReg(0);
372     }
373   }
374 
375   // Registers for the final merge to be produced.
376   SmallVector<Register, 4> Remerge(NumParts);
377 
378   // Registers needed for intermediate merges, which will be merged into a
379   // source for Remerge.
380   SmallVector<Register, 4> SubMerge(NumSubParts);
381 
382   // Once we've fully read off the end of the original source bits, we can reuse
383   // the same high bits for remaining padding elements.
384   Register AllPadReg;
385 
386   // Build merges to the LCM type to cover the original result type.
387   for (int I = 0; I != NumParts; ++I) {
388     bool AllMergePartsArePadding = true;
389 
390     // Build the requested merges to the requested type.
391     for (int J = 0; J != NumSubParts; ++J) {
392       int Idx = I * NumSubParts + J;
393       if (Idx >= NumOrigSrc) {
394         SubMerge[J] = PadReg;
395         continue;
396       }
397 
398       SubMerge[J] = VRegs[Idx];
399 
400       // There are meaningful bits here we can't reuse later.
401       AllMergePartsArePadding = false;
402     }
403 
404     // If we've filled up a complete piece with padding bits, we can directly
405     // emit the natural sized constant if applicable, rather than a merge of
406     // smaller constants.
407     if (AllMergePartsArePadding && !AllPadReg) {
408       if (PadStrategy == TargetOpcode::G_ANYEXT)
409         AllPadReg = MIRBuilder.buildUndef(NarrowTy).getReg(0);
410       else if (PadStrategy == TargetOpcode::G_ZEXT)
411         AllPadReg = MIRBuilder.buildConstant(NarrowTy, 0).getReg(0);
412 
413       // If this is a sign extension, we can't materialize a trivial constant
414       // with the right type and have to produce a merge.
415     }
416 
417     if (AllPadReg) {
418       // Avoid creating additional instructions if we're just adding additional
419       // copies of padding bits.
420       Remerge[I] = AllPadReg;
421       continue;
422     }
423 
424     if (NumSubParts == 1)
425       Remerge[I] = SubMerge[0];
426     else
427       Remerge[I] = MIRBuilder.buildMergeLikeInstr(NarrowTy, SubMerge).getReg(0);
428 
429     // In the sign extend padding case, re-use the first all-signbit merge.
430     if (AllMergePartsArePadding && !AllPadReg)
431       AllPadReg = Remerge[I];
432   }
433 
434   VRegs = std::move(Remerge);
435   return LCMTy;
436 }
437 
buildWidenedRemergeToDst(Register DstReg,LLT LCMTy,ArrayRef<Register> RemergeRegs)438 void LegalizerHelper::buildWidenedRemergeToDst(Register DstReg, LLT LCMTy,
439                                                ArrayRef<Register> RemergeRegs) {
440   LLT DstTy = MRI.getType(DstReg);
441 
442   // Create the merge to the widened source, and extract the relevant bits into
443   // the result.
444 
445   if (DstTy == LCMTy) {
446     MIRBuilder.buildMergeLikeInstr(DstReg, RemergeRegs);
447     return;
448   }
449 
450   auto Remerge = MIRBuilder.buildMergeLikeInstr(LCMTy, RemergeRegs);
451   if (DstTy.isScalar() && LCMTy.isScalar()) {
452     MIRBuilder.buildTrunc(DstReg, Remerge);
453     return;
454   }
455 
456   if (LCMTy.isVector()) {
457     unsigned NumDefs = LCMTy.getSizeInBits() / DstTy.getSizeInBits();
458     SmallVector<Register, 8> UnmergeDefs(NumDefs);
459     UnmergeDefs[0] = DstReg;
460     for (unsigned I = 1; I != NumDefs; ++I)
461       UnmergeDefs[I] = MRI.createGenericVirtualRegister(DstTy);
462 
463     MIRBuilder.buildUnmerge(UnmergeDefs,
464                             MIRBuilder.buildMergeLikeInstr(LCMTy, RemergeRegs));
465     return;
466   }
467 
468   llvm_unreachable("unhandled case");
469 }
470 
getRTLibDesc(unsigned Opcode,unsigned Size)471 static RTLIB::Libcall getRTLibDesc(unsigned Opcode, unsigned Size) {
472 #define RTLIBCASE_INT(LibcallPrefix)                                           \
473   do {                                                                         \
474     switch (Size) {                                                            \
475     case 32:                                                                   \
476       return RTLIB::LibcallPrefix##32;                                         \
477     case 64:                                                                   \
478       return RTLIB::LibcallPrefix##64;                                         \
479     case 128:                                                                  \
480       return RTLIB::LibcallPrefix##128;                                        \
481     default:                                                                   \
482       llvm_unreachable("unexpected size");                                     \
483     }                                                                          \
484   } while (0)
485 
486 #define RTLIBCASE(LibcallPrefix)                                               \
487   do {                                                                         \
488     switch (Size) {                                                            \
489     case 32:                                                                   \
490       return RTLIB::LibcallPrefix##32;                                         \
491     case 64:                                                                   \
492       return RTLIB::LibcallPrefix##64;                                         \
493     case 80:                                                                   \
494       return RTLIB::LibcallPrefix##80;                                         \
495     case 128:                                                                  \
496       return RTLIB::LibcallPrefix##128;                                        \
497     default:                                                                   \
498       llvm_unreachable("unexpected size");                                     \
499     }                                                                          \
500   } while (0)
501 
502   switch (Opcode) {
503   case TargetOpcode::G_MUL:
504     RTLIBCASE_INT(MUL_I);
505   case TargetOpcode::G_SDIV:
506     RTLIBCASE_INT(SDIV_I);
507   case TargetOpcode::G_UDIV:
508     RTLIBCASE_INT(UDIV_I);
509   case TargetOpcode::G_SREM:
510     RTLIBCASE_INT(SREM_I);
511   case TargetOpcode::G_UREM:
512     RTLIBCASE_INT(UREM_I);
513   case TargetOpcode::G_CTLZ_ZERO_UNDEF:
514     RTLIBCASE_INT(CTLZ_I);
515   case TargetOpcode::G_FADD:
516     RTLIBCASE(ADD_F);
517   case TargetOpcode::G_FSUB:
518     RTLIBCASE(SUB_F);
519   case TargetOpcode::G_FMUL:
520     RTLIBCASE(MUL_F);
521   case TargetOpcode::G_FDIV:
522     RTLIBCASE(DIV_F);
523   case TargetOpcode::G_FEXP:
524     RTLIBCASE(EXP_F);
525   case TargetOpcode::G_FEXP2:
526     RTLIBCASE(EXP2_F);
527   case TargetOpcode::G_FREM:
528     RTLIBCASE(REM_F);
529   case TargetOpcode::G_FPOW:
530     RTLIBCASE(POW_F);
531   case TargetOpcode::G_FMA:
532     RTLIBCASE(FMA_F);
533   case TargetOpcode::G_FSIN:
534     RTLIBCASE(SIN_F);
535   case TargetOpcode::G_FCOS:
536     RTLIBCASE(COS_F);
537   case TargetOpcode::G_FLOG10:
538     RTLIBCASE(LOG10_F);
539   case TargetOpcode::G_FLOG:
540     RTLIBCASE(LOG_F);
541   case TargetOpcode::G_FLOG2:
542     RTLIBCASE(LOG2_F);
543   case TargetOpcode::G_FCEIL:
544     RTLIBCASE(CEIL_F);
545   case TargetOpcode::G_FFLOOR:
546     RTLIBCASE(FLOOR_F);
547   case TargetOpcode::G_FMINNUM:
548     RTLIBCASE(FMIN_F);
549   case TargetOpcode::G_FMAXNUM:
550     RTLIBCASE(FMAX_F);
551   case TargetOpcode::G_FSQRT:
552     RTLIBCASE(SQRT_F);
553   case TargetOpcode::G_FRINT:
554     RTLIBCASE(RINT_F);
555   case TargetOpcode::G_FNEARBYINT:
556     RTLIBCASE(NEARBYINT_F);
557   case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
558     RTLIBCASE(ROUNDEVEN_F);
559   }
560   llvm_unreachable("Unknown libcall function");
561 }
562 
563 /// True if an instruction is in tail position in its caller. Intended for
564 /// legalizing libcalls as tail calls when possible.
isLibCallInTailPosition(MachineInstr & MI,const TargetInstrInfo & TII,MachineRegisterInfo & MRI)565 static bool isLibCallInTailPosition(MachineInstr &MI,
566                                     const TargetInstrInfo &TII,
567                                     MachineRegisterInfo &MRI) {
568   MachineBasicBlock &MBB = *MI.getParent();
569   const Function &F = MBB.getParent()->getFunction();
570 
571   // Conservatively require the attributes of the call to match those of
572   // the return. Ignore NoAlias and NonNull because they don't affect the
573   // call sequence.
574   AttributeList CallerAttrs = F.getAttributes();
575   if (AttrBuilder(F.getContext(), CallerAttrs.getRetAttrs())
576           .removeAttribute(Attribute::NoAlias)
577           .removeAttribute(Attribute::NonNull)
578           .hasAttributes())
579     return false;
580 
581   // It's not safe to eliminate the sign / zero extension of the return value.
582   if (CallerAttrs.hasRetAttr(Attribute::ZExt) ||
583       CallerAttrs.hasRetAttr(Attribute::SExt))
584     return false;
585 
586   // Only tail call if the following instruction is a standard return or if we
587   // have a `thisreturn` callee, and a sequence like:
588   //
589   //   G_MEMCPY %0, %1, %2
590   //   $x0 = COPY %0
591   //   RET_ReallyLR implicit $x0
592   auto Next = next_nodbg(MI.getIterator(), MBB.instr_end());
593   if (Next != MBB.instr_end() && Next->isCopy()) {
594     switch (MI.getOpcode()) {
595     default:
596       llvm_unreachable("unsupported opcode");
597     case TargetOpcode::G_BZERO:
598       return false;
599     case TargetOpcode::G_MEMCPY:
600     case TargetOpcode::G_MEMMOVE:
601     case TargetOpcode::G_MEMSET:
602       break;
603     }
604 
605     Register VReg = MI.getOperand(0).getReg();
606     if (!VReg.isVirtual() || VReg != Next->getOperand(1).getReg())
607       return false;
608 
609     Register PReg = Next->getOperand(0).getReg();
610     if (!PReg.isPhysical())
611       return false;
612 
613     auto Ret = next_nodbg(Next, MBB.instr_end());
614     if (Ret == MBB.instr_end() || !Ret->isReturn())
615       return false;
616 
617     if (Ret->getNumImplicitOperands() != 1)
618       return false;
619 
620     if (PReg != Ret->getOperand(0).getReg())
621       return false;
622 
623     // Skip over the COPY that we just validated.
624     Next = Ret;
625   }
626 
627   if (Next == MBB.instr_end() || TII.isTailCall(*Next) || !Next->isReturn())
628     return false;
629 
630   return true;
631 }
632 
633 LegalizerHelper::LegalizeResult
createLibcall(MachineIRBuilder & MIRBuilder,const char * Name,const CallLowering::ArgInfo & Result,ArrayRef<CallLowering::ArgInfo> Args,const CallingConv::ID CC)634 llvm::createLibcall(MachineIRBuilder &MIRBuilder, const char *Name,
635                     const CallLowering::ArgInfo &Result,
636                     ArrayRef<CallLowering::ArgInfo> Args,
637                     const CallingConv::ID CC) {
638   auto &CLI = *MIRBuilder.getMF().getSubtarget().getCallLowering();
639 
640   CallLowering::CallLoweringInfo Info;
641   Info.CallConv = CC;
642   Info.Callee = MachineOperand::CreateES(Name);
643   Info.OrigRet = Result;
644   std::copy(Args.begin(), Args.end(), std::back_inserter(Info.OrigArgs));
645   if (!CLI.lowerCall(MIRBuilder, Info))
646     return LegalizerHelper::UnableToLegalize;
647 
648   return LegalizerHelper::Legalized;
649 }
650 
651 LegalizerHelper::LegalizeResult
createLibcall(MachineIRBuilder & MIRBuilder,RTLIB::Libcall Libcall,const CallLowering::ArgInfo & Result,ArrayRef<CallLowering::ArgInfo> Args)652 llvm::createLibcall(MachineIRBuilder &MIRBuilder, RTLIB::Libcall Libcall,
653                     const CallLowering::ArgInfo &Result,
654                     ArrayRef<CallLowering::ArgInfo> Args) {
655   auto &TLI = *MIRBuilder.getMF().getSubtarget().getTargetLowering();
656   const char *Name = TLI.getLibcallName(Libcall);
657   const CallingConv::ID CC = TLI.getLibcallCallingConv(Libcall);
658   return createLibcall(MIRBuilder, Name, Result, Args, CC);
659 }
660 
661 // Useful for libcalls where all operands have the same type.
662 static LegalizerHelper::LegalizeResult
simpleLibcall(MachineInstr & MI,MachineIRBuilder & MIRBuilder,unsigned Size,Type * OpType)663 simpleLibcall(MachineInstr &MI, MachineIRBuilder &MIRBuilder, unsigned Size,
664               Type *OpType) {
665   auto Libcall = getRTLibDesc(MI.getOpcode(), Size);
666 
667   // FIXME: What does the original arg index mean here?
668   SmallVector<CallLowering::ArgInfo, 3> Args;
669   for (const MachineOperand &MO : llvm::drop_begin(MI.operands()))
670     Args.push_back({MO.getReg(), OpType, 0});
671   return createLibcall(MIRBuilder, Libcall,
672                        {MI.getOperand(0).getReg(), OpType, 0}, Args);
673 }
674 
675 LegalizerHelper::LegalizeResult
createMemLibcall(MachineIRBuilder & MIRBuilder,MachineRegisterInfo & MRI,MachineInstr & MI,LostDebugLocObserver & LocObserver)676 llvm::createMemLibcall(MachineIRBuilder &MIRBuilder, MachineRegisterInfo &MRI,
677                        MachineInstr &MI, LostDebugLocObserver &LocObserver) {
678   auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
679 
680   SmallVector<CallLowering::ArgInfo, 3> Args;
681   // Add all the args, except for the last which is an imm denoting 'tail'.
682   for (unsigned i = 0; i < MI.getNumOperands() - 1; ++i) {
683     Register Reg = MI.getOperand(i).getReg();
684 
685     // Need derive an IR type for call lowering.
686     LLT OpLLT = MRI.getType(Reg);
687     Type *OpTy = nullptr;
688     if (OpLLT.isPointer())
689       OpTy = Type::getInt8PtrTy(Ctx, OpLLT.getAddressSpace());
690     else
691       OpTy = IntegerType::get(Ctx, OpLLT.getSizeInBits());
692     Args.push_back({Reg, OpTy, 0});
693   }
694 
695   auto &CLI = *MIRBuilder.getMF().getSubtarget().getCallLowering();
696   auto &TLI = *MIRBuilder.getMF().getSubtarget().getTargetLowering();
697   RTLIB::Libcall RTLibcall;
698   unsigned Opc = MI.getOpcode();
699   switch (Opc) {
700   case TargetOpcode::G_BZERO:
701     RTLibcall = RTLIB::BZERO;
702     break;
703   case TargetOpcode::G_MEMCPY:
704     RTLibcall = RTLIB::MEMCPY;
705     Args[0].Flags[0].setReturned();
706     break;
707   case TargetOpcode::G_MEMMOVE:
708     RTLibcall = RTLIB::MEMMOVE;
709     Args[0].Flags[0].setReturned();
710     break;
711   case TargetOpcode::G_MEMSET:
712     RTLibcall = RTLIB::MEMSET;
713     Args[0].Flags[0].setReturned();
714     break;
715   default:
716     llvm_unreachable("unsupported opcode");
717   }
718   const char *Name = TLI.getLibcallName(RTLibcall);
719 
720   // Unsupported libcall on the target.
721   if (!Name) {
722     LLVM_DEBUG(dbgs() << ".. .. Could not find libcall name for "
723                       << MIRBuilder.getTII().getName(Opc) << "\n");
724     return LegalizerHelper::UnableToLegalize;
725   }
726 
727   CallLowering::CallLoweringInfo Info;
728   Info.CallConv = TLI.getLibcallCallingConv(RTLibcall);
729   Info.Callee = MachineOperand::CreateES(Name);
730   Info.OrigRet = CallLowering::ArgInfo({0}, Type::getVoidTy(Ctx), 0);
731   Info.IsTailCall = MI.getOperand(MI.getNumOperands() - 1).getImm() &&
732                     isLibCallInTailPosition(MI, MIRBuilder.getTII(), MRI);
733 
734   std::copy(Args.begin(), Args.end(), std::back_inserter(Info.OrigArgs));
735   if (!CLI.lowerCall(MIRBuilder, Info))
736     return LegalizerHelper::UnableToLegalize;
737 
738   if (Info.LoweredTailCall) {
739     assert(Info.IsTailCall && "Lowered tail call when it wasn't a tail call?");
740 
741     // Check debug locations before removing the return.
742     LocObserver.checkpoint(true);
743 
744     // We must have a return following the call (or debug insts) to get past
745     // isLibCallInTailPosition.
746     do {
747       MachineInstr *Next = MI.getNextNode();
748       assert(Next &&
749              (Next->isCopy() || Next->isReturn() || Next->isDebugInstr()) &&
750              "Expected instr following MI to be return or debug inst?");
751       // We lowered a tail call, so the call is now the return from the block.
752       // Delete the old return.
753       Next->eraseFromParent();
754     } while (MI.getNextNode());
755 
756     // We expect to lose the debug location from the return.
757     LocObserver.checkpoint(false);
758   }
759 
760   return LegalizerHelper::Legalized;
761 }
762 
getConvRTLibDesc(unsigned Opcode,Type * ToType,Type * FromType)763 static RTLIB::Libcall getConvRTLibDesc(unsigned Opcode, Type *ToType,
764                                        Type *FromType) {
765   auto ToMVT = MVT::getVT(ToType);
766   auto FromMVT = MVT::getVT(FromType);
767 
768   switch (Opcode) {
769   case TargetOpcode::G_FPEXT:
770     return RTLIB::getFPEXT(FromMVT, ToMVT);
771   case TargetOpcode::G_FPTRUNC:
772     return RTLIB::getFPROUND(FromMVT, ToMVT);
773   case TargetOpcode::G_FPTOSI:
774     return RTLIB::getFPTOSINT(FromMVT, ToMVT);
775   case TargetOpcode::G_FPTOUI:
776     return RTLIB::getFPTOUINT(FromMVT, ToMVT);
777   case TargetOpcode::G_SITOFP:
778     return RTLIB::getSINTTOFP(FromMVT, ToMVT);
779   case TargetOpcode::G_UITOFP:
780     return RTLIB::getUINTTOFP(FromMVT, ToMVT);
781   }
782   llvm_unreachable("Unsupported libcall function");
783 }
784 
785 static LegalizerHelper::LegalizeResult
conversionLibcall(MachineInstr & MI,MachineIRBuilder & MIRBuilder,Type * ToType,Type * FromType)786 conversionLibcall(MachineInstr &MI, MachineIRBuilder &MIRBuilder, Type *ToType,
787                   Type *FromType) {
788   RTLIB::Libcall Libcall = getConvRTLibDesc(MI.getOpcode(), ToType, FromType);
789   return createLibcall(MIRBuilder, Libcall,
790                        {MI.getOperand(0).getReg(), ToType, 0},
791                        {{MI.getOperand(1).getReg(), FromType, 0}});
792 }
793 
794 LegalizerHelper::LegalizeResult
libcall(MachineInstr & MI,LostDebugLocObserver & LocObserver)795 LegalizerHelper::libcall(MachineInstr &MI, LostDebugLocObserver &LocObserver) {
796   LLT LLTy = MRI.getType(MI.getOperand(0).getReg());
797   unsigned Size = LLTy.getSizeInBits();
798   auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
799 
800   switch (MI.getOpcode()) {
801   default:
802     return UnableToLegalize;
803   case TargetOpcode::G_MUL:
804   case TargetOpcode::G_SDIV:
805   case TargetOpcode::G_UDIV:
806   case TargetOpcode::G_SREM:
807   case TargetOpcode::G_UREM:
808   case TargetOpcode::G_CTLZ_ZERO_UNDEF: {
809     Type *HLTy = IntegerType::get(Ctx, Size);
810     auto Status = simpleLibcall(MI, MIRBuilder, Size, HLTy);
811     if (Status != Legalized)
812       return Status;
813     break;
814   }
815   case TargetOpcode::G_FADD:
816   case TargetOpcode::G_FSUB:
817   case TargetOpcode::G_FMUL:
818   case TargetOpcode::G_FDIV:
819   case TargetOpcode::G_FMA:
820   case TargetOpcode::G_FPOW:
821   case TargetOpcode::G_FREM:
822   case TargetOpcode::G_FCOS:
823   case TargetOpcode::G_FSIN:
824   case TargetOpcode::G_FLOG10:
825   case TargetOpcode::G_FLOG:
826   case TargetOpcode::G_FLOG2:
827   case TargetOpcode::G_FEXP:
828   case TargetOpcode::G_FEXP2:
829   case TargetOpcode::G_FCEIL:
830   case TargetOpcode::G_FFLOOR:
831   case TargetOpcode::G_FMINNUM:
832   case TargetOpcode::G_FMAXNUM:
833   case TargetOpcode::G_FSQRT:
834   case TargetOpcode::G_FRINT:
835   case TargetOpcode::G_FNEARBYINT:
836   case TargetOpcode::G_INTRINSIC_ROUNDEVEN: {
837     Type *HLTy = getFloatTypeForLLT(Ctx, LLTy);
838     if (!HLTy || (Size != 32 && Size != 64 && Size != 80 && Size != 128)) {
839       LLVM_DEBUG(dbgs() << "No libcall available for type " << LLTy << ".\n");
840       return UnableToLegalize;
841     }
842     auto Status = simpleLibcall(MI, MIRBuilder, Size, HLTy);
843     if (Status != Legalized)
844       return Status;
845     break;
846   }
847   case TargetOpcode::G_FPEXT:
848   case TargetOpcode::G_FPTRUNC: {
849     Type *FromTy = getFloatTypeForLLT(Ctx,  MRI.getType(MI.getOperand(1).getReg()));
850     Type *ToTy = getFloatTypeForLLT(Ctx, MRI.getType(MI.getOperand(0).getReg()));
851     if (!FromTy || !ToTy)
852       return UnableToLegalize;
853     LegalizeResult Status = conversionLibcall(MI, MIRBuilder, ToTy, FromTy );
854     if (Status != Legalized)
855       return Status;
856     break;
857   }
858   case TargetOpcode::G_FPTOSI:
859   case TargetOpcode::G_FPTOUI: {
860     // FIXME: Support other types
861     unsigned FromSize = MRI.getType(MI.getOperand(1).getReg()).getSizeInBits();
862     unsigned ToSize = MRI.getType(MI.getOperand(0).getReg()).getSizeInBits();
863     if ((ToSize != 32 && ToSize != 64) || (FromSize != 32 && FromSize != 64))
864       return UnableToLegalize;
865     LegalizeResult Status = conversionLibcall(
866         MI, MIRBuilder,
867         ToSize == 32 ? Type::getInt32Ty(Ctx) : Type::getInt64Ty(Ctx),
868         FromSize == 64 ? Type::getDoubleTy(Ctx) : Type::getFloatTy(Ctx));
869     if (Status != Legalized)
870       return Status;
871     break;
872   }
873   case TargetOpcode::G_SITOFP:
874   case TargetOpcode::G_UITOFP: {
875     // FIXME: Support other types
876     unsigned FromSize = MRI.getType(MI.getOperand(1).getReg()).getSizeInBits();
877     unsigned ToSize = MRI.getType(MI.getOperand(0).getReg()).getSizeInBits();
878     if ((FromSize != 32 && FromSize != 64) || (ToSize != 32 && ToSize != 64))
879       return UnableToLegalize;
880     LegalizeResult Status = conversionLibcall(
881         MI, MIRBuilder,
882         ToSize == 64 ? Type::getDoubleTy(Ctx) : Type::getFloatTy(Ctx),
883         FromSize == 32 ? Type::getInt32Ty(Ctx) : Type::getInt64Ty(Ctx));
884     if (Status != Legalized)
885       return Status;
886     break;
887   }
888   case TargetOpcode::G_BZERO:
889   case TargetOpcode::G_MEMCPY:
890   case TargetOpcode::G_MEMMOVE:
891   case TargetOpcode::G_MEMSET: {
892     LegalizeResult Result =
893         createMemLibcall(MIRBuilder, *MIRBuilder.getMRI(), MI, LocObserver);
894     if (Result != Legalized)
895       return Result;
896     MI.eraseFromParent();
897     return Result;
898   }
899   }
900 
901   MI.eraseFromParent();
902   return Legalized;
903 }
904 
narrowScalar(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)905 LegalizerHelper::LegalizeResult LegalizerHelper::narrowScalar(MachineInstr &MI,
906                                                               unsigned TypeIdx,
907                                                               LLT NarrowTy) {
908   uint64_t SizeOp0 = MRI.getType(MI.getOperand(0).getReg()).getSizeInBits();
909   uint64_t NarrowSize = NarrowTy.getSizeInBits();
910 
911   switch (MI.getOpcode()) {
912   default:
913     return UnableToLegalize;
914   case TargetOpcode::G_IMPLICIT_DEF: {
915     Register DstReg = MI.getOperand(0).getReg();
916     LLT DstTy = MRI.getType(DstReg);
917 
918     // If SizeOp0 is not an exact multiple of NarrowSize, emit
919     // G_ANYEXT(G_IMPLICIT_DEF). Cast result to vector if needed.
920     // FIXME: Although this would also be legal for the general case, it causes
921     //  a lot of regressions in the emitted code (superfluous COPYs, artifact
922     //  combines not being hit). This seems to be a problem related to the
923     //  artifact combiner.
924     if (SizeOp0 % NarrowSize != 0) {
925       LLT ImplicitTy = NarrowTy;
926       if (DstTy.isVector())
927         ImplicitTy = LLT::vector(DstTy.getElementCount(), ImplicitTy);
928 
929       Register ImplicitReg = MIRBuilder.buildUndef(ImplicitTy).getReg(0);
930       MIRBuilder.buildAnyExt(DstReg, ImplicitReg);
931 
932       MI.eraseFromParent();
933       return Legalized;
934     }
935 
936     int NumParts = SizeOp0 / NarrowSize;
937 
938     SmallVector<Register, 2> DstRegs;
939     for (int i = 0; i < NumParts; ++i)
940       DstRegs.push_back(MIRBuilder.buildUndef(NarrowTy).getReg(0));
941 
942     if (DstTy.isVector())
943       MIRBuilder.buildBuildVector(DstReg, DstRegs);
944     else
945       MIRBuilder.buildMergeLikeInstr(DstReg, DstRegs);
946     MI.eraseFromParent();
947     return Legalized;
948   }
949   case TargetOpcode::G_CONSTANT: {
950     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
951     const APInt &Val = MI.getOperand(1).getCImm()->getValue();
952     unsigned TotalSize = Ty.getSizeInBits();
953     unsigned NarrowSize = NarrowTy.getSizeInBits();
954     int NumParts = TotalSize / NarrowSize;
955 
956     SmallVector<Register, 4> PartRegs;
957     for (int I = 0; I != NumParts; ++I) {
958       unsigned Offset = I * NarrowSize;
959       auto K = MIRBuilder.buildConstant(NarrowTy,
960                                         Val.lshr(Offset).trunc(NarrowSize));
961       PartRegs.push_back(K.getReg(0));
962     }
963 
964     LLT LeftoverTy;
965     unsigned LeftoverBits = TotalSize - NumParts * NarrowSize;
966     SmallVector<Register, 1> LeftoverRegs;
967     if (LeftoverBits != 0) {
968       LeftoverTy = LLT::scalar(LeftoverBits);
969       auto K = MIRBuilder.buildConstant(
970         LeftoverTy,
971         Val.lshr(NumParts * NarrowSize).trunc(LeftoverBits));
972       LeftoverRegs.push_back(K.getReg(0));
973     }
974 
975     insertParts(MI.getOperand(0).getReg(),
976                 Ty, NarrowTy, PartRegs, LeftoverTy, LeftoverRegs);
977 
978     MI.eraseFromParent();
979     return Legalized;
980   }
981   case TargetOpcode::G_SEXT:
982   case TargetOpcode::G_ZEXT:
983   case TargetOpcode::G_ANYEXT:
984     return narrowScalarExt(MI, TypeIdx, NarrowTy);
985   case TargetOpcode::G_TRUNC: {
986     if (TypeIdx != 1)
987       return UnableToLegalize;
988 
989     uint64_t SizeOp1 = MRI.getType(MI.getOperand(1).getReg()).getSizeInBits();
990     if (NarrowTy.getSizeInBits() * 2 != SizeOp1) {
991       LLVM_DEBUG(dbgs() << "Can't narrow trunc to type " << NarrowTy << "\n");
992       return UnableToLegalize;
993     }
994 
995     auto Unmerge = MIRBuilder.buildUnmerge(NarrowTy, MI.getOperand(1));
996     MIRBuilder.buildCopy(MI.getOperand(0), Unmerge.getReg(0));
997     MI.eraseFromParent();
998     return Legalized;
999   }
1000 
1001   case TargetOpcode::G_FREEZE: {
1002     if (TypeIdx != 0)
1003       return UnableToLegalize;
1004 
1005     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
1006     // Should widen scalar first
1007     if (Ty.getSizeInBits() % NarrowTy.getSizeInBits() != 0)
1008       return UnableToLegalize;
1009 
1010     auto Unmerge = MIRBuilder.buildUnmerge(NarrowTy, MI.getOperand(1).getReg());
1011     SmallVector<Register, 8> Parts;
1012     for (unsigned i = 0; i < Unmerge->getNumDefs(); ++i) {
1013       Parts.push_back(
1014           MIRBuilder.buildFreeze(NarrowTy, Unmerge.getReg(i)).getReg(0));
1015     }
1016 
1017     MIRBuilder.buildMergeLikeInstr(MI.getOperand(0).getReg(), Parts);
1018     MI.eraseFromParent();
1019     return Legalized;
1020   }
1021   case TargetOpcode::G_ADD:
1022   case TargetOpcode::G_SUB:
1023   case TargetOpcode::G_SADDO:
1024   case TargetOpcode::G_SSUBO:
1025   case TargetOpcode::G_SADDE:
1026   case TargetOpcode::G_SSUBE:
1027   case TargetOpcode::G_UADDO:
1028   case TargetOpcode::G_USUBO:
1029   case TargetOpcode::G_UADDE:
1030   case TargetOpcode::G_USUBE:
1031     return narrowScalarAddSub(MI, TypeIdx, NarrowTy);
1032   case TargetOpcode::G_MUL:
1033   case TargetOpcode::G_UMULH:
1034     return narrowScalarMul(MI, NarrowTy);
1035   case TargetOpcode::G_EXTRACT:
1036     return narrowScalarExtract(MI, TypeIdx, NarrowTy);
1037   case TargetOpcode::G_INSERT:
1038     return narrowScalarInsert(MI, TypeIdx, NarrowTy);
1039   case TargetOpcode::G_LOAD: {
1040     auto &LoadMI = cast<GLoad>(MI);
1041     Register DstReg = LoadMI.getDstReg();
1042     LLT DstTy = MRI.getType(DstReg);
1043     if (DstTy.isVector())
1044       return UnableToLegalize;
1045 
1046     if (8 * LoadMI.getMemSize() != DstTy.getSizeInBits()) {
1047       Register TmpReg = MRI.createGenericVirtualRegister(NarrowTy);
1048       MIRBuilder.buildLoad(TmpReg, LoadMI.getPointerReg(), LoadMI.getMMO());
1049       MIRBuilder.buildAnyExt(DstReg, TmpReg);
1050       LoadMI.eraseFromParent();
1051       return Legalized;
1052     }
1053 
1054     return reduceLoadStoreWidth(LoadMI, TypeIdx, NarrowTy);
1055   }
1056   case TargetOpcode::G_ZEXTLOAD:
1057   case TargetOpcode::G_SEXTLOAD: {
1058     auto &LoadMI = cast<GExtLoad>(MI);
1059     Register DstReg = LoadMI.getDstReg();
1060     Register PtrReg = LoadMI.getPointerReg();
1061 
1062     Register TmpReg = MRI.createGenericVirtualRegister(NarrowTy);
1063     auto &MMO = LoadMI.getMMO();
1064     unsigned MemSize = MMO.getSizeInBits();
1065 
1066     if (MemSize == NarrowSize) {
1067       MIRBuilder.buildLoad(TmpReg, PtrReg, MMO);
1068     } else if (MemSize < NarrowSize) {
1069       MIRBuilder.buildLoadInstr(LoadMI.getOpcode(), TmpReg, PtrReg, MMO);
1070     } else if (MemSize > NarrowSize) {
1071       // FIXME: Need to split the load.
1072       return UnableToLegalize;
1073     }
1074 
1075     if (isa<GZExtLoad>(LoadMI))
1076       MIRBuilder.buildZExt(DstReg, TmpReg);
1077     else
1078       MIRBuilder.buildSExt(DstReg, TmpReg);
1079 
1080     LoadMI.eraseFromParent();
1081     return Legalized;
1082   }
1083   case TargetOpcode::G_STORE: {
1084     auto &StoreMI = cast<GStore>(MI);
1085 
1086     Register SrcReg = StoreMI.getValueReg();
1087     LLT SrcTy = MRI.getType(SrcReg);
1088     if (SrcTy.isVector())
1089       return UnableToLegalize;
1090 
1091     int NumParts = SizeOp0 / NarrowSize;
1092     unsigned HandledSize = NumParts * NarrowTy.getSizeInBits();
1093     unsigned LeftoverBits = SrcTy.getSizeInBits() - HandledSize;
1094     if (SrcTy.isVector() && LeftoverBits != 0)
1095       return UnableToLegalize;
1096 
1097     if (8 * StoreMI.getMemSize() != SrcTy.getSizeInBits()) {
1098       Register TmpReg = MRI.createGenericVirtualRegister(NarrowTy);
1099       MIRBuilder.buildTrunc(TmpReg, SrcReg);
1100       MIRBuilder.buildStore(TmpReg, StoreMI.getPointerReg(), StoreMI.getMMO());
1101       StoreMI.eraseFromParent();
1102       return Legalized;
1103     }
1104 
1105     return reduceLoadStoreWidth(StoreMI, 0, NarrowTy);
1106   }
1107   case TargetOpcode::G_SELECT:
1108     return narrowScalarSelect(MI, TypeIdx, NarrowTy);
1109   case TargetOpcode::G_AND:
1110   case TargetOpcode::G_OR:
1111   case TargetOpcode::G_XOR: {
1112     // Legalize bitwise operation:
1113     // A = BinOp<Ty> B, C
1114     // into:
1115     // B1, ..., BN = G_UNMERGE_VALUES B
1116     // C1, ..., CN = G_UNMERGE_VALUES C
1117     // A1 = BinOp<Ty/N> B1, C2
1118     // ...
1119     // AN = BinOp<Ty/N> BN, CN
1120     // A = G_MERGE_VALUES A1, ..., AN
1121     return narrowScalarBasic(MI, TypeIdx, NarrowTy);
1122   }
1123   case TargetOpcode::G_SHL:
1124   case TargetOpcode::G_LSHR:
1125   case TargetOpcode::G_ASHR:
1126     return narrowScalarShift(MI, TypeIdx, NarrowTy);
1127   case TargetOpcode::G_CTLZ:
1128   case TargetOpcode::G_CTLZ_ZERO_UNDEF:
1129   case TargetOpcode::G_CTTZ:
1130   case TargetOpcode::G_CTTZ_ZERO_UNDEF:
1131   case TargetOpcode::G_CTPOP:
1132     if (TypeIdx == 1)
1133       switch (MI.getOpcode()) {
1134       case TargetOpcode::G_CTLZ:
1135       case TargetOpcode::G_CTLZ_ZERO_UNDEF:
1136         return narrowScalarCTLZ(MI, TypeIdx, NarrowTy);
1137       case TargetOpcode::G_CTTZ:
1138       case TargetOpcode::G_CTTZ_ZERO_UNDEF:
1139         return narrowScalarCTTZ(MI, TypeIdx, NarrowTy);
1140       case TargetOpcode::G_CTPOP:
1141         return narrowScalarCTPOP(MI, TypeIdx, NarrowTy);
1142       default:
1143         return UnableToLegalize;
1144       }
1145 
1146     Observer.changingInstr(MI);
1147     narrowScalarDst(MI, NarrowTy, 0, TargetOpcode::G_ZEXT);
1148     Observer.changedInstr(MI);
1149     return Legalized;
1150   case TargetOpcode::G_INTTOPTR:
1151     if (TypeIdx != 1)
1152       return UnableToLegalize;
1153 
1154     Observer.changingInstr(MI);
1155     narrowScalarSrc(MI, NarrowTy, 1);
1156     Observer.changedInstr(MI);
1157     return Legalized;
1158   case TargetOpcode::G_PTRTOINT:
1159     if (TypeIdx != 0)
1160       return UnableToLegalize;
1161 
1162     Observer.changingInstr(MI);
1163     narrowScalarDst(MI, NarrowTy, 0, TargetOpcode::G_ZEXT);
1164     Observer.changedInstr(MI);
1165     return Legalized;
1166   case TargetOpcode::G_PHI: {
1167     // FIXME: add support for when SizeOp0 isn't an exact multiple of
1168     // NarrowSize.
1169     if (SizeOp0 % NarrowSize != 0)
1170       return UnableToLegalize;
1171 
1172     unsigned NumParts = SizeOp0 / NarrowSize;
1173     SmallVector<Register, 2> DstRegs(NumParts);
1174     SmallVector<SmallVector<Register, 2>, 2> SrcRegs(MI.getNumOperands() / 2);
1175     Observer.changingInstr(MI);
1176     for (unsigned i = 1; i < MI.getNumOperands(); i += 2) {
1177       MachineBasicBlock &OpMBB = *MI.getOperand(i + 1).getMBB();
1178       MIRBuilder.setInsertPt(OpMBB, OpMBB.getFirstTerminatorForward());
1179       extractParts(MI.getOperand(i).getReg(), NarrowTy, NumParts,
1180                    SrcRegs[i / 2]);
1181     }
1182     MachineBasicBlock &MBB = *MI.getParent();
1183     MIRBuilder.setInsertPt(MBB, MI);
1184     for (unsigned i = 0; i < NumParts; ++i) {
1185       DstRegs[i] = MRI.createGenericVirtualRegister(NarrowTy);
1186       MachineInstrBuilder MIB =
1187           MIRBuilder.buildInstr(TargetOpcode::G_PHI).addDef(DstRegs[i]);
1188       for (unsigned j = 1; j < MI.getNumOperands(); j += 2)
1189         MIB.addUse(SrcRegs[j / 2][i]).add(MI.getOperand(j + 1));
1190     }
1191     MIRBuilder.setInsertPt(MBB, MBB.getFirstNonPHI());
1192     MIRBuilder.buildMergeLikeInstr(MI.getOperand(0), DstRegs);
1193     Observer.changedInstr(MI);
1194     MI.eraseFromParent();
1195     return Legalized;
1196   }
1197   case TargetOpcode::G_EXTRACT_VECTOR_ELT:
1198   case TargetOpcode::G_INSERT_VECTOR_ELT: {
1199     if (TypeIdx != 2)
1200       return UnableToLegalize;
1201 
1202     int OpIdx = MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3;
1203     Observer.changingInstr(MI);
1204     narrowScalarSrc(MI, NarrowTy, OpIdx);
1205     Observer.changedInstr(MI);
1206     return Legalized;
1207   }
1208   case TargetOpcode::G_ICMP: {
1209     Register LHS = MI.getOperand(2).getReg();
1210     LLT SrcTy = MRI.getType(LHS);
1211     uint64_t SrcSize = SrcTy.getSizeInBits();
1212     CmpInst::Predicate Pred =
1213         static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
1214 
1215     // TODO: Handle the non-equality case for weird sizes.
1216     if (NarrowSize * 2 != SrcSize && !ICmpInst::isEquality(Pred))
1217       return UnableToLegalize;
1218 
1219     LLT LeftoverTy; // Example: s88 -> s64 (NarrowTy) + s24 (leftover)
1220     SmallVector<Register, 4> LHSPartRegs, LHSLeftoverRegs;
1221     if (!extractParts(LHS, SrcTy, NarrowTy, LeftoverTy, LHSPartRegs,
1222                       LHSLeftoverRegs))
1223       return UnableToLegalize;
1224 
1225     LLT Unused; // Matches LeftoverTy; G_ICMP LHS and RHS are the same type.
1226     SmallVector<Register, 4> RHSPartRegs, RHSLeftoverRegs;
1227     if (!extractParts(MI.getOperand(3).getReg(), SrcTy, NarrowTy, Unused,
1228                       RHSPartRegs, RHSLeftoverRegs))
1229       return UnableToLegalize;
1230 
1231     // We now have the LHS and RHS of the compare split into narrow-type
1232     // registers, plus potentially some leftover type.
1233     Register Dst = MI.getOperand(0).getReg();
1234     LLT ResTy = MRI.getType(Dst);
1235     if (ICmpInst::isEquality(Pred)) {
1236       // For each part on the LHS and RHS, keep track of the result of XOR-ing
1237       // them together. For each equal part, the result should be all 0s. For
1238       // each non-equal part, we'll get at least one 1.
1239       auto Zero = MIRBuilder.buildConstant(NarrowTy, 0);
1240       SmallVector<Register, 4> Xors;
1241       for (auto LHSAndRHS : zip(LHSPartRegs, RHSPartRegs)) {
1242         auto LHS = std::get<0>(LHSAndRHS);
1243         auto RHS = std::get<1>(LHSAndRHS);
1244         auto Xor = MIRBuilder.buildXor(NarrowTy, LHS, RHS).getReg(0);
1245         Xors.push_back(Xor);
1246       }
1247 
1248       // Build a G_XOR for each leftover register. Each G_XOR must be widened
1249       // to the desired narrow type so that we can OR them together later.
1250       SmallVector<Register, 4> WidenedXors;
1251       for (auto LHSAndRHS : zip(LHSLeftoverRegs, RHSLeftoverRegs)) {
1252         auto LHS = std::get<0>(LHSAndRHS);
1253         auto RHS = std::get<1>(LHSAndRHS);
1254         auto Xor = MIRBuilder.buildXor(LeftoverTy, LHS, RHS).getReg(0);
1255         LLT GCDTy = extractGCDType(WidenedXors, NarrowTy, LeftoverTy, Xor);
1256         buildLCMMergePieces(LeftoverTy, NarrowTy, GCDTy, WidenedXors,
1257                             /* PadStrategy = */ TargetOpcode::G_ZEXT);
1258         Xors.insert(Xors.end(), WidenedXors.begin(), WidenedXors.end());
1259       }
1260 
1261       // Now, for each part we broke up, we know if they are equal/not equal
1262       // based off the G_XOR. We can OR these all together and compare against
1263       // 0 to get the result.
1264       assert(Xors.size() >= 2 && "Should have gotten at least two Xors?");
1265       auto Or = MIRBuilder.buildOr(NarrowTy, Xors[0], Xors[1]);
1266       for (unsigned I = 2, E = Xors.size(); I < E; ++I)
1267         Or = MIRBuilder.buildOr(NarrowTy, Or, Xors[I]);
1268       MIRBuilder.buildICmp(Pred, Dst, Or, Zero);
1269     } else {
1270       // TODO: Handle non-power-of-two types.
1271       assert(LHSPartRegs.size() == 2 && "Expected exactly 2 LHS part regs?");
1272       assert(RHSPartRegs.size() == 2 && "Expected exactly 2 RHS part regs?");
1273       Register LHSL = LHSPartRegs[0];
1274       Register LHSH = LHSPartRegs[1];
1275       Register RHSL = RHSPartRegs[0];
1276       Register RHSH = RHSPartRegs[1];
1277       MachineInstrBuilder CmpH = MIRBuilder.buildICmp(Pred, ResTy, LHSH, RHSH);
1278       MachineInstrBuilder CmpHEQ =
1279           MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, ResTy, LHSH, RHSH);
1280       MachineInstrBuilder CmpLU = MIRBuilder.buildICmp(
1281           ICmpInst::getUnsignedPredicate(Pred), ResTy, LHSL, RHSL);
1282       MIRBuilder.buildSelect(Dst, CmpHEQ, CmpLU, CmpH);
1283     }
1284     MI.eraseFromParent();
1285     return Legalized;
1286   }
1287   case TargetOpcode::G_SEXT_INREG: {
1288     if (TypeIdx != 0)
1289       return UnableToLegalize;
1290 
1291     int64_t SizeInBits = MI.getOperand(2).getImm();
1292 
1293     // So long as the new type has more bits than the bits we're extending we
1294     // don't need to break it apart.
1295     if (NarrowTy.getScalarSizeInBits() >= SizeInBits) {
1296       Observer.changingInstr(MI);
1297       // We don't lose any non-extension bits by truncating the src and
1298       // sign-extending the dst.
1299       MachineOperand &MO1 = MI.getOperand(1);
1300       auto TruncMIB = MIRBuilder.buildTrunc(NarrowTy, MO1);
1301       MO1.setReg(TruncMIB.getReg(0));
1302 
1303       MachineOperand &MO2 = MI.getOperand(0);
1304       Register DstExt = MRI.createGenericVirtualRegister(NarrowTy);
1305       MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
1306       MIRBuilder.buildSExt(MO2, DstExt);
1307       MO2.setReg(DstExt);
1308       Observer.changedInstr(MI);
1309       return Legalized;
1310     }
1311 
1312     // Break it apart. Components below the extension point are unmodified. The
1313     // component containing the extension point becomes a narrower SEXT_INREG.
1314     // Components above it are ashr'd from the component containing the
1315     // extension point.
1316     if (SizeOp0 % NarrowSize != 0)
1317       return UnableToLegalize;
1318     int NumParts = SizeOp0 / NarrowSize;
1319 
1320     // List the registers where the destination will be scattered.
1321     SmallVector<Register, 2> DstRegs;
1322     // List the registers where the source will be split.
1323     SmallVector<Register, 2> SrcRegs;
1324 
1325     // Create all the temporary registers.
1326     for (int i = 0; i < NumParts; ++i) {
1327       Register SrcReg = MRI.createGenericVirtualRegister(NarrowTy);
1328 
1329       SrcRegs.push_back(SrcReg);
1330     }
1331 
1332     // Explode the big arguments into smaller chunks.
1333     MIRBuilder.buildUnmerge(SrcRegs, MI.getOperand(1));
1334 
1335     Register AshrCstReg =
1336         MIRBuilder.buildConstant(NarrowTy, NarrowTy.getScalarSizeInBits() - 1)
1337             .getReg(0);
1338     Register FullExtensionReg = 0;
1339     Register PartialExtensionReg = 0;
1340 
1341     // Do the operation on each small part.
1342     for (int i = 0; i < NumParts; ++i) {
1343       if ((i + 1) * NarrowTy.getScalarSizeInBits() < SizeInBits)
1344         DstRegs.push_back(SrcRegs[i]);
1345       else if (i * NarrowTy.getScalarSizeInBits() > SizeInBits) {
1346         assert(PartialExtensionReg &&
1347                "Expected to visit partial extension before full");
1348         if (FullExtensionReg) {
1349           DstRegs.push_back(FullExtensionReg);
1350           continue;
1351         }
1352         DstRegs.push_back(
1353             MIRBuilder.buildAShr(NarrowTy, PartialExtensionReg, AshrCstReg)
1354                 .getReg(0));
1355         FullExtensionReg = DstRegs.back();
1356       } else {
1357         DstRegs.push_back(
1358             MIRBuilder
1359                 .buildInstr(
1360                     TargetOpcode::G_SEXT_INREG, {NarrowTy},
1361                     {SrcRegs[i], SizeInBits % NarrowTy.getScalarSizeInBits()})
1362                 .getReg(0));
1363         PartialExtensionReg = DstRegs.back();
1364       }
1365     }
1366 
1367     // Gather the destination registers into the final destination.
1368     Register DstReg = MI.getOperand(0).getReg();
1369     MIRBuilder.buildMergeLikeInstr(DstReg, DstRegs);
1370     MI.eraseFromParent();
1371     return Legalized;
1372   }
1373   case TargetOpcode::G_BSWAP:
1374   case TargetOpcode::G_BITREVERSE: {
1375     if (SizeOp0 % NarrowSize != 0)
1376       return UnableToLegalize;
1377 
1378     Observer.changingInstr(MI);
1379     SmallVector<Register, 2> SrcRegs, DstRegs;
1380     unsigned NumParts = SizeOp0 / NarrowSize;
1381     extractParts(MI.getOperand(1).getReg(), NarrowTy, NumParts, SrcRegs);
1382 
1383     for (unsigned i = 0; i < NumParts; ++i) {
1384       auto DstPart = MIRBuilder.buildInstr(MI.getOpcode(), {NarrowTy},
1385                                            {SrcRegs[NumParts - 1 - i]});
1386       DstRegs.push_back(DstPart.getReg(0));
1387     }
1388 
1389     MIRBuilder.buildMergeLikeInstr(MI.getOperand(0), DstRegs);
1390 
1391     Observer.changedInstr(MI);
1392     MI.eraseFromParent();
1393     return Legalized;
1394   }
1395   case TargetOpcode::G_PTR_ADD:
1396   case TargetOpcode::G_PTRMASK: {
1397     if (TypeIdx != 1)
1398       return UnableToLegalize;
1399     Observer.changingInstr(MI);
1400     narrowScalarSrc(MI, NarrowTy, 2);
1401     Observer.changedInstr(MI);
1402     return Legalized;
1403   }
1404   case TargetOpcode::G_FPTOUI:
1405   case TargetOpcode::G_FPTOSI:
1406     return narrowScalarFPTOI(MI, TypeIdx, NarrowTy);
1407   case TargetOpcode::G_FPEXT:
1408     if (TypeIdx != 0)
1409       return UnableToLegalize;
1410     Observer.changingInstr(MI);
1411     narrowScalarDst(MI, NarrowTy, 0, TargetOpcode::G_FPEXT);
1412     Observer.changedInstr(MI);
1413     return Legalized;
1414   }
1415 }
1416 
coerceToScalar(Register Val)1417 Register LegalizerHelper::coerceToScalar(Register Val) {
1418   LLT Ty = MRI.getType(Val);
1419   if (Ty.isScalar())
1420     return Val;
1421 
1422   const DataLayout &DL = MIRBuilder.getDataLayout();
1423   LLT NewTy = LLT::scalar(Ty.getSizeInBits());
1424   if (Ty.isPointer()) {
1425     if (DL.isNonIntegralAddressSpace(Ty.getAddressSpace()))
1426       return Register();
1427     return MIRBuilder.buildPtrToInt(NewTy, Val).getReg(0);
1428   }
1429 
1430   Register NewVal = Val;
1431 
1432   assert(Ty.isVector());
1433   LLT EltTy = Ty.getElementType();
1434   if (EltTy.isPointer())
1435     NewVal = MIRBuilder.buildPtrToInt(NewTy, NewVal).getReg(0);
1436   return MIRBuilder.buildBitcast(NewTy, NewVal).getReg(0);
1437 }
1438 
widenScalarSrc(MachineInstr & MI,LLT WideTy,unsigned OpIdx,unsigned ExtOpcode)1439 void LegalizerHelper::widenScalarSrc(MachineInstr &MI, LLT WideTy,
1440                                      unsigned OpIdx, unsigned ExtOpcode) {
1441   MachineOperand &MO = MI.getOperand(OpIdx);
1442   auto ExtB = MIRBuilder.buildInstr(ExtOpcode, {WideTy}, {MO});
1443   MO.setReg(ExtB.getReg(0));
1444 }
1445 
narrowScalarSrc(MachineInstr & MI,LLT NarrowTy,unsigned OpIdx)1446 void LegalizerHelper::narrowScalarSrc(MachineInstr &MI, LLT NarrowTy,
1447                                       unsigned OpIdx) {
1448   MachineOperand &MO = MI.getOperand(OpIdx);
1449   auto ExtB = MIRBuilder.buildTrunc(NarrowTy, MO);
1450   MO.setReg(ExtB.getReg(0));
1451 }
1452 
widenScalarDst(MachineInstr & MI,LLT WideTy,unsigned OpIdx,unsigned TruncOpcode)1453 void LegalizerHelper::widenScalarDst(MachineInstr &MI, LLT WideTy,
1454                                      unsigned OpIdx, unsigned TruncOpcode) {
1455   MachineOperand &MO = MI.getOperand(OpIdx);
1456   Register DstExt = MRI.createGenericVirtualRegister(WideTy);
1457   MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
1458   MIRBuilder.buildInstr(TruncOpcode, {MO}, {DstExt});
1459   MO.setReg(DstExt);
1460 }
1461 
narrowScalarDst(MachineInstr & MI,LLT NarrowTy,unsigned OpIdx,unsigned ExtOpcode)1462 void LegalizerHelper::narrowScalarDst(MachineInstr &MI, LLT NarrowTy,
1463                                       unsigned OpIdx, unsigned ExtOpcode) {
1464   MachineOperand &MO = MI.getOperand(OpIdx);
1465   Register DstTrunc = MRI.createGenericVirtualRegister(NarrowTy);
1466   MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
1467   MIRBuilder.buildInstr(ExtOpcode, {MO}, {DstTrunc});
1468   MO.setReg(DstTrunc);
1469 }
1470 
moreElementsVectorDst(MachineInstr & MI,LLT WideTy,unsigned OpIdx)1471 void LegalizerHelper::moreElementsVectorDst(MachineInstr &MI, LLT WideTy,
1472                                             unsigned OpIdx) {
1473   MachineOperand &MO = MI.getOperand(OpIdx);
1474   MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
1475   Register Dst = MO.getReg();
1476   Register DstExt = MRI.createGenericVirtualRegister(WideTy);
1477   MO.setReg(DstExt);
1478   MIRBuilder.buildDeleteTrailingVectorElements(Dst, DstExt);
1479 }
1480 
moreElementsVectorSrc(MachineInstr & MI,LLT MoreTy,unsigned OpIdx)1481 void LegalizerHelper::moreElementsVectorSrc(MachineInstr &MI, LLT MoreTy,
1482                                             unsigned OpIdx) {
1483   MachineOperand &MO = MI.getOperand(OpIdx);
1484   SmallVector<Register, 8> Regs;
1485   MO.setReg(MIRBuilder.buildPadVectorWithUndefElements(MoreTy, MO).getReg(0));
1486 }
1487 
bitcastSrc(MachineInstr & MI,LLT CastTy,unsigned OpIdx)1488 void LegalizerHelper::bitcastSrc(MachineInstr &MI, LLT CastTy, unsigned OpIdx) {
1489   MachineOperand &Op = MI.getOperand(OpIdx);
1490   Op.setReg(MIRBuilder.buildBitcast(CastTy, Op).getReg(0));
1491 }
1492 
bitcastDst(MachineInstr & MI,LLT CastTy,unsigned OpIdx)1493 void LegalizerHelper::bitcastDst(MachineInstr &MI, LLT CastTy, unsigned OpIdx) {
1494   MachineOperand &MO = MI.getOperand(OpIdx);
1495   Register CastDst = MRI.createGenericVirtualRegister(CastTy);
1496   MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
1497   MIRBuilder.buildBitcast(MO, CastDst);
1498   MO.setReg(CastDst);
1499 }
1500 
1501 LegalizerHelper::LegalizeResult
widenScalarMergeValues(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)1502 LegalizerHelper::widenScalarMergeValues(MachineInstr &MI, unsigned TypeIdx,
1503                                         LLT WideTy) {
1504   if (TypeIdx != 1)
1505     return UnableToLegalize;
1506 
1507   Register DstReg = MI.getOperand(0).getReg();
1508   LLT DstTy = MRI.getType(DstReg);
1509   if (DstTy.isVector())
1510     return UnableToLegalize;
1511 
1512   Register Src1 = MI.getOperand(1).getReg();
1513   LLT SrcTy = MRI.getType(Src1);
1514   const int DstSize = DstTy.getSizeInBits();
1515   const int SrcSize = SrcTy.getSizeInBits();
1516   const int WideSize = WideTy.getSizeInBits();
1517   const int NumMerge = (DstSize + WideSize - 1) / WideSize;
1518 
1519   unsigned NumOps = MI.getNumOperands();
1520   unsigned NumSrc = MI.getNumOperands() - 1;
1521   unsigned PartSize = DstTy.getSizeInBits() / NumSrc;
1522 
1523   if (WideSize >= DstSize) {
1524     // Directly pack the bits in the target type.
1525     Register ResultReg = MIRBuilder.buildZExt(WideTy, Src1).getReg(0);
1526 
1527     for (unsigned I = 2; I != NumOps; ++I) {
1528       const unsigned Offset = (I - 1) * PartSize;
1529 
1530       Register SrcReg = MI.getOperand(I).getReg();
1531       assert(MRI.getType(SrcReg) == LLT::scalar(PartSize));
1532 
1533       auto ZextInput = MIRBuilder.buildZExt(WideTy, SrcReg);
1534 
1535       Register NextResult = I + 1 == NumOps && WideTy == DstTy ? DstReg :
1536         MRI.createGenericVirtualRegister(WideTy);
1537 
1538       auto ShiftAmt = MIRBuilder.buildConstant(WideTy, Offset);
1539       auto Shl = MIRBuilder.buildShl(WideTy, ZextInput, ShiftAmt);
1540       MIRBuilder.buildOr(NextResult, ResultReg, Shl);
1541       ResultReg = NextResult;
1542     }
1543 
1544     if (WideSize > DstSize)
1545       MIRBuilder.buildTrunc(DstReg, ResultReg);
1546     else if (DstTy.isPointer())
1547       MIRBuilder.buildIntToPtr(DstReg, ResultReg);
1548 
1549     MI.eraseFromParent();
1550     return Legalized;
1551   }
1552 
1553   // Unmerge the original values to the GCD type, and recombine to the next
1554   // multiple greater than the original type.
1555   //
1556   // %3:_(s12) = G_MERGE_VALUES %0:_(s4), %1:_(s4), %2:_(s4) -> s6
1557   // %4:_(s2), %5:_(s2) = G_UNMERGE_VALUES %0
1558   // %6:_(s2), %7:_(s2) = G_UNMERGE_VALUES %1
1559   // %8:_(s2), %9:_(s2) = G_UNMERGE_VALUES %2
1560   // %10:_(s6) = G_MERGE_VALUES %4, %5, %6
1561   // %11:_(s6) = G_MERGE_VALUES %7, %8, %9
1562   // %12:_(s12) = G_MERGE_VALUES %10, %11
1563   //
1564   // Padding with undef if necessary:
1565   //
1566   // %2:_(s8) = G_MERGE_VALUES %0:_(s4), %1:_(s4) -> s6
1567   // %3:_(s2), %4:_(s2) = G_UNMERGE_VALUES %0
1568   // %5:_(s2), %6:_(s2) = G_UNMERGE_VALUES %1
1569   // %7:_(s2) = G_IMPLICIT_DEF
1570   // %8:_(s6) = G_MERGE_VALUES %3, %4, %5
1571   // %9:_(s6) = G_MERGE_VALUES %6, %7, %7
1572   // %10:_(s12) = G_MERGE_VALUES %8, %9
1573 
1574   const int GCD = std::gcd(SrcSize, WideSize);
1575   LLT GCDTy = LLT::scalar(GCD);
1576 
1577   SmallVector<Register, 8> Parts;
1578   SmallVector<Register, 8> NewMergeRegs;
1579   SmallVector<Register, 8> Unmerges;
1580   LLT WideDstTy = LLT::scalar(NumMerge * WideSize);
1581 
1582   // Decompose the original operands if they don't evenly divide.
1583   for (const MachineOperand &MO : llvm::drop_begin(MI.operands())) {
1584     Register SrcReg = MO.getReg();
1585     if (GCD == SrcSize) {
1586       Unmerges.push_back(SrcReg);
1587     } else {
1588       auto Unmerge = MIRBuilder.buildUnmerge(GCDTy, SrcReg);
1589       for (int J = 0, JE = Unmerge->getNumOperands() - 1; J != JE; ++J)
1590         Unmerges.push_back(Unmerge.getReg(J));
1591     }
1592   }
1593 
1594   // Pad with undef to the next size that is a multiple of the requested size.
1595   if (static_cast<int>(Unmerges.size()) != NumMerge * WideSize) {
1596     Register UndefReg = MIRBuilder.buildUndef(GCDTy).getReg(0);
1597     for (int I = Unmerges.size(); I != NumMerge * WideSize; ++I)
1598       Unmerges.push_back(UndefReg);
1599   }
1600 
1601   const int PartsPerGCD = WideSize / GCD;
1602 
1603   // Build merges of each piece.
1604   ArrayRef<Register> Slicer(Unmerges);
1605   for (int I = 0; I != NumMerge; ++I, Slicer = Slicer.drop_front(PartsPerGCD)) {
1606     auto Merge =
1607         MIRBuilder.buildMergeLikeInstr(WideTy, Slicer.take_front(PartsPerGCD));
1608     NewMergeRegs.push_back(Merge.getReg(0));
1609   }
1610 
1611   // A truncate may be necessary if the requested type doesn't evenly divide the
1612   // original result type.
1613   if (DstTy.getSizeInBits() == WideDstTy.getSizeInBits()) {
1614     MIRBuilder.buildMergeLikeInstr(DstReg, NewMergeRegs);
1615   } else {
1616     auto FinalMerge = MIRBuilder.buildMergeLikeInstr(WideDstTy, NewMergeRegs);
1617     MIRBuilder.buildTrunc(DstReg, FinalMerge.getReg(0));
1618   }
1619 
1620   MI.eraseFromParent();
1621   return Legalized;
1622 }
1623 
1624 LegalizerHelper::LegalizeResult
widenScalarUnmergeValues(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)1625 LegalizerHelper::widenScalarUnmergeValues(MachineInstr &MI, unsigned TypeIdx,
1626                                           LLT WideTy) {
1627   if (TypeIdx != 0)
1628     return UnableToLegalize;
1629 
1630   int NumDst = MI.getNumOperands() - 1;
1631   Register SrcReg = MI.getOperand(NumDst).getReg();
1632   LLT SrcTy = MRI.getType(SrcReg);
1633   if (SrcTy.isVector())
1634     return UnableToLegalize;
1635 
1636   Register Dst0Reg = MI.getOperand(0).getReg();
1637   LLT DstTy = MRI.getType(Dst0Reg);
1638   if (!DstTy.isScalar())
1639     return UnableToLegalize;
1640 
1641   if (WideTy.getSizeInBits() >= SrcTy.getSizeInBits()) {
1642     if (SrcTy.isPointer()) {
1643       const DataLayout &DL = MIRBuilder.getDataLayout();
1644       if (DL.isNonIntegralAddressSpace(SrcTy.getAddressSpace())) {
1645         LLVM_DEBUG(
1646             dbgs() << "Not casting non-integral address space integer\n");
1647         return UnableToLegalize;
1648       }
1649 
1650       SrcTy = LLT::scalar(SrcTy.getSizeInBits());
1651       SrcReg = MIRBuilder.buildPtrToInt(SrcTy, SrcReg).getReg(0);
1652     }
1653 
1654     // Widen SrcTy to WideTy. This does not affect the result, but since the
1655     // user requested this size, it is probably better handled than SrcTy and
1656     // should reduce the total number of legalization artifacts.
1657     if (WideTy.getSizeInBits() > SrcTy.getSizeInBits()) {
1658       SrcTy = WideTy;
1659       SrcReg = MIRBuilder.buildAnyExt(WideTy, SrcReg).getReg(0);
1660     }
1661 
1662     // Theres no unmerge type to target. Directly extract the bits from the
1663     // source type
1664     unsigned DstSize = DstTy.getSizeInBits();
1665 
1666     MIRBuilder.buildTrunc(Dst0Reg, SrcReg);
1667     for (int I = 1; I != NumDst; ++I) {
1668       auto ShiftAmt = MIRBuilder.buildConstant(SrcTy, DstSize * I);
1669       auto Shr = MIRBuilder.buildLShr(SrcTy, SrcReg, ShiftAmt);
1670       MIRBuilder.buildTrunc(MI.getOperand(I), Shr);
1671     }
1672 
1673     MI.eraseFromParent();
1674     return Legalized;
1675   }
1676 
1677   // Extend the source to a wider type.
1678   LLT LCMTy = getLCMType(SrcTy, WideTy);
1679 
1680   Register WideSrc = SrcReg;
1681   if (LCMTy.getSizeInBits() != SrcTy.getSizeInBits()) {
1682     // TODO: If this is an integral address space, cast to integer and anyext.
1683     if (SrcTy.isPointer()) {
1684       LLVM_DEBUG(dbgs() << "Widening pointer source types not implemented\n");
1685       return UnableToLegalize;
1686     }
1687 
1688     WideSrc = MIRBuilder.buildAnyExt(LCMTy, WideSrc).getReg(0);
1689   }
1690 
1691   auto Unmerge = MIRBuilder.buildUnmerge(WideTy, WideSrc);
1692 
1693   // Create a sequence of unmerges and merges to the original results. Since we
1694   // may have widened the source, we will need to pad the results with dead defs
1695   // to cover the source register.
1696   // e.g. widen s48 to s64:
1697   // %1:_(s48), %2:_(s48) = G_UNMERGE_VALUES %0:_(s96)
1698   //
1699   // =>
1700   //  %4:_(s192) = G_ANYEXT %0:_(s96)
1701   //  %5:_(s64), %6, %7 = G_UNMERGE_VALUES %4 ; Requested unmerge
1702   //  ; unpack to GCD type, with extra dead defs
1703   //  %8:_(s16), %9, %10, %11 = G_UNMERGE_VALUES %5:_(s64)
1704   //  %12:_(s16), %13, dead %14, dead %15 = G_UNMERGE_VALUES %6:_(s64)
1705   //  dead %16:_(s16), dead %17, dead %18, dead %18 = G_UNMERGE_VALUES %7:_(s64)
1706   //  %1:_(s48) = G_MERGE_VALUES %8:_(s16), %9, %10   ; Remerge to destination
1707   //  %2:_(s48) = G_MERGE_VALUES %11:_(s16), %12, %13 ; Remerge to destination
1708   const LLT GCDTy = getGCDType(WideTy, DstTy);
1709   const int NumUnmerge = Unmerge->getNumOperands() - 1;
1710   const int PartsPerRemerge = DstTy.getSizeInBits() / GCDTy.getSizeInBits();
1711 
1712   // Directly unmerge to the destination without going through a GCD type
1713   // if possible
1714   if (PartsPerRemerge == 1) {
1715     const int PartsPerUnmerge = WideTy.getSizeInBits() / DstTy.getSizeInBits();
1716 
1717     for (int I = 0; I != NumUnmerge; ++I) {
1718       auto MIB = MIRBuilder.buildInstr(TargetOpcode::G_UNMERGE_VALUES);
1719 
1720       for (int J = 0; J != PartsPerUnmerge; ++J) {
1721         int Idx = I * PartsPerUnmerge + J;
1722         if (Idx < NumDst)
1723           MIB.addDef(MI.getOperand(Idx).getReg());
1724         else {
1725           // Create dead def for excess components.
1726           MIB.addDef(MRI.createGenericVirtualRegister(DstTy));
1727         }
1728       }
1729 
1730       MIB.addUse(Unmerge.getReg(I));
1731     }
1732   } else {
1733     SmallVector<Register, 16> Parts;
1734     for (int J = 0; J != NumUnmerge; ++J)
1735       extractGCDType(Parts, GCDTy, Unmerge.getReg(J));
1736 
1737     SmallVector<Register, 8> RemergeParts;
1738     for (int I = 0; I != NumDst; ++I) {
1739       for (int J = 0; J < PartsPerRemerge; ++J) {
1740         const int Idx = I * PartsPerRemerge + J;
1741         RemergeParts.emplace_back(Parts[Idx]);
1742       }
1743 
1744       MIRBuilder.buildMergeLikeInstr(MI.getOperand(I).getReg(), RemergeParts);
1745       RemergeParts.clear();
1746     }
1747   }
1748 
1749   MI.eraseFromParent();
1750   return Legalized;
1751 }
1752 
1753 LegalizerHelper::LegalizeResult
widenScalarExtract(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)1754 LegalizerHelper::widenScalarExtract(MachineInstr &MI, unsigned TypeIdx,
1755                                     LLT WideTy) {
1756   Register DstReg = MI.getOperand(0).getReg();
1757   Register SrcReg = MI.getOperand(1).getReg();
1758   LLT SrcTy = MRI.getType(SrcReg);
1759 
1760   LLT DstTy = MRI.getType(DstReg);
1761   unsigned Offset = MI.getOperand(2).getImm();
1762 
1763   if (TypeIdx == 0) {
1764     if (SrcTy.isVector() || DstTy.isVector())
1765       return UnableToLegalize;
1766 
1767     SrcOp Src(SrcReg);
1768     if (SrcTy.isPointer()) {
1769       // Extracts from pointers can be handled only if they are really just
1770       // simple integers.
1771       const DataLayout &DL = MIRBuilder.getDataLayout();
1772       if (DL.isNonIntegralAddressSpace(SrcTy.getAddressSpace()))
1773         return UnableToLegalize;
1774 
1775       LLT SrcAsIntTy = LLT::scalar(SrcTy.getSizeInBits());
1776       Src = MIRBuilder.buildPtrToInt(SrcAsIntTy, Src);
1777       SrcTy = SrcAsIntTy;
1778     }
1779 
1780     if (DstTy.isPointer())
1781       return UnableToLegalize;
1782 
1783     if (Offset == 0) {
1784       // Avoid a shift in the degenerate case.
1785       MIRBuilder.buildTrunc(DstReg,
1786                             MIRBuilder.buildAnyExtOrTrunc(WideTy, Src));
1787       MI.eraseFromParent();
1788       return Legalized;
1789     }
1790 
1791     // Do a shift in the source type.
1792     LLT ShiftTy = SrcTy;
1793     if (WideTy.getSizeInBits() > SrcTy.getSizeInBits()) {
1794       Src = MIRBuilder.buildAnyExt(WideTy, Src);
1795       ShiftTy = WideTy;
1796     }
1797 
1798     auto LShr = MIRBuilder.buildLShr(
1799       ShiftTy, Src, MIRBuilder.buildConstant(ShiftTy, Offset));
1800     MIRBuilder.buildTrunc(DstReg, LShr);
1801     MI.eraseFromParent();
1802     return Legalized;
1803   }
1804 
1805   if (SrcTy.isScalar()) {
1806     Observer.changingInstr(MI);
1807     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
1808     Observer.changedInstr(MI);
1809     return Legalized;
1810   }
1811 
1812   if (!SrcTy.isVector())
1813     return UnableToLegalize;
1814 
1815   if (DstTy != SrcTy.getElementType())
1816     return UnableToLegalize;
1817 
1818   if (Offset % SrcTy.getScalarSizeInBits() != 0)
1819     return UnableToLegalize;
1820 
1821   Observer.changingInstr(MI);
1822   widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
1823 
1824   MI.getOperand(2).setImm((WideTy.getSizeInBits() / SrcTy.getSizeInBits()) *
1825                           Offset);
1826   widenScalarDst(MI, WideTy.getScalarType(), 0);
1827   Observer.changedInstr(MI);
1828   return Legalized;
1829 }
1830 
1831 LegalizerHelper::LegalizeResult
widenScalarInsert(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)1832 LegalizerHelper::widenScalarInsert(MachineInstr &MI, unsigned TypeIdx,
1833                                    LLT WideTy) {
1834   if (TypeIdx != 0 || WideTy.isVector())
1835     return UnableToLegalize;
1836   Observer.changingInstr(MI);
1837   widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
1838   widenScalarDst(MI, WideTy);
1839   Observer.changedInstr(MI);
1840   return Legalized;
1841 }
1842 
1843 LegalizerHelper::LegalizeResult
widenScalarAddSubOverflow(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)1844 LegalizerHelper::widenScalarAddSubOverflow(MachineInstr &MI, unsigned TypeIdx,
1845                                            LLT WideTy) {
1846   unsigned Opcode;
1847   unsigned ExtOpcode;
1848   std::optional<Register> CarryIn;
1849   switch (MI.getOpcode()) {
1850   default:
1851     llvm_unreachable("Unexpected opcode!");
1852   case TargetOpcode::G_SADDO:
1853     Opcode = TargetOpcode::G_ADD;
1854     ExtOpcode = TargetOpcode::G_SEXT;
1855     break;
1856   case TargetOpcode::G_SSUBO:
1857     Opcode = TargetOpcode::G_SUB;
1858     ExtOpcode = TargetOpcode::G_SEXT;
1859     break;
1860   case TargetOpcode::G_UADDO:
1861     Opcode = TargetOpcode::G_ADD;
1862     ExtOpcode = TargetOpcode::G_ZEXT;
1863     break;
1864   case TargetOpcode::G_USUBO:
1865     Opcode = TargetOpcode::G_SUB;
1866     ExtOpcode = TargetOpcode::G_ZEXT;
1867     break;
1868   case TargetOpcode::G_SADDE:
1869     Opcode = TargetOpcode::G_UADDE;
1870     ExtOpcode = TargetOpcode::G_SEXT;
1871     CarryIn = MI.getOperand(4).getReg();
1872     break;
1873   case TargetOpcode::G_SSUBE:
1874     Opcode = TargetOpcode::G_USUBE;
1875     ExtOpcode = TargetOpcode::G_SEXT;
1876     CarryIn = MI.getOperand(4).getReg();
1877     break;
1878   case TargetOpcode::G_UADDE:
1879     Opcode = TargetOpcode::G_UADDE;
1880     ExtOpcode = TargetOpcode::G_ZEXT;
1881     CarryIn = MI.getOperand(4).getReg();
1882     break;
1883   case TargetOpcode::G_USUBE:
1884     Opcode = TargetOpcode::G_USUBE;
1885     ExtOpcode = TargetOpcode::G_ZEXT;
1886     CarryIn = MI.getOperand(4).getReg();
1887     break;
1888   }
1889 
1890   if (TypeIdx == 1) {
1891     unsigned BoolExtOp = MIRBuilder.getBoolExtOp(WideTy.isVector(), false);
1892 
1893     Observer.changingInstr(MI);
1894     if (CarryIn)
1895       widenScalarSrc(MI, WideTy, 4, BoolExtOp);
1896     widenScalarDst(MI, WideTy, 1);
1897 
1898     Observer.changedInstr(MI);
1899     return Legalized;
1900   }
1901 
1902   auto LHSExt = MIRBuilder.buildInstr(ExtOpcode, {WideTy}, {MI.getOperand(2)});
1903   auto RHSExt = MIRBuilder.buildInstr(ExtOpcode, {WideTy}, {MI.getOperand(3)});
1904   // Do the arithmetic in the larger type.
1905   Register NewOp;
1906   if (CarryIn) {
1907     LLT CarryOutTy = MRI.getType(MI.getOperand(1).getReg());
1908     NewOp = MIRBuilder
1909                 .buildInstr(Opcode, {WideTy, CarryOutTy},
1910                             {LHSExt, RHSExt, *CarryIn})
1911                 .getReg(0);
1912   } else {
1913     NewOp = MIRBuilder.buildInstr(Opcode, {WideTy}, {LHSExt, RHSExt}).getReg(0);
1914   }
1915   LLT OrigTy = MRI.getType(MI.getOperand(0).getReg());
1916   auto TruncOp = MIRBuilder.buildTrunc(OrigTy, NewOp);
1917   auto ExtOp = MIRBuilder.buildInstr(ExtOpcode, {WideTy}, {TruncOp});
1918   // There is no overflow if the ExtOp is the same as NewOp.
1919   MIRBuilder.buildICmp(CmpInst::ICMP_NE, MI.getOperand(1), NewOp, ExtOp);
1920   // Now trunc the NewOp to the original result.
1921   MIRBuilder.buildTrunc(MI.getOperand(0), NewOp);
1922   MI.eraseFromParent();
1923   return Legalized;
1924 }
1925 
1926 LegalizerHelper::LegalizeResult
widenScalarAddSubShlSat(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)1927 LegalizerHelper::widenScalarAddSubShlSat(MachineInstr &MI, unsigned TypeIdx,
1928                                          LLT WideTy) {
1929   bool IsSigned = MI.getOpcode() == TargetOpcode::G_SADDSAT ||
1930                   MI.getOpcode() == TargetOpcode::G_SSUBSAT ||
1931                   MI.getOpcode() == TargetOpcode::G_SSHLSAT;
1932   bool IsShift = MI.getOpcode() == TargetOpcode::G_SSHLSAT ||
1933                  MI.getOpcode() == TargetOpcode::G_USHLSAT;
1934   // We can convert this to:
1935   //   1. Any extend iN to iM
1936   //   2. SHL by M-N
1937   //   3. [US][ADD|SUB|SHL]SAT
1938   //   4. L/ASHR by M-N
1939   //
1940   // It may be more efficient to lower this to a min and a max operation in
1941   // the higher precision arithmetic if the promoted operation isn't legal,
1942   // but this decision is up to the target's lowering request.
1943   Register DstReg = MI.getOperand(0).getReg();
1944 
1945   unsigned NewBits = WideTy.getScalarSizeInBits();
1946   unsigned SHLAmount = NewBits - MRI.getType(DstReg).getScalarSizeInBits();
1947 
1948   // Shifts must zero-extend the RHS to preserve the unsigned quantity, and
1949   // must not left shift the RHS to preserve the shift amount.
1950   auto LHS = MIRBuilder.buildAnyExt(WideTy, MI.getOperand(1));
1951   auto RHS = IsShift ? MIRBuilder.buildZExt(WideTy, MI.getOperand(2))
1952                      : MIRBuilder.buildAnyExt(WideTy, MI.getOperand(2));
1953   auto ShiftK = MIRBuilder.buildConstant(WideTy, SHLAmount);
1954   auto ShiftL = MIRBuilder.buildShl(WideTy, LHS, ShiftK);
1955   auto ShiftR = IsShift ? RHS : MIRBuilder.buildShl(WideTy, RHS, ShiftK);
1956 
1957   auto WideInst = MIRBuilder.buildInstr(MI.getOpcode(), {WideTy},
1958                                         {ShiftL, ShiftR}, MI.getFlags());
1959 
1960   // Use a shift that will preserve the number of sign bits when the trunc is
1961   // folded away.
1962   auto Result = IsSigned ? MIRBuilder.buildAShr(WideTy, WideInst, ShiftK)
1963                          : MIRBuilder.buildLShr(WideTy, WideInst, ShiftK);
1964 
1965   MIRBuilder.buildTrunc(DstReg, Result);
1966   MI.eraseFromParent();
1967   return Legalized;
1968 }
1969 
1970 LegalizerHelper::LegalizeResult
widenScalarMulo(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)1971 LegalizerHelper::widenScalarMulo(MachineInstr &MI, unsigned TypeIdx,
1972                                  LLT WideTy) {
1973   if (TypeIdx == 1) {
1974     Observer.changingInstr(MI);
1975     widenScalarDst(MI, WideTy, 1);
1976     Observer.changedInstr(MI);
1977     return Legalized;
1978   }
1979 
1980   bool IsSigned = MI.getOpcode() == TargetOpcode::G_SMULO;
1981   Register Result = MI.getOperand(0).getReg();
1982   Register OriginalOverflow = MI.getOperand(1).getReg();
1983   Register LHS = MI.getOperand(2).getReg();
1984   Register RHS = MI.getOperand(3).getReg();
1985   LLT SrcTy = MRI.getType(LHS);
1986   LLT OverflowTy = MRI.getType(OriginalOverflow);
1987   unsigned SrcBitWidth = SrcTy.getScalarSizeInBits();
1988 
1989   // To determine if the result overflowed in the larger type, we extend the
1990   // input to the larger type, do the multiply (checking if it overflows),
1991   // then also check the high bits of the result to see if overflow happened
1992   // there.
1993   unsigned ExtOp = IsSigned ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
1994   auto LeftOperand = MIRBuilder.buildInstr(ExtOp, {WideTy}, {LHS});
1995   auto RightOperand = MIRBuilder.buildInstr(ExtOp, {WideTy}, {RHS});
1996 
1997   auto Mulo = MIRBuilder.buildInstr(MI.getOpcode(), {WideTy, OverflowTy},
1998                                     {LeftOperand, RightOperand});
1999   auto Mul = Mulo->getOperand(0);
2000   MIRBuilder.buildTrunc(Result, Mul);
2001 
2002   MachineInstrBuilder ExtResult;
2003   // Overflow occurred if it occurred in the larger type, or if the high part
2004   // of the result does not zero/sign-extend the low part.  Check this second
2005   // possibility first.
2006   if (IsSigned) {
2007     // For signed, overflow occurred when the high part does not sign-extend
2008     // the low part.
2009     ExtResult = MIRBuilder.buildSExtInReg(WideTy, Mul, SrcBitWidth);
2010   } else {
2011     // Unsigned overflow occurred when the high part does not zero-extend the
2012     // low part.
2013     ExtResult = MIRBuilder.buildZExtInReg(WideTy, Mul, SrcBitWidth);
2014   }
2015 
2016   // Multiplication cannot overflow if the WideTy is >= 2 * original width,
2017   // so we don't need to check the overflow result of larger type Mulo.
2018   if (WideTy.getScalarSizeInBits() < 2 * SrcBitWidth) {
2019     auto Overflow =
2020         MIRBuilder.buildICmp(CmpInst::ICMP_NE, OverflowTy, Mul, ExtResult);
2021     // Finally check if the multiplication in the larger type itself overflowed.
2022     MIRBuilder.buildOr(OriginalOverflow, Mulo->getOperand(1), Overflow);
2023   } else {
2024     MIRBuilder.buildICmp(CmpInst::ICMP_NE, OriginalOverflow, Mul, ExtResult);
2025   }
2026   MI.eraseFromParent();
2027   return Legalized;
2028 }
2029 
2030 LegalizerHelper::LegalizeResult
widenScalar(MachineInstr & MI,unsigned TypeIdx,LLT WideTy)2031 LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
2032   switch (MI.getOpcode()) {
2033   default:
2034     return UnableToLegalize;
2035   case TargetOpcode::G_ATOMICRMW_XCHG:
2036   case TargetOpcode::G_ATOMICRMW_ADD:
2037   case TargetOpcode::G_ATOMICRMW_SUB:
2038   case TargetOpcode::G_ATOMICRMW_AND:
2039   case TargetOpcode::G_ATOMICRMW_OR:
2040   case TargetOpcode::G_ATOMICRMW_XOR:
2041   case TargetOpcode::G_ATOMICRMW_MIN:
2042   case TargetOpcode::G_ATOMICRMW_MAX:
2043   case TargetOpcode::G_ATOMICRMW_UMIN:
2044   case TargetOpcode::G_ATOMICRMW_UMAX:
2045     assert(TypeIdx == 0 && "atomicrmw with second scalar type");
2046     Observer.changingInstr(MI);
2047     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
2048     widenScalarDst(MI, WideTy, 0);
2049     Observer.changedInstr(MI);
2050     return Legalized;
2051   case TargetOpcode::G_ATOMIC_CMPXCHG:
2052     assert(TypeIdx == 0 && "G_ATOMIC_CMPXCHG with second scalar type");
2053     Observer.changingInstr(MI);
2054     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
2055     widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_ANYEXT);
2056     widenScalarDst(MI, WideTy, 0);
2057     Observer.changedInstr(MI);
2058     return Legalized;
2059   case TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS:
2060     if (TypeIdx == 0) {
2061       Observer.changingInstr(MI);
2062       widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_ANYEXT);
2063       widenScalarSrc(MI, WideTy, 4, TargetOpcode::G_ANYEXT);
2064       widenScalarDst(MI, WideTy, 0);
2065       Observer.changedInstr(MI);
2066       return Legalized;
2067     }
2068     assert(TypeIdx == 1 &&
2069            "G_ATOMIC_CMPXCHG_WITH_SUCCESS with third scalar type");
2070     Observer.changingInstr(MI);
2071     widenScalarDst(MI, WideTy, 1);
2072     Observer.changedInstr(MI);
2073     return Legalized;
2074   case TargetOpcode::G_EXTRACT:
2075     return widenScalarExtract(MI, TypeIdx, WideTy);
2076   case TargetOpcode::G_INSERT:
2077     return widenScalarInsert(MI, TypeIdx, WideTy);
2078   case TargetOpcode::G_MERGE_VALUES:
2079     return widenScalarMergeValues(MI, TypeIdx, WideTy);
2080   case TargetOpcode::G_UNMERGE_VALUES:
2081     return widenScalarUnmergeValues(MI, TypeIdx, WideTy);
2082   case TargetOpcode::G_SADDO:
2083   case TargetOpcode::G_SSUBO:
2084   case TargetOpcode::G_UADDO:
2085   case TargetOpcode::G_USUBO:
2086   case TargetOpcode::G_SADDE:
2087   case TargetOpcode::G_SSUBE:
2088   case TargetOpcode::G_UADDE:
2089   case TargetOpcode::G_USUBE:
2090     return widenScalarAddSubOverflow(MI, TypeIdx, WideTy);
2091   case TargetOpcode::G_UMULO:
2092   case TargetOpcode::G_SMULO:
2093     return widenScalarMulo(MI, TypeIdx, WideTy);
2094   case TargetOpcode::G_SADDSAT:
2095   case TargetOpcode::G_SSUBSAT:
2096   case TargetOpcode::G_SSHLSAT:
2097   case TargetOpcode::G_UADDSAT:
2098   case TargetOpcode::G_USUBSAT:
2099   case TargetOpcode::G_USHLSAT:
2100     return widenScalarAddSubShlSat(MI, TypeIdx, WideTy);
2101   case TargetOpcode::G_CTTZ:
2102   case TargetOpcode::G_CTTZ_ZERO_UNDEF:
2103   case TargetOpcode::G_CTLZ:
2104   case TargetOpcode::G_CTLZ_ZERO_UNDEF:
2105   case TargetOpcode::G_CTPOP: {
2106     if (TypeIdx == 0) {
2107       Observer.changingInstr(MI);
2108       widenScalarDst(MI, WideTy, 0);
2109       Observer.changedInstr(MI);
2110       return Legalized;
2111     }
2112 
2113     Register SrcReg = MI.getOperand(1).getReg();
2114 
2115     // First extend the input.
2116     unsigned ExtOpc = MI.getOpcode() == TargetOpcode::G_CTTZ ||
2117                               MI.getOpcode() == TargetOpcode::G_CTTZ_ZERO_UNDEF
2118                           ? TargetOpcode::G_ANYEXT
2119                           : TargetOpcode::G_ZEXT;
2120     auto MIBSrc = MIRBuilder.buildInstr(ExtOpc, {WideTy}, {SrcReg});
2121     LLT CurTy = MRI.getType(SrcReg);
2122     unsigned NewOpc = MI.getOpcode();
2123     if (NewOpc == TargetOpcode::G_CTTZ) {
2124       // The count is the same in the larger type except if the original
2125       // value was zero.  This can be handled by setting the bit just off
2126       // the top of the original type.
2127       auto TopBit =
2128           APInt::getOneBitSet(WideTy.getSizeInBits(), CurTy.getSizeInBits());
2129       MIBSrc = MIRBuilder.buildOr(
2130         WideTy, MIBSrc, MIRBuilder.buildConstant(WideTy, TopBit));
2131       // Now we know the operand is non-zero, use the more relaxed opcode.
2132       NewOpc = TargetOpcode::G_CTTZ_ZERO_UNDEF;
2133     }
2134 
2135     // Perform the operation at the larger size.
2136     auto MIBNewOp = MIRBuilder.buildInstr(NewOpc, {WideTy}, {MIBSrc});
2137     // This is already the correct result for CTPOP and CTTZs
2138     if (MI.getOpcode() == TargetOpcode::G_CTLZ ||
2139         MI.getOpcode() == TargetOpcode::G_CTLZ_ZERO_UNDEF) {
2140       // The correct result is NewOp - (Difference in widety and current ty).
2141       unsigned SizeDiff = WideTy.getSizeInBits() - CurTy.getSizeInBits();
2142       MIBNewOp = MIRBuilder.buildSub(
2143           WideTy, MIBNewOp, MIRBuilder.buildConstant(WideTy, SizeDiff));
2144     }
2145 
2146     MIRBuilder.buildZExtOrTrunc(MI.getOperand(0), MIBNewOp);
2147     MI.eraseFromParent();
2148     return Legalized;
2149   }
2150   case TargetOpcode::G_BSWAP: {
2151     Observer.changingInstr(MI);
2152     Register DstReg = MI.getOperand(0).getReg();
2153 
2154     Register ShrReg = MRI.createGenericVirtualRegister(WideTy);
2155     Register DstExt = MRI.createGenericVirtualRegister(WideTy);
2156     Register ShiftAmtReg = MRI.createGenericVirtualRegister(WideTy);
2157     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2158 
2159     MI.getOperand(0).setReg(DstExt);
2160 
2161     MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
2162 
2163     LLT Ty = MRI.getType(DstReg);
2164     unsigned DiffBits = WideTy.getScalarSizeInBits() - Ty.getScalarSizeInBits();
2165     MIRBuilder.buildConstant(ShiftAmtReg, DiffBits);
2166     MIRBuilder.buildLShr(ShrReg, DstExt, ShiftAmtReg);
2167 
2168     MIRBuilder.buildTrunc(DstReg, ShrReg);
2169     Observer.changedInstr(MI);
2170     return Legalized;
2171   }
2172   case TargetOpcode::G_BITREVERSE: {
2173     Observer.changingInstr(MI);
2174 
2175     Register DstReg = MI.getOperand(0).getReg();
2176     LLT Ty = MRI.getType(DstReg);
2177     unsigned DiffBits = WideTy.getScalarSizeInBits() - Ty.getScalarSizeInBits();
2178 
2179     Register DstExt = MRI.createGenericVirtualRegister(WideTy);
2180     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2181     MI.getOperand(0).setReg(DstExt);
2182     MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
2183 
2184     auto ShiftAmt = MIRBuilder.buildConstant(WideTy, DiffBits);
2185     auto Shift = MIRBuilder.buildLShr(WideTy, DstExt, ShiftAmt);
2186     MIRBuilder.buildTrunc(DstReg, Shift);
2187     Observer.changedInstr(MI);
2188     return Legalized;
2189   }
2190   case TargetOpcode::G_FREEZE:
2191     Observer.changingInstr(MI);
2192     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2193     widenScalarDst(MI, WideTy);
2194     Observer.changedInstr(MI);
2195     return Legalized;
2196 
2197   case TargetOpcode::G_ABS:
2198     Observer.changingInstr(MI);
2199     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_SEXT);
2200     widenScalarDst(MI, WideTy);
2201     Observer.changedInstr(MI);
2202     return Legalized;
2203 
2204   case TargetOpcode::G_ADD:
2205   case TargetOpcode::G_AND:
2206   case TargetOpcode::G_MUL:
2207   case TargetOpcode::G_OR:
2208   case TargetOpcode::G_XOR:
2209   case TargetOpcode::G_SUB:
2210     // Perform operation at larger width (any extension is fines here, high bits
2211     // don't affect the result) and then truncate the result back to the
2212     // original type.
2213     Observer.changingInstr(MI);
2214     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2215     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
2216     widenScalarDst(MI, WideTy);
2217     Observer.changedInstr(MI);
2218     return Legalized;
2219 
2220   case TargetOpcode::G_SBFX:
2221   case TargetOpcode::G_UBFX:
2222     Observer.changingInstr(MI);
2223 
2224     if (TypeIdx == 0) {
2225       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2226       widenScalarDst(MI, WideTy);
2227     } else {
2228       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2229       widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_ZEXT);
2230     }
2231 
2232     Observer.changedInstr(MI);
2233     return Legalized;
2234 
2235   case TargetOpcode::G_SHL:
2236     Observer.changingInstr(MI);
2237 
2238     if (TypeIdx == 0) {
2239       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2240       widenScalarDst(MI, WideTy);
2241     } else {
2242       assert(TypeIdx == 1);
2243       // The "number of bits to shift" operand must preserve its value as an
2244       // unsigned integer:
2245       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2246     }
2247 
2248     Observer.changedInstr(MI);
2249     return Legalized;
2250 
2251   case TargetOpcode::G_SDIV:
2252   case TargetOpcode::G_SREM:
2253   case TargetOpcode::G_SMIN:
2254   case TargetOpcode::G_SMAX:
2255     Observer.changingInstr(MI);
2256     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_SEXT);
2257     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_SEXT);
2258     widenScalarDst(MI, WideTy);
2259     Observer.changedInstr(MI);
2260     return Legalized;
2261 
2262   case TargetOpcode::G_SDIVREM:
2263     Observer.changingInstr(MI);
2264     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_SEXT);
2265     widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_SEXT);
2266     widenScalarDst(MI, WideTy);
2267     widenScalarDst(MI, WideTy, 1);
2268     Observer.changedInstr(MI);
2269     return Legalized;
2270 
2271   case TargetOpcode::G_ASHR:
2272   case TargetOpcode::G_LSHR:
2273     Observer.changingInstr(MI);
2274 
2275     if (TypeIdx == 0) {
2276       unsigned CvtOp = MI.getOpcode() == TargetOpcode::G_ASHR ?
2277         TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
2278 
2279       widenScalarSrc(MI, WideTy, 1, CvtOp);
2280       widenScalarDst(MI, WideTy);
2281     } else {
2282       assert(TypeIdx == 1);
2283       // The "number of bits to shift" operand must preserve its value as an
2284       // unsigned integer:
2285       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2286     }
2287 
2288     Observer.changedInstr(MI);
2289     return Legalized;
2290   case TargetOpcode::G_UDIV:
2291   case TargetOpcode::G_UREM:
2292   case TargetOpcode::G_UMIN:
2293   case TargetOpcode::G_UMAX:
2294     Observer.changingInstr(MI);
2295     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ZEXT);
2296     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2297     widenScalarDst(MI, WideTy);
2298     Observer.changedInstr(MI);
2299     return Legalized;
2300 
2301   case TargetOpcode::G_UDIVREM:
2302     Observer.changingInstr(MI);
2303     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2304     widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_ZEXT);
2305     widenScalarDst(MI, WideTy);
2306     widenScalarDst(MI, WideTy, 1);
2307     Observer.changedInstr(MI);
2308     return Legalized;
2309 
2310   case TargetOpcode::G_SELECT:
2311     Observer.changingInstr(MI);
2312     if (TypeIdx == 0) {
2313       // Perform operation at larger width (any extension is fine here, high
2314       // bits don't affect the result) and then truncate the result back to the
2315       // original type.
2316       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
2317       widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_ANYEXT);
2318       widenScalarDst(MI, WideTy);
2319     } else {
2320       bool IsVec = MRI.getType(MI.getOperand(1).getReg()).isVector();
2321       // Explicit extension is required here since high bits affect the result.
2322       widenScalarSrc(MI, WideTy, 1, MIRBuilder.getBoolExtOp(IsVec, false));
2323     }
2324     Observer.changedInstr(MI);
2325     return Legalized;
2326 
2327   case TargetOpcode::G_FPTOSI:
2328   case TargetOpcode::G_FPTOUI:
2329     Observer.changingInstr(MI);
2330 
2331     if (TypeIdx == 0)
2332       widenScalarDst(MI, WideTy);
2333     else
2334       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_FPEXT);
2335 
2336     Observer.changedInstr(MI);
2337     return Legalized;
2338   case TargetOpcode::G_SITOFP:
2339     Observer.changingInstr(MI);
2340 
2341     if (TypeIdx == 0)
2342       widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
2343     else
2344       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_SEXT);
2345 
2346     Observer.changedInstr(MI);
2347     return Legalized;
2348   case TargetOpcode::G_UITOFP:
2349     Observer.changingInstr(MI);
2350 
2351     if (TypeIdx == 0)
2352       widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
2353     else
2354       widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ZEXT);
2355 
2356     Observer.changedInstr(MI);
2357     return Legalized;
2358   case TargetOpcode::G_LOAD:
2359   case TargetOpcode::G_SEXTLOAD:
2360   case TargetOpcode::G_ZEXTLOAD:
2361     Observer.changingInstr(MI);
2362     widenScalarDst(MI, WideTy);
2363     Observer.changedInstr(MI);
2364     return Legalized;
2365 
2366   case TargetOpcode::G_STORE: {
2367     if (TypeIdx != 0)
2368       return UnableToLegalize;
2369 
2370     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
2371     if (!Ty.isScalar())
2372       return UnableToLegalize;
2373 
2374     Observer.changingInstr(MI);
2375 
2376     unsigned ExtType = Ty.getScalarSizeInBits() == 1 ?
2377       TargetOpcode::G_ZEXT : TargetOpcode::G_ANYEXT;
2378     widenScalarSrc(MI, WideTy, 0, ExtType);
2379 
2380     Observer.changedInstr(MI);
2381     return Legalized;
2382   }
2383   case TargetOpcode::G_CONSTANT: {
2384     MachineOperand &SrcMO = MI.getOperand(1);
2385     LLVMContext &Ctx = MIRBuilder.getMF().getFunction().getContext();
2386     unsigned ExtOpc = LI.getExtOpcodeForWideningConstant(
2387         MRI.getType(MI.getOperand(0).getReg()));
2388     assert((ExtOpc == TargetOpcode::G_ZEXT || ExtOpc == TargetOpcode::G_SEXT ||
2389             ExtOpc == TargetOpcode::G_ANYEXT) &&
2390            "Illegal Extend");
2391     const APInt &SrcVal = SrcMO.getCImm()->getValue();
2392     const APInt &Val = (ExtOpc == TargetOpcode::G_SEXT)
2393                            ? SrcVal.sext(WideTy.getSizeInBits())
2394                            : SrcVal.zext(WideTy.getSizeInBits());
2395     Observer.changingInstr(MI);
2396     SrcMO.setCImm(ConstantInt::get(Ctx, Val));
2397 
2398     widenScalarDst(MI, WideTy);
2399     Observer.changedInstr(MI);
2400     return Legalized;
2401   }
2402   case TargetOpcode::G_FCONSTANT: {
2403     // To avoid changing the bits of the constant due to extension to a larger
2404     // type and then using G_FPTRUNC, we simply convert to a G_CONSTANT.
2405     MachineOperand &SrcMO = MI.getOperand(1);
2406     APInt Val = SrcMO.getFPImm()->getValueAPF().bitcastToAPInt();
2407     MIRBuilder.setInstrAndDebugLoc(MI);
2408     auto IntCst = MIRBuilder.buildConstant(MI.getOperand(0).getReg(), Val);
2409     widenScalarDst(*IntCst, WideTy, 0, TargetOpcode::G_TRUNC);
2410     MI.eraseFromParent();
2411     return Legalized;
2412   }
2413   case TargetOpcode::G_IMPLICIT_DEF: {
2414     Observer.changingInstr(MI);
2415     widenScalarDst(MI, WideTy);
2416     Observer.changedInstr(MI);
2417     return Legalized;
2418   }
2419   case TargetOpcode::G_BRCOND:
2420     Observer.changingInstr(MI);
2421     widenScalarSrc(MI, WideTy, 0, MIRBuilder.getBoolExtOp(false, false));
2422     Observer.changedInstr(MI);
2423     return Legalized;
2424 
2425   case TargetOpcode::G_FCMP:
2426     Observer.changingInstr(MI);
2427     if (TypeIdx == 0)
2428       widenScalarDst(MI, WideTy);
2429     else {
2430       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_FPEXT);
2431       widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_FPEXT);
2432     }
2433     Observer.changedInstr(MI);
2434     return Legalized;
2435 
2436   case TargetOpcode::G_ICMP:
2437     Observer.changingInstr(MI);
2438     if (TypeIdx == 0)
2439       widenScalarDst(MI, WideTy);
2440     else {
2441       unsigned ExtOpcode = CmpInst::isSigned(static_cast<CmpInst::Predicate>(
2442                                MI.getOperand(1).getPredicate()))
2443                                ? TargetOpcode::G_SEXT
2444                                : TargetOpcode::G_ZEXT;
2445       widenScalarSrc(MI, WideTy, 2, ExtOpcode);
2446       widenScalarSrc(MI, WideTy, 3, ExtOpcode);
2447     }
2448     Observer.changedInstr(MI);
2449     return Legalized;
2450 
2451   case TargetOpcode::G_PTR_ADD:
2452     assert(TypeIdx == 1 && "unable to legalize pointer of G_PTR_ADD");
2453     Observer.changingInstr(MI);
2454     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_SEXT);
2455     Observer.changedInstr(MI);
2456     return Legalized;
2457 
2458   case TargetOpcode::G_PHI: {
2459     assert(TypeIdx == 0 && "Expecting only Idx 0");
2460 
2461     Observer.changingInstr(MI);
2462     for (unsigned I = 1; I < MI.getNumOperands(); I += 2) {
2463       MachineBasicBlock &OpMBB = *MI.getOperand(I + 1).getMBB();
2464       MIRBuilder.setInsertPt(OpMBB, OpMBB.getFirstTerminatorForward());
2465       widenScalarSrc(MI, WideTy, I, TargetOpcode::G_ANYEXT);
2466     }
2467 
2468     MachineBasicBlock &MBB = *MI.getParent();
2469     MIRBuilder.setInsertPt(MBB, --MBB.getFirstNonPHI());
2470     widenScalarDst(MI, WideTy);
2471     Observer.changedInstr(MI);
2472     return Legalized;
2473   }
2474   case TargetOpcode::G_EXTRACT_VECTOR_ELT: {
2475     if (TypeIdx == 0) {
2476       Register VecReg = MI.getOperand(1).getReg();
2477       LLT VecTy = MRI.getType(VecReg);
2478       Observer.changingInstr(MI);
2479 
2480       widenScalarSrc(
2481           MI, LLT::vector(VecTy.getElementCount(), WideTy.getSizeInBits()), 1,
2482           TargetOpcode::G_ANYEXT);
2483 
2484       widenScalarDst(MI, WideTy, 0);
2485       Observer.changedInstr(MI);
2486       return Legalized;
2487     }
2488 
2489     if (TypeIdx != 2)
2490       return UnableToLegalize;
2491     Observer.changingInstr(MI);
2492     // TODO: Probably should be zext
2493     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_SEXT);
2494     Observer.changedInstr(MI);
2495     return Legalized;
2496   }
2497   case TargetOpcode::G_INSERT_VECTOR_ELT: {
2498     if (TypeIdx == 1) {
2499       Observer.changingInstr(MI);
2500 
2501       Register VecReg = MI.getOperand(1).getReg();
2502       LLT VecTy = MRI.getType(VecReg);
2503       LLT WideVecTy = LLT::vector(VecTy.getElementCount(), WideTy);
2504 
2505       widenScalarSrc(MI, WideVecTy, 1, TargetOpcode::G_ANYEXT);
2506       widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
2507       widenScalarDst(MI, WideVecTy, 0);
2508       Observer.changedInstr(MI);
2509       return Legalized;
2510     }
2511 
2512     if (TypeIdx == 2) {
2513       Observer.changingInstr(MI);
2514       // TODO: Probably should be zext
2515       widenScalarSrc(MI, WideTy, 3, TargetOpcode::G_SEXT);
2516       Observer.changedInstr(MI);
2517       return Legalized;
2518     }
2519 
2520     return UnableToLegalize;
2521   }
2522   case TargetOpcode::G_FADD:
2523   case TargetOpcode::G_FMUL:
2524   case TargetOpcode::G_FSUB:
2525   case TargetOpcode::G_FMA:
2526   case TargetOpcode::G_FMAD:
2527   case TargetOpcode::G_FNEG:
2528   case TargetOpcode::G_FABS:
2529   case TargetOpcode::G_FCANONICALIZE:
2530   case TargetOpcode::G_FMINNUM:
2531   case TargetOpcode::G_FMAXNUM:
2532   case TargetOpcode::G_FMINNUM_IEEE:
2533   case TargetOpcode::G_FMAXNUM_IEEE:
2534   case TargetOpcode::G_FMINIMUM:
2535   case TargetOpcode::G_FMAXIMUM:
2536   case TargetOpcode::G_FDIV:
2537   case TargetOpcode::G_FREM:
2538   case TargetOpcode::G_FCEIL:
2539   case TargetOpcode::G_FFLOOR:
2540   case TargetOpcode::G_FCOS:
2541   case TargetOpcode::G_FSIN:
2542   case TargetOpcode::G_FLOG10:
2543   case TargetOpcode::G_FLOG:
2544   case TargetOpcode::G_FLOG2:
2545   case TargetOpcode::G_FRINT:
2546   case TargetOpcode::G_FNEARBYINT:
2547   case TargetOpcode::G_FSQRT:
2548   case TargetOpcode::G_FEXP:
2549   case TargetOpcode::G_FEXP2:
2550   case TargetOpcode::G_FPOW:
2551   case TargetOpcode::G_INTRINSIC_TRUNC:
2552   case TargetOpcode::G_INTRINSIC_ROUND:
2553   case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
2554     assert(TypeIdx == 0);
2555     Observer.changingInstr(MI);
2556 
2557     for (unsigned I = 1, E = MI.getNumOperands(); I != E; ++I)
2558       widenScalarSrc(MI, WideTy, I, TargetOpcode::G_FPEXT);
2559 
2560     widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
2561     Observer.changedInstr(MI);
2562     return Legalized;
2563   case TargetOpcode::G_FPOWI: {
2564     if (TypeIdx != 0)
2565       return UnableToLegalize;
2566     Observer.changingInstr(MI);
2567     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_FPEXT);
2568     widenScalarDst(MI, WideTy, 0, TargetOpcode::G_FPTRUNC);
2569     Observer.changedInstr(MI);
2570     return Legalized;
2571   }
2572   case TargetOpcode::G_INTTOPTR:
2573     if (TypeIdx != 1)
2574       return UnableToLegalize;
2575 
2576     Observer.changingInstr(MI);
2577     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ZEXT);
2578     Observer.changedInstr(MI);
2579     return Legalized;
2580   case TargetOpcode::G_PTRTOINT:
2581     if (TypeIdx != 0)
2582       return UnableToLegalize;
2583 
2584     Observer.changingInstr(MI);
2585     widenScalarDst(MI, WideTy, 0);
2586     Observer.changedInstr(MI);
2587     return Legalized;
2588   case TargetOpcode::G_BUILD_VECTOR: {
2589     Observer.changingInstr(MI);
2590 
2591     const LLT WideEltTy = TypeIdx == 1 ? WideTy : WideTy.getElementType();
2592     for (int I = 1, E = MI.getNumOperands(); I != E; ++I)
2593       widenScalarSrc(MI, WideEltTy, I, TargetOpcode::G_ANYEXT);
2594 
2595     // Avoid changing the result vector type if the source element type was
2596     // requested.
2597     if (TypeIdx == 1) {
2598       MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::G_BUILD_VECTOR_TRUNC));
2599     } else {
2600       widenScalarDst(MI, WideTy, 0);
2601     }
2602 
2603     Observer.changedInstr(MI);
2604     return Legalized;
2605   }
2606   case TargetOpcode::G_SEXT_INREG:
2607     if (TypeIdx != 0)
2608       return UnableToLegalize;
2609 
2610     Observer.changingInstr(MI);
2611     widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
2612     widenScalarDst(MI, WideTy, 0, TargetOpcode::G_TRUNC);
2613     Observer.changedInstr(MI);
2614     return Legalized;
2615   case TargetOpcode::G_PTRMASK: {
2616     if (TypeIdx != 1)
2617       return UnableToLegalize;
2618     Observer.changingInstr(MI);
2619     widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ZEXT);
2620     Observer.changedInstr(MI);
2621     return Legalized;
2622   }
2623   }
2624 }
2625 
getUnmergePieces(SmallVectorImpl<Register> & Pieces,MachineIRBuilder & B,Register Src,LLT Ty)2626 static void getUnmergePieces(SmallVectorImpl<Register> &Pieces,
2627                              MachineIRBuilder &B, Register Src, LLT Ty) {
2628   auto Unmerge = B.buildUnmerge(Ty, Src);
2629   for (int I = 0, E = Unmerge->getNumOperands() - 1; I != E; ++I)
2630     Pieces.push_back(Unmerge.getReg(I));
2631 }
2632 
2633 LegalizerHelper::LegalizeResult
lowerBitcast(MachineInstr & MI)2634 LegalizerHelper::lowerBitcast(MachineInstr &MI) {
2635   Register Dst = MI.getOperand(0).getReg();
2636   Register Src = MI.getOperand(1).getReg();
2637   LLT DstTy = MRI.getType(Dst);
2638   LLT SrcTy = MRI.getType(Src);
2639 
2640   if (SrcTy.isVector()) {
2641     LLT SrcEltTy = SrcTy.getElementType();
2642     SmallVector<Register, 8> SrcRegs;
2643 
2644     if (DstTy.isVector()) {
2645       int NumDstElt = DstTy.getNumElements();
2646       int NumSrcElt = SrcTy.getNumElements();
2647 
2648       LLT DstEltTy = DstTy.getElementType();
2649       LLT DstCastTy = DstEltTy; // Intermediate bitcast result type
2650       LLT SrcPartTy = SrcEltTy; // Original unmerge result type.
2651 
2652       // If there's an element size mismatch, insert intermediate casts to match
2653       // the result element type.
2654       if (NumSrcElt < NumDstElt) { // Source element type is larger.
2655         // %1:_(<4 x s8>) = G_BITCAST %0:_(<2 x s16>)
2656         //
2657         // =>
2658         //
2659         // %2:_(s16), %3:_(s16) = G_UNMERGE_VALUES %0
2660         // %3:_(<2 x s8>) = G_BITCAST %2
2661         // %4:_(<2 x s8>) = G_BITCAST %3
2662         // %1:_(<4 x s16>) = G_CONCAT_VECTORS %3, %4
2663         DstCastTy = LLT::fixed_vector(NumDstElt / NumSrcElt, DstEltTy);
2664         SrcPartTy = SrcEltTy;
2665       } else if (NumSrcElt > NumDstElt) { // Source element type is smaller.
2666         //
2667         // %1:_(<2 x s16>) = G_BITCAST %0:_(<4 x s8>)
2668         //
2669         // =>
2670         //
2671         // %2:_(<2 x s8>), %3:_(<2 x s8>) = G_UNMERGE_VALUES %0
2672         // %3:_(s16) = G_BITCAST %2
2673         // %4:_(s16) = G_BITCAST %3
2674         // %1:_(<2 x s16>) = G_BUILD_VECTOR %3, %4
2675         SrcPartTy = LLT::fixed_vector(NumSrcElt / NumDstElt, SrcEltTy);
2676         DstCastTy = DstEltTy;
2677       }
2678 
2679       getUnmergePieces(SrcRegs, MIRBuilder, Src, SrcPartTy);
2680       for (Register &SrcReg : SrcRegs)
2681         SrcReg = MIRBuilder.buildBitcast(DstCastTy, SrcReg).getReg(0);
2682     } else
2683       getUnmergePieces(SrcRegs, MIRBuilder, Src, SrcEltTy);
2684 
2685     MIRBuilder.buildMergeLikeInstr(Dst, SrcRegs);
2686     MI.eraseFromParent();
2687     return Legalized;
2688   }
2689 
2690   if (DstTy.isVector()) {
2691     SmallVector<Register, 8> SrcRegs;
2692     getUnmergePieces(SrcRegs, MIRBuilder, Src, DstTy.getElementType());
2693     MIRBuilder.buildMergeLikeInstr(Dst, SrcRegs);
2694     MI.eraseFromParent();
2695     return Legalized;
2696   }
2697 
2698   return UnableToLegalize;
2699 }
2700 
2701 /// Figure out the bit offset into a register when coercing a vector index for
2702 /// the wide element type. This is only for the case when promoting vector to
2703 /// one with larger elements.
2704 //
2705 ///
2706 /// %offset_idx = G_AND %idx, ~(-1 << Log2(DstEltSize / SrcEltSize))
2707 /// %offset_bits = G_SHL %offset_idx, Log2(SrcEltSize)
getBitcastWiderVectorElementOffset(MachineIRBuilder & B,Register Idx,unsigned NewEltSize,unsigned OldEltSize)2708 static Register getBitcastWiderVectorElementOffset(MachineIRBuilder &B,
2709                                                    Register Idx,
2710                                                    unsigned NewEltSize,
2711                                                    unsigned OldEltSize) {
2712   const unsigned Log2EltRatio = Log2_32(NewEltSize / OldEltSize);
2713   LLT IdxTy = B.getMRI()->getType(Idx);
2714 
2715   // Now figure out the amount we need to shift to get the target bits.
2716   auto OffsetMask = B.buildConstant(
2717       IdxTy, ~(APInt::getAllOnes(IdxTy.getSizeInBits()) << Log2EltRatio));
2718   auto OffsetIdx = B.buildAnd(IdxTy, Idx, OffsetMask);
2719   return B.buildShl(IdxTy, OffsetIdx,
2720                     B.buildConstant(IdxTy, Log2_32(OldEltSize))).getReg(0);
2721 }
2722 
2723 /// Perform a G_EXTRACT_VECTOR_ELT in a different sized vector element. If this
2724 /// is casting to a vector with a smaller element size, perform multiple element
2725 /// extracts and merge the results. If this is coercing to a vector with larger
2726 /// elements, index the bitcasted vector and extract the target element with bit
2727 /// operations. This is intended to force the indexing in the native register
2728 /// size for architectures that can dynamically index the register file.
2729 LegalizerHelper::LegalizeResult
bitcastExtractVectorElt(MachineInstr & MI,unsigned TypeIdx,LLT CastTy)2730 LegalizerHelper::bitcastExtractVectorElt(MachineInstr &MI, unsigned TypeIdx,
2731                                          LLT CastTy) {
2732   if (TypeIdx != 1)
2733     return UnableToLegalize;
2734 
2735   Register Dst = MI.getOperand(0).getReg();
2736   Register SrcVec = MI.getOperand(1).getReg();
2737   Register Idx = MI.getOperand(2).getReg();
2738   LLT SrcVecTy = MRI.getType(SrcVec);
2739   LLT IdxTy = MRI.getType(Idx);
2740 
2741   LLT SrcEltTy = SrcVecTy.getElementType();
2742   unsigned NewNumElts = CastTy.isVector() ? CastTy.getNumElements() : 1;
2743   unsigned OldNumElts = SrcVecTy.getNumElements();
2744 
2745   LLT NewEltTy = CastTy.isVector() ? CastTy.getElementType() : CastTy;
2746   Register CastVec = MIRBuilder.buildBitcast(CastTy, SrcVec).getReg(0);
2747 
2748   const unsigned NewEltSize = NewEltTy.getSizeInBits();
2749   const unsigned OldEltSize = SrcEltTy.getSizeInBits();
2750   if (NewNumElts > OldNumElts) {
2751     // Decreasing the vector element size
2752     //
2753     // e.g. i64 = extract_vector_elt x:v2i64, y:i32
2754     //  =>
2755     //  v4i32:castx = bitcast x:v2i64
2756     //
2757     // i64 = bitcast
2758     //   (v2i32 build_vector (i32 (extract_vector_elt castx, (2 * y))),
2759     //                       (i32 (extract_vector_elt castx, (2 * y + 1)))
2760     //
2761     if (NewNumElts % OldNumElts != 0)
2762       return UnableToLegalize;
2763 
2764     // Type of the intermediate result vector.
2765     const unsigned NewEltsPerOldElt = NewNumElts / OldNumElts;
2766     LLT MidTy =
2767         LLT::scalarOrVector(ElementCount::getFixed(NewEltsPerOldElt), NewEltTy);
2768 
2769     auto NewEltsPerOldEltK = MIRBuilder.buildConstant(IdxTy, NewEltsPerOldElt);
2770 
2771     SmallVector<Register, 8> NewOps(NewEltsPerOldElt);
2772     auto NewBaseIdx = MIRBuilder.buildMul(IdxTy, Idx, NewEltsPerOldEltK);
2773 
2774     for (unsigned I = 0; I < NewEltsPerOldElt; ++I) {
2775       auto IdxOffset = MIRBuilder.buildConstant(IdxTy, I);
2776       auto TmpIdx = MIRBuilder.buildAdd(IdxTy, NewBaseIdx, IdxOffset);
2777       auto Elt = MIRBuilder.buildExtractVectorElement(NewEltTy, CastVec, TmpIdx);
2778       NewOps[I] = Elt.getReg(0);
2779     }
2780 
2781     auto NewVec = MIRBuilder.buildBuildVector(MidTy, NewOps);
2782     MIRBuilder.buildBitcast(Dst, NewVec);
2783     MI.eraseFromParent();
2784     return Legalized;
2785   }
2786 
2787   if (NewNumElts < OldNumElts) {
2788     if (NewEltSize % OldEltSize != 0)
2789       return UnableToLegalize;
2790 
2791     // This only depends on powers of 2 because we use bit tricks to figure out
2792     // the bit offset we need to shift to get the target element. A general
2793     // expansion could emit division/multiply.
2794     if (!isPowerOf2_32(NewEltSize / OldEltSize))
2795       return UnableToLegalize;
2796 
2797     // Increasing the vector element size.
2798     // %elt:_(small_elt) = G_EXTRACT_VECTOR_ELT %vec:_(<N x small_elt>), %idx
2799     //
2800     //   =>
2801     //
2802     // %cast = G_BITCAST %vec
2803     // %scaled_idx = G_LSHR %idx, Log2(DstEltSize / SrcEltSize)
2804     // %wide_elt  = G_EXTRACT_VECTOR_ELT %cast, %scaled_idx
2805     // %offset_idx = G_AND %idx, ~(-1 << Log2(DstEltSize / SrcEltSize))
2806     // %offset_bits = G_SHL %offset_idx, Log2(SrcEltSize)
2807     // %elt_bits = G_LSHR %wide_elt, %offset_bits
2808     // %elt = G_TRUNC %elt_bits
2809 
2810     const unsigned Log2EltRatio = Log2_32(NewEltSize / OldEltSize);
2811     auto Log2Ratio = MIRBuilder.buildConstant(IdxTy, Log2EltRatio);
2812 
2813     // Divide to get the index in the wider element type.
2814     auto ScaledIdx = MIRBuilder.buildLShr(IdxTy, Idx, Log2Ratio);
2815 
2816     Register WideElt = CastVec;
2817     if (CastTy.isVector()) {
2818       WideElt = MIRBuilder.buildExtractVectorElement(NewEltTy, CastVec,
2819                                                      ScaledIdx).getReg(0);
2820     }
2821 
2822     // Compute the bit offset into the register of the target element.
2823     Register OffsetBits = getBitcastWiderVectorElementOffset(
2824       MIRBuilder, Idx, NewEltSize, OldEltSize);
2825 
2826     // Shift the wide element to get the target element.
2827     auto ExtractedBits = MIRBuilder.buildLShr(NewEltTy, WideElt, OffsetBits);
2828     MIRBuilder.buildTrunc(Dst, ExtractedBits);
2829     MI.eraseFromParent();
2830     return Legalized;
2831   }
2832 
2833   return UnableToLegalize;
2834 }
2835 
2836 /// Emit code to insert \p InsertReg into \p TargetRet at \p OffsetBits in \p
2837 /// TargetReg, while preserving other bits in \p TargetReg.
2838 ///
2839 /// (InsertReg << Offset) | (TargetReg & ~(-1 >> InsertReg.size()) << Offset)
buildBitFieldInsert(MachineIRBuilder & B,Register TargetReg,Register InsertReg,Register OffsetBits)2840 static Register buildBitFieldInsert(MachineIRBuilder &B,
2841                                     Register TargetReg, Register InsertReg,
2842                                     Register OffsetBits) {
2843   LLT TargetTy = B.getMRI()->getType(TargetReg);
2844   LLT InsertTy = B.getMRI()->getType(InsertReg);
2845   auto ZextVal = B.buildZExt(TargetTy, InsertReg);
2846   auto ShiftedInsertVal = B.buildShl(TargetTy, ZextVal, OffsetBits);
2847 
2848   // Produce a bitmask of the value to insert
2849   auto EltMask = B.buildConstant(
2850     TargetTy, APInt::getLowBitsSet(TargetTy.getSizeInBits(),
2851                                    InsertTy.getSizeInBits()));
2852   // Shift it into position
2853   auto ShiftedMask = B.buildShl(TargetTy, EltMask, OffsetBits);
2854   auto InvShiftedMask = B.buildNot(TargetTy, ShiftedMask);
2855 
2856   // Clear out the bits in the wide element
2857   auto MaskedOldElt = B.buildAnd(TargetTy, TargetReg, InvShiftedMask);
2858 
2859   // The value to insert has all zeros already, so stick it into the masked
2860   // wide element.
2861   return B.buildOr(TargetTy, MaskedOldElt, ShiftedInsertVal).getReg(0);
2862 }
2863 
2864 /// Perform a G_INSERT_VECTOR_ELT in a different sized vector element. If this
2865 /// is increasing the element size, perform the indexing in the target element
2866 /// type, and use bit operations to insert at the element position. This is
2867 /// intended for architectures that can dynamically index the register file and
2868 /// want to force indexing in the native register size.
2869 LegalizerHelper::LegalizeResult
bitcastInsertVectorElt(MachineInstr & MI,unsigned TypeIdx,LLT CastTy)2870 LegalizerHelper::bitcastInsertVectorElt(MachineInstr &MI, unsigned TypeIdx,
2871                                         LLT CastTy) {
2872   if (TypeIdx != 0)
2873     return UnableToLegalize;
2874 
2875   Register Dst = MI.getOperand(0).getReg();
2876   Register SrcVec = MI.getOperand(1).getReg();
2877   Register Val = MI.getOperand(2).getReg();
2878   Register Idx = MI.getOperand(3).getReg();
2879 
2880   LLT VecTy = MRI.getType(Dst);
2881   LLT IdxTy = MRI.getType(Idx);
2882 
2883   LLT VecEltTy = VecTy.getElementType();
2884   LLT NewEltTy = CastTy.isVector() ? CastTy.getElementType() : CastTy;
2885   const unsigned NewEltSize = NewEltTy.getSizeInBits();
2886   const unsigned OldEltSize = VecEltTy.getSizeInBits();
2887 
2888   unsigned NewNumElts = CastTy.isVector() ? CastTy.getNumElements() : 1;
2889   unsigned OldNumElts = VecTy.getNumElements();
2890 
2891   Register CastVec = MIRBuilder.buildBitcast(CastTy, SrcVec).getReg(0);
2892   if (NewNumElts < OldNumElts) {
2893     if (NewEltSize % OldEltSize != 0)
2894       return UnableToLegalize;
2895 
2896     // This only depends on powers of 2 because we use bit tricks to figure out
2897     // the bit offset we need to shift to get the target element. A general
2898     // expansion could emit division/multiply.
2899     if (!isPowerOf2_32(NewEltSize / OldEltSize))
2900       return UnableToLegalize;
2901 
2902     const unsigned Log2EltRatio = Log2_32(NewEltSize / OldEltSize);
2903     auto Log2Ratio = MIRBuilder.buildConstant(IdxTy, Log2EltRatio);
2904 
2905     // Divide to get the index in the wider element type.
2906     auto ScaledIdx = MIRBuilder.buildLShr(IdxTy, Idx, Log2Ratio);
2907 
2908     Register ExtractedElt = CastVec;
2909     if (CastTy.isVector()) {
2910       ExtractedElt = MIRBuilder.buildExtractVectorElement(NewEltTy, CastVec,
2911                                                           ScaledIdx).getReg(0);
2912     }
2913 
2914     // Compute the bit offset into the register of the target element.
2915     Register OffsetBits = getBitcastWiderVectorElementOffset(
2916       MIRBuilder, Idx, NewEltSize, OldEltSize);
2917 
2918     Register InsertedElt = buildBitFieldInsert(MIRBuilder, ExtractedElt,
2919                                                Val, OffsetBits);
2920     if (CastTy.isVector()) {
2921       InsertedElt = MIRBuilder.buildInsertVectorElement(
2922         CastTy, CastVec, InsertedElt, ScaledIdx).getReg(0);
2923     }
2924 
2925     MIRBuilder.buildBitcast(Dst, InsertedElt);
2926     MI.eraseFromParent();
2927     return Legalized;
2928   }
2929 
2930   return UnableToLegalize;
2931 }
2932 
lowerLoad(GAnyLoad & LoadMI)2933 LegalizerHelper::LegalizeResult LegalizerHelper::lowerLoad(GAnyLoad &LoadMI) {
2934   // Lower to a memory-width G_LOAD and a G_SEXT/G_ZEXT/G_ANYEXT
2935   Register DstReg = LoadMI.getDstReg();
2936   Register PtrReg = LoadMI.getPointerReg();
2937   LLT DstTy = MRI.getType(DstReg);
2938   MachineMemOperand &MMO = LoadMI.getMMO();
2939   LLT MemTy = MMO.getMemoryType();
2940   MachineFunction &MF = MIRBuilder.getMF();
2941 
2942   unsigned MemSizeInBits = MemTy.getSizeInBits();
2943   unsigned MemStoreSizeInBits = 8 * MemTy.getSizeInBytes();
2944 
2945   if (MemSizeInBits != MemStoreSizeInBits) {
2946     if (MemTy.isVector())
2947       return UnableToLegalize;
2948 
2949     // Promote to a byte-sized load if not loading an integral number of
2950     // bytes.  For example, promote EXTLOAD:i20 -> EXTLOAD:i24.
2951     LLT WideMemTy = LLT::scalar(MemStoreSizeInBits);
2952     MachineMemOperand *NewMMO =
2953         MF.getMachineMemOperand(&MMO, MMO.getPointerInfo(), WideMemTy);
2954 
2955     Register LoadReg = DstReg;
2956     LLT LoadTy = DstTy;
2957 
2958     // If this wasn't already an extending load, we need to widen the result
2959     // register to avoid creating a load with a narrower result than the source.
2960     if (MemStoreSizeInBits > DstTy.getSizeInBits()) {
2961       LoadTy = WideMemTy;
2962       LoadReg = MRI.createGenericVirtualRegister(WideMemTy);
2963     }
2964 
2965     if (isa<GSExtLoad>(LoadMI)) {
2966       auto NewLoad = MIRBuilder.buildLoad(LoadTy, PtrReg, *NewMMO);
2967       MIRBuilder.buildSExtInReg(LoadReg, NewLoad, MemSizeInBits);
2968     } else if (isa<GZExtLoad>(LoadMI) || WideMemTy == LoadTy) {
2969       auto NewLoad = MIRBuilder.buildLoad(LoadTy, PtrReg, *NewMMO);
2970       // The extra bits are guaranteed to be zero, since we stored them that
2971       // way.  A zext load from Wide thus automatically gives zext from MemVT.
2972       MIRBuilder.buildAssertZExt(LoadReg, NewLoad, MemSizeInBits);
2973     } else {
2974       MIRBuilder.buildLoad(LoadReg, PtrReg, *NewMMO);
2975     }
2976 
2977     if (DstTy != LoadTy)
2978       MIRBuilder.buildTrunc(DstReg, LoadReg);
2979 
2980     LoadMI.eraseFromParent();
2981     return Legalized;
2982   }
2983 
2984   // Big endian lowering not implemented.
2985   if (MIRBuilder.getDataLayout().isBigEndian())
2986     return UnableToLegalize;
2987 
2988   // This load needs splitting into power of 2 sized loads.
2989   //
2990   // Our strategy here is to generate anyextending loads for the smaller
2991   // types up to next power-2 result type, and then combine the two larger
2992   // result values together, before truncating back down to the non-pow-2
2993   // type.
2994   // E.g. v1 = i24 load =>
2995   // v2 = i32 zextload (2 byte)
2996   // v3 = i32 load (1 byte)
2997   // v4 = i32 shl v3, 16
2998   // v5 = i32 or v4, v2
2999   // v1 = i24 trunc v5
3000   // By doing this we generate the correct truncate which should get
3001   // combined away as an artifact with a matching extend.
3002 
3003   uint64_t LargeSplitSize, SmallSplitSize;
3004 
3005   if (!isPowerOf2_32(MemSizeInBits)) {
3006     // This load needs splitting into power of 2 sized loads.
3007     LargeSplitSize = PowerOf2Floor(MemSizeInBits);
3008     SmallSplitSize = MemSizeInBits - LargeSplitSize;
3009   } else {
3010     // This is already a power of 2, but we still need to split this in half.
3011     //
3012     // Assume we're being asked to decompose an unaligned load.
3013     // TODO: If this requires multiple splits, handle them all at once.
3014     auto &Ctx = MF.getFunction().getContext();
3015     if (TLI.allowsMemoryAccess(Ctx, MIRBuilder.getDataLayout(), MemTy, MMO))
3016       return UnableToLegalize;
3017 
3018     SmallSplitSize = LargeSplitSize = MemSizeInBits / 2;
3019   }
3020 
3021   if (MemTy.isVector()) {
3022     // TODO: Handle vector extloads
3023     if (MemTy != DstTy)
3024       return UnableToLegalize;
3025 
3026     // TODO: We can do better than scalarizing the vector and at least split it
3027     // in half.
3028     return reduceLoadStoreWidth(LoadMI, 0, DstTy.getElementType());
3029   }
3030 
3031   MachineMemOperand *LargeMMO =
3032       MF.getMachineMemOperand(&MMO, 0, LargeSplitSize / 8);
3033   MachineMemOperand *SmallMMO =
3034       MF.getMachineMemOperand(&MMO, LargeSplitSize / 8, SmallSplitSize / 8);
3035 
3036   LLT PtrTy = MRI.getType(PtrReg);
3037   unsigned AnyExtSize = PowerOf2Ceil(DstTy.getSizeInBits());
3038   LLT AnyExtTy = LLT::scalar(AnyExtSize);
3039   auto LargeLoad = MIRBuilder.buildLoadInstr(TargetOpcode::G_ZEXTLOAD, AnyExtTy,
3040                                              PtrReg, *LargeMMO);
3041 
3042   auto OffsetCst = MIRBuilder.buildConstant(LLT::scalar(PtrTy.getSizeInBits()),
3043                                             LargeSplitSize / 8);
3044   Register PtrAddReg = MRI.createGenericVirtualRegister(PtrTy);
3045   auto SmallPtr = MIRBuilder.buildPtrAdd(PtrAddReg, PtrReg, OffsetCst);
3046   auto SmallLoad = MIRBuilder.buildLoadInstr(LoadMI.getOpcode(), AnyExtTy,
3047                                              SmallPtr, *SmallMMO);
3048 
3049   auto ShiftAmt = MIRBuilder.buildConstant(AnyExtTy, LargeSplitSize);
3050   auto Shift = MIRBuilder.buildShl(AnyExtTy, SmallLoad, ShiftAmt);
3051 
3052   if (AnyExtTy == DstTy)
3053     MIRBuilder.buildOr(DstReg, Shift, LargeLoad);
3054   else if (AnyExtTy.getSizeInBits() != DstTy.getSizeInBits()) {
3055     auto Or = MIRBuilder.buildOr(AnyExtTy, Shift, LargeLoad);
3056     MIRBuilder.buildTrunc(DstReg, {Or});
3057   } else {
3058     assert(DstTy.isPointer() && "expected pointer");
3059     auto Or = MIRBuilder.buildOr(AnyExtTy, Shift, LargeLoad);
3060 
3061     // FIXME: We currently consider this to be illegal for non-integral address
3062     // spaces, but we need still need a way to reinterpret the bits.
3063     MIRBuilder.buildIntToPtr(DstReg, Or);
3064   }
3065 
3066   LoadMI.eraseFromParent();
3067   return Legalized;
3068 }
3069 
lowerStore(GStore & StoreMI)3070 LegalizerHelper::LegalizeResult LegalizerHelper::lowerStore(GStore &StoreMI) {
3071   // Lower a non-power of 2 store into multiple pow-2 stores.
3072   // E.g. split an i24 store into an i16 store + i8 store.
3073   // We do this by first extending the stored value to the next largest power
3074   // of 2 type, and then using truncating stores to store the components.
3075   // By doing this, likewise with G_LOAD, generate an extend that can be
3076   // artifact-combined away instead of leaving behind extracts.
3077   Register SrcReg = StoreMI.getValueReg();
3078   Register PtrReg = StoreMI.getPointerReg();
3079   LLT SrcTy = MRI.getType(SrcReg);
3080   MachineFunction &MF = MIRBuilder.getMF();
3081   MachineMemOperand &MMO = **StoreMI.memoperands_begin();
3082   LLT MemTy = MMO.getMemoryType();
3083 
3084   unsigned StoreWidth = MemTy.getSizeInBits();
3085   unsigned StoreSizeInBits = 8 * MemTy.getSizeInBytes();
3086 
3087   if (StoreWidth != StoreSizeInBits) {
3088     if (SrcTy.isVector())
3089       return UnableToLegalize;
3090 
3091     // Promote to a byte-sized store with upper bits zero if not
3092     // storing an integral number of bytes.  For example, promote
3093     // TRUNCSTORE:i1 X -> TRUNCSTORE:i8 (and X, 1)
3094     LLT WideTy = LLT::scalar(StoreSizeInBits);
3095 
3096     if (StoreSizeInBits > SrcTy.getSizeInBits()) {
3097       // Avoid creating a store with a narrower source than result.
3098       SrcReg = MIRBuilder.buildAnyExt(WideTy, SrcReg).getReg(0);
3099       SrcTy = WideTy;
3100     }
3101 
3102     auto ZextInReg = MIRBuilder.buildZExtInReg(SrcTy, SrcReg, StoreWidth);
3103 
3104     MachineMemOperand *NewMMO =
3105         MF.getMachineMemOperand(&MMO, MMO.getPointerInfo(), WideTy);
3106     MIRBuilder.buildStore(ZextInReg, PtrReg, *NewMMO);
3107     StoreMI.eraseFromParent();
3108     return Legalized;
3109   }
3110 
3111   if (MemTy.isVector()) {
3112     // TODO: Handle vector trunc stores
3113     if (MemTy != SrcTy)
3114       return UnableToLegalize;
3115 
3116     // TODO: We can do better than scalarizing the vector and at least split it
3117     // in half.
3118     return reduceLoadStoreWidth(StoreMI, 0, SrcTy.getElementType());
3119   }
3120 
3121   unsigned MemSizeInBits = MemTy.getSizeInBits();
3122   uint64_t LargeSplitSize, SmallSplitSize;
3123 
3124   if (!isPowerOf2_32(MemSizeInBits)) {
3125     LargeSplitSize = PowerOf2Floor(MemTy.getSizeInBits());
3126     SmallSplitSize = MemTy.getSizeInBits() - LargeSplitSize;
3127   } else {
3128     auto &Ctx = MF.getFunction().getContext();
3129     if (TLI.allowsMemoryAccess(Ctx, MIRBuilder.getDataLayout(), MemTy, MMO))
3130       return UnableToLegalize; // Don't know what we're being asked to do.
3131 
3132     SmallSplitSize = LargeSplitSize = MemSizeInBits / 2;
3133   }
3134 
3135   // Extend to the next pow-2. If this store was itself the result of lowering,
3136   // e.g. an s56 store being broken into s32 + s24, we might have a stored type
3137   // that's wider than the stored size.
3138   unsigned AnyExtSize = PowerOf2Ceil(MemTy.getSizeInBits());
3139   const LLT NewSrcTy = LLT::scalar(AnyExtSize);
3140 
3141   if (SrcTy.isPointer()) {
3142     const LLT IntPtrTy = LLT::scalar(SrcTy.getSizeInBits());
3143     SrcReg = MIRBuilder.buildPtrToInt(IntPtrTy, SrcReg).getReg(0);
3144   }
3145 
3146   auto ExtVal = MIRBuilder.buildAnyExtOrTrunc(NewSrcTy, SrcReg);
3147 
3148   // Obtain the smaller value by shifting away the larger value.
3149   auto ShiftAmt = MIRBuilder.buildConstant(NewSrcTy, LargeSplitSize);
3150   auto SmallVal = MIRBuilder.buildLShr(NewSrcTy, ExtVal, ShiftAmt);
3151 
3152   // Generate the PtrAdd and truncating stores.
3153   LLT PtrTy = MRI.getType(PtrReg);
3154   auto OffsetCst = MIRBuilder.buildConstant(
3155     LLT::scalar(PtrTy.getSizeInBits()), LargeSplitSize / 8);
3156   auto SmallPtr =
3157     MIRBuilder.buildPtrAdd(PtrTy, PtrReg, OffsetCst);
3158 
3159   MachineMemOperand *LargeMMO =
3160     MF.getMachineMemOperand(&MMO, 0, LargeSplitSize / 8);
3161   MachineMemOperand *SmallMMO =
3162     MF.getMachineMemOperand(&MMO, LargeSplitSize / 8, SmallSplitSize / 8);
3163   MIRBuilder.buildStore(ExtVal, PtrReg, *LargeMMO);
3164   MIRBuilder.buildStore(SmallVal, SmallPtr, *SmallMMO);
3165   StoreMI.eraseFromParent();
3166   return Legalized;
3167 }
3168 
3169 LegalizerHelper::LegalizeResult
bitcast(MachineInstr & MI,unsigned TypeIdx,LLT CastTy)3170 LegalizerHelper::bitcast(MachineInstr &MI, unsigned TypeIdx, LLT CastTy) {
3171   switch (MI.getOpcode()) {
3172   case TargetOpcode::G_LOAD: {
3173     if (TypeIdx != 0)
3174       return UnableToLegalize;
3175     MachineMemOperand &MMO = **MI.memoperands_begin();
3176 
3177     // Not sure how to interpret a bitcast of an extending load.
3178     if (MMO.getMemoryType().getSizeInBits() != CastTy.getSizeInBits())
3179       return UnableToLegalize;
3180 
3181     Observer.changingInstr(MI);
3182     bitcastDst(MI, CastTy, 0);
3183     MMO.setType(CastTy);
3184     Observer.changedInstr(MI);
3185     return Legalized;
3186   }
3187   case TargetOpcode::G_STORE: {
3188     if (TypeIdx != 0)
3189       return UnableToLegalize;
3190 
3191     MachineMemOperand &MMO = **MI.memoperands_begin();
3192 
3193     // Not sure how to interpret a bitcast of a truncating store.
3194     if (MMO.getMemoryType().getSizeInBits() != CastTy.getSizeInBits())
3195       return UnableToLegalize;
3196 
3197     Observer.changingInstr(MI);
3198     bitcastSrc(MI, CastTy, 0);
3199     MMO.setType(CastTy);
3200     Observer.changedInstr(MI);
3201     return Legalized;
3202   }
3203   case TargetOpcode::G_SELECT: {
3204     if (TypeIdx != 0)
3205       return UnableToLegalize;
3206 
3207     if (MRI.getType(MI.getOperand(1).getReg()).isVector()) {
3208       LLVM_DEBUG(
3209           dbgs() << "bitcast action not implemented for vector select\n");
3210       return UnableToLegalize;
3211     }
3212 
3213     Observer.changingInstr(MI);
3214     bitcastSrc(MI, CastTy, 2);
3215     bitcastSrc(MI, CastTy, 3);
3216     bitcastDst(MI, CastTy, 0);
3217     Observer.changedInstr(MI);
3218     return Legalized;
3219   }
3220   case TargetOpcode::G_AND:
3221   case TargetOpcode::G_OR:
3222   case TargetOpcode::G_XOR: {
3223     Observer.changingInstr(MI);
3224     bitcastSrc(MI, CastTy, 1);
3225     bitcastSrc(MI, CastTy, 2);
3226     bitcastDst(MI, CastTy, 0);
3227     Observer.changedInstr(MI);
3228     return Legalized;
3229   }
3230   case TargetOpcode::G_EXTRACT_VECTOR_ELT:
3231     return bitcastExtractVectorElt(MI, TypeIdx, CastTy);
3232   case TargetOpcode::G_INSERT_VECTOR_ELT:
3233     return bitcastInsertVectorElt(MI, TypeIdx, CastTy);
3234   default:
3235     return UnableToLegalize;
3236   }
3237 }
3238 
3239 // Legalize an instruction by changing the opcode in place.
changeOpcode(MachineInstr & MI,unsigned NewOpcode)3240 void LegalizerHelper::changeOpcode(MachineInstr &MI, unsigned NewOpcode) {
3241     Observer.changingInstr(MI);
3242     MI.setDesc(MIRBuilder.getTII().get(NewOpcode));
3243     Observer.changedInstr(MI);
3244 }
3245 
3246 LegalizerHelper::LegalizeResult
lower(MachineInstr & MI,unsigned TypeIdx,LLT LowerHintTy)3247 LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
3248   using namespace TargetOpcode;
3249 
3250   switch(MI.getOpcode()) {
3251   default:
3252     return UnableToLegalize;
3253   case TargetOpcode::G_BITCAST:
3254     return lowerBitcast(MI);
3255   case TargetOpcode::G_SREM:
3256   case TargetOpcode::G_UREM: {
3257     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
3258     auto Quot =
3259         MIRBuilder.buildInstr(MI.getOpcode() == G_SREM ? G_SDIV : G_UDIV, {Ty},
3260                               {MI.getOperand(1), MI.getOperand(2)});
3261 
3262     auto Prod = MIRBuilder.buildMul(Ty, Quot, MI.getOperand(2));
3263     MIRBuilder.buildSub(MI.getOperand(0), MI.getOperand(1), Prod);
3264     MI.eraseFromParent();
3265     return Legalized;
3266   }
3267   case TargetOpcode::G_SADDO:
3268   case TargetOpcode::G_SSUBO:
3269     return lowerSADDO_SSUBO(MI);
3270   case TargetOpcode::G_UMULH:
3271   case TargetOpcode::G_SMULH:
3272     return lowerSMULH_UMULH(MI);
3273   case TargetOpcode::G_SMULO:
3274   case TargetOpcode::G_UMULO: {
3275     // Generate G_UMULH/G_SMULH to check for overflow and a normal G_MUL for the
3276     // result.
3277     Register Res = MI.getOperand(0).getReg();
3278     Register Overflow = MI.getOperand(1).getReg();
3279     Register LHS = MI.getOperand(2).getReg();
3280     Register RHS = MI.getOperand(3).getReg();
3281     LLT Ty = MRI.getType(Res);
3282 
3283     unsigned Opcode = MI.getOpcode() == TargetOpcode::G_SMULO
3284                           ? TargetOpcode::G_SMULH
3285                           : TargetOpcode::G_UMULH;
3286 
3287     Observer.changingInstr(MI);
3288     const auto &TII = MIRBuilder.getTII();
3289     MI.setDesc(TII.get(TargetOpcode::G_MUL));
3290     MI.removeOperand(1);
3291     Observer.changedInstr(MI);
3292 
3293     auto HiPart = MIRBuilder.buildInstr(Opcode, {Ty}, {LHS, RHS});
3294     auto Zero = MIRBuilder.buildConstant(Ty, 0);
3295 
3296     // Move insert point forward so we can use the Res register if needed.
3297     MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt());
3298 
3299     // For *signed* multiply, overflow is detected by checking:
3300     // (hi != (lo >> bitwidth-1))
3301     if (Opcode == TargetOpcode::G_SMULH) {
3302       auto ShiftAmt = MIRBuilder.buildConstant(Ty, Ty.getSizeInBits() - 1);
3303       auto Shifted = MIRBuilder.buildAShr(Ty, Res, ShiftAmt);
3304       MIRBuilder.buildICmp(CmpInst::ICMP_NE, Overflow, HiPart, Shifted);
3305     } else {
3306       MIRBuilder.buildICmp(CmpInst::ICMP_NE, Overflow, HiPart, Zero);
3307     }
3308     return Legalized;
3309   }
3310   case TargetOpcode::G_FNEG: {
3311     Register Res = MI.getOperand(0).getReg();
3312     LLT Ty = MRI.getType(Res);
3313 
3314     // TODO: Handle vector types once we are able to
3315     // represent them.
3316     if (Ty.isVector())
3317       return UnableToLegalize;
3318     auto SignMask =
3319         MIRBuilder.buildConstant(Ty, APInt::getSignMask(Ty.getSizeInBits()));
3320     Register SubByReg = MI.getOperand(1).getReg();
3321     MIRBuilder.buildXor(Res, SubByReg, SignMask);
3322     MI.eraseFromParent();
3323     return Legalized;
3324   }
3325   case TargetOpcode::G_FSUB:
3326   case TargetOpcode::G_STRICT_FSUB: {
3327     Register Res = MI.getOperand(0).getReg();
3328     LLT Ty = MRI.getType(Res);
3329 
3330     // Lower (G_FSUB LHS, RHS) to (G_FADD LHS, (G_FNEG RHS)).
3331     // First, check if G_FNEG is marked as Lower. If so, we may
3332     // end up with an infinite loop as G_FSUB is used to legalize G_FNEG.
3333     if (LI.getAction({G_FNEG, {Ty}}).Action == Lower)
3334       return UnableToLegalize;
3335     Register LHS = MI.getOperand(1).getReg();
3336     Register RHS = MI.getOperand(2).getReg();
3337     auto Neg = MIRBuilder.buildFNeg(Ty, RHS);
3338 
3339     if (MI.getOpcode() == TargetOpcode::G_STRICT_FSUB)
3340       MIRBuilder.buildStrictFAdd(Res, LHS, Neg, MI.getFlags());
3341     else
3342       MIRBuilder.buildFAdd(Res, LHS, Neg, MI.getFlags());
3343 
3344     MI.eraseFromParent();
3345     return Legalized;
3346   }
3347   case TargetOpcode::G_FMAD:
3348     return lowerFMad(MI);
3349   case TargetOpcode::G_FFLOOR:
3350     return lowerFFloor(MI);
3351   case TargetOpcode::G_INTRINSIC_ROUND:
3352     return lowerIntrinsicRound(MI);
3353   case TargetOpcode::G_INTRINSIC_ROUNDEVEN: {
3354     // Since round even is the assumed rounding mode for unconstrained FP
3355     // operations, rint and roundeven are the same operation.
3356     changeOpcode(MI, TargetOpcode::G_FRINT);
3357     return Legalized;
3358   }
3359   case TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS: {
3360     Register OldValRes = MI.getOperand(0).getReg();
3361     Register SuccessRes = MI.getOperand(1).getReg();
3362     Register Addr = MI.getOperand(2).getReg();
3363     Register CmpVal = MI.getOperand(3).getReg();
3364     Register NewVal = MI.getOperand(4).getReg();
3365     MIRBuilder.buildAtomicCmpXchg(OldValRes, Addr, CmpVal, NewVal,
3366                                   **MI.memoperands_begin());
3367     MIRBuilder.buildICmp(CmpInst::ICMP_EQ, SuccessRes, OldValRes, CmpVal);
3368     MI.eraseFromParent();
3369     return Legalized;
3370   }
3371   case TargetOpcode::G_LOAD:
3372   case TargetOpcode::G_SEXTLOAD:
3373   case TargetOpcode::G_ZEXTLOAD:
3374     return lowerLoad(cast<GAnyLoad>(MI));
3375   case TargetOpcode::G_STORE:
3376     return lowerStore(cast<GStore>(MI));
3377   case TargetOpcode::G_CTLZ_ZERO_UNDEF:
3378   case TargetOpcode::G_CTTZ_ZERO_UNDEF:
3379   case TargetOpcode::G_CTLZ:
3380   case TargetOpcode::G_CTTZ:
3381   case TargetOpcode::G_CTPOP:
3382     return lowerBitCount(MI);
3383   case G_UADDO: {
3384     Register Res = MI.getOperand(0).getReg();
3385     Register CarryOut = MI.getOperand(1).getReg();
3386     Register LHS = MI.getOperand(2).getReg();
3387     Register RHS = MI.getOperand(3).getReg();
3388 
3389     MIRBuilder.buildAdd(Res, LHS, RHS);
3390     MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CarryOut, Res, RHS);
3391 
3392     MI.eraseFromParent();
3393     return Legalized;
3394   }
3395   case G_UADDE: {
3396     Register Res = MI.getOperand(0).getReg();
3397     Register CarryOut = MI.getOperand(1).getReg();
3398     Register LHS = MI.getOperand(2).getReg();
3399     Register RHS = MI.getOperand(3).getReg();
3400     Register CarryIn = MI.getOperand(4).getReg();
3401     LLT Ty = MRI.getType(Res);
3402 
3403     auto TmpRes = MIRBuilder.buildAdd(Ty, LHS, RHS);
3404     auto ZExtCarryIn = MIRBuilder.buildZExt(Ty, CarryIn);
3405     MIRBuilder.buildAdd(Res, TmpRes, ZExtCarryIn);
3406     MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CarryOut, Res, LHS);
3407 
3408     MI.eraseFromParent();
3409     return Legalized;
3410   }
3411   case G_USUBO: {
3412     Register Res = MI.getOperand(0).getReg();
3413     Register BorrowOut = MI.getOperand(1).getReg();
3414     Register LHS = MI.getOperand(2).getReg();
3415     Register RHS = MI.getOperand(3).getReg();
3416 
3417     MIRBuilder.buildSub(Res, LHS, RHS);
3418     MIRBuilder.buildICmp(CmpInst::ICMP_ULT, BorrowOut, LHS, RHS);
3419 
3420     MI.eraseFromParent();
3421     return Legalized;
3422   }
3423   case G_USUBE: {
3424     Register Res = MI.getOperand(0).getReg();
3425     Register BorrowOut = MI.getOperand(1).getReg();
3426     Register LHS = MI.getOperand(2).getReg();
3427     Register RHS = MI.getOperand(3).getReg();
3428     Register BorrowIn = MI.getOperand(4).getReg();
3429     const LLT CondTy = MRI.getType(BorrowOut);
3430     const LLT Ty = MRI.getType(Res);
3431 
3432     auto TmpRes = MIRBuilder.buildSub(Ty, LHS, RHS);
3433     auto ZExtBorrowIn = MIRBuilder.buildZExt(Ty, BorrowIn);
3434     MIRBuilder.buildSub(Res, TmpRes, ZExtBorrowIn);
3435 
3436     auto LHS_EQ_RHS = MIRBuilder.buildICmp(CmpInst::ICMP_EQ, CondTy, LHS, RHS);
3437     auto LHS_ULT_RHS = MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CondTy, LHS, RHS);
3438     MIRBuilder.buildSelect(BorrowOut, LHS_EQ_RHS, BorrowIn, LHS_ULT_RHS);
3439 
3440     MI.eraseFromParent();
3441     return Legalized;
3442   }
3443   case G_UITOFP:
3444     return lowerUITOFP(MI);
3445   case G_SITOFP:
3446     return lowerSITOFP(MI);
3447   case G_FPTOUI:
3448     return lowerFPTOUI(MI);
3449   case G_FPTOSI:
3450     return lowerFPTOSI(MI);
3451   case G_FPTRUNC:
3452     return lowerFPTRUNC(MI);
3453   case G_FPOWI:
3454     return lowerFPOWI(MI);
3455   case G_SMIN:
3456   case G_SMAX:
3457   case G_UMIN:
3458   case G_UMAX:
3459     return lowerMinMax(MI);
3460   case G_FCOPYSIGN:
3461     return lowerFCopySign(MI);
3462   case G_FMINNUM:
3463   case G_FMAXNUM:
3464     return lowerFMinNumMaxNum(MI);
3465   case G_MERGE_VALUES:
3466     return lowerMergeValues(MI);
3467   case G_UNMERGE_VALUES:
3468     return lowerUnmergeValues(MI);
3469   case TargetOpcode::G_SEXT_INREG: {
3470     assert(MI.getOperand(2).isImm() && "Expected immediate");
3471     int64_t SizeInBits = MI.getOperand(2).getImm();
3472 
3473     Register DstReg = MI.getOperand(0).getReg();
3474     Register SrcReg = MI.getOperand(1).getReg();
3475     LLT DstTy = MRI.getType(DstReg);
3476     Register TmpRes = MRI.createGenericVirtualRegister(DstTy);
3477 
3478     auto MIBSz = MIRBuilder.buildConstant(DstTy, DstTy.getScalarSizeInBits() - SizeInBits);
3479     MIRBuilder.buildShl(TmpRes, SrcReg, MIBSz->getOperand(0));
3480     MIRBuilder.buildAShr(DstReg, TmpRes, MIBSz->getOperand(0));
3481     MI.eraseFromParent();
3482     return Legalized;
3483   }
3484   case G_EXTRACT_VECTOR_ELT:
3485   case G_INSERT_VECTOR_ELT:
3486     return lowerExtractInsertVectorElt(MI);
3487   case G_SHUFFLE_VECTOR:
3488     return lowerShuffleVector(MI);
3489   case G_DYN_STACKALLOC:
3490     return lowerDynStackAlloc(MI);
3491   case G_EXTRACT:
3492     return lowerExtract(MI);
3493   case G_INSERT:
3494     return lowerInsert(MI);
3495   case G_BSWAP:
3496     return lowerBswap(MI);
3497   case G_BITREVERSE:
3498     return lowerBitreverse(MI);
3499   case G_READ_REGISTER:
3500   case G_WRITE_REGISTER:
3501     return lowerReadWriteRegister(MI);
3502   case G_UADDSAT:
3503   case G_USUBSAT: {
3504     // Try to make a reasonable guess about which lowering strategy to use. The
3505     // target can override this with custom lowering and calling the
3506     // implementation functions.
3507     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
3508     if (LI.isLegalOrCustom({G_UMIN, Ty}))
3509       return lowerAddSubSatToMinMax(MI);
3510     return lowerAddSubSatToAddoSubo(MI);
3511   }
3512   case G_SADDSAT:
3513   case G_SSUBSAT: {
3514     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
3515 
3516     // FIXME: It would probably make more sense to see if G_SADDO is preferred,
3517     // since it's a shorter expansion. However, we would need to figure out the
3518     // preferred boolean type for the carry out for the query.
3519     if (LI.isLegalOrCustom({G_SMIN, Ty}) && LI.isLegalOrCustom({G_SMAX, Ty}))
3520       return lowerAddSubSatToMinMax(MI);
3521     return lowerAddSubSatToAddoSubo(MI);
3522   }
3523   case G_SSHLSAT:
3524   case G_USHLSAT:
3525     return lowerShlSat(MI);
3526   case G_ABS:
3527     return lowerAbsToAddXor(MI);
3528   case G_SELECT:
3529     return lowerSelect(MI);
3530   case G_IS_FPCLASS:
3531     return lowerISFPCLASS(MI);
3532   case G_SDIVREM:
3533   case G_UDIVREM:
3534     return lowerDIVREM(MI);
3535   case G_FSHL:
3536   case G_FSHR:
3537     return lowerFunnelShift(MI);
3538   case G_ROTL:
3539   case G_ROTR:
3540     return lowerRotate(MI);
3541   case G_MEMSET:
3542   case G_MEMCPY:
3543   case G_MEMMOVE:
3544     return lowerMemCpyFamily(MI);
3545   case G_MEMCPY_INLINE:
3546     return lowerMemcpyInline(MI);
3547   GISEL_VECREDUCE_CASES_NONSEQ
3548     return lowerVectorReduction(MI);
3549   }
3550 }
3551 
getStackTemporaryAlignment(LLT Ty,Align MinAlign) const3552 Align LegalizerHelper::getStackTemporaryAlignment(LLT Ty,
3553                                                   Align MinAlign) const {
3554   // FIXME: We're missing a way to go back from LLT to llvm::Type to query the
3555   // datalayout for the preferred alignment. Also there should be a target hook
3556   // for this to allow targets to reduce the alignment and ignore the
3557   // datalayout. e.g. AMDGPU should always use a 4-byte alignment, regardless of
3558   // the type.
3559   return std::max(Align(PowerOf2Ceil(Ty.getSizeInBytes())), MinAlign);
3560 }
3561 
3562 MachineInstrBuilder
createStackTemporary(TypeSize Bytes,Align Alignment,MachinePointerInfo & PtrInfo)3563 LegalizerHelper::createStackTemporary(TypeSize Bytes, Align Alignment,
3564                                       MachinePointerInfo &PtrInfo) {
3565   MachineFunction &MF = MIRBuilder.getMF();
3566   const DataLayout &DL = MIRBuilder.getDataLayout();
3567   int FrameIdx = MF.getFrameInfo().CreateStackObject(Bytes, Alignment, false);
3568 
3569   unsigned AddrSpace = DL.getAllocaAddrSpace();
3570   LLT FramePtrTy = LLT::pointer(AddrSpace, DL.getPointerSizeInBits(AddrSpace));
3571 
3572   PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIdx);
3573   return MIRBuilder.buildFrameIndex(FramePtrTy, FrameIdx);
3574 }
3575 
clampDynamicVectorIndex(MachineIRBuilder & B,Register IdxReg,LLT VecTy)3576 static Register clampDynamicVectorIndex(MachineIRBuilder &B, Register IdxReg,
3577                                         LLT VecTy) {
3578   int64_t IdxVal;
3579   if (mi_match(IdxReg, *B.getMRI(), m_ICst(IdxVal)))
3580     return IdxReg;
3581 
3582   LLT IdxTy = B.getMRI()->getType(IdxReg);
3583   unsigned NElts = VecTy.getNumElements();
3584   if (isPowerOf2_32(NElts)) {
3585     APInt Imm = APInt::getLowBitsSet(IdxTy.getSizeInBits(), Log2_32(NElts));
3586     return B.buildAnd(IdxTy, IdxReg, B.buildConstant(IdxTy, Imm)).getReg(0);
3587   }
3588 
3589   return B.buildUMin(IdxTy, IdxReg, B.buildConstant(IdxTy, NElts - 1))
3590       .getReg(0);
3591 }
3592 
getVectorElementPointer(Register VecPtr,LLT VecTy,Register Index)3593 Register LegalizerHelper::getVectorElementPointer(Register VecPtr, LLT VecTy,
3594                                                   Register Index) {
3595   LLT EltTy = VecTy.getElementType();
3596 
3597   // Calculate the element offset and add it to the pointer.
3598   unsigned EltSize = EltTy.getSizeInBits() / 8; // FIXME: should be ABI size.
3599   assert(EltSize * 8 == EltTy.getSizeInBits() &&
3600          "Converting bits to bytes lost precision");
3601 
3602   Index = clampDynamicVectorIndex(MIRBuilder, Index, VecTy);
3603 
3604   LLT IdxTy = MRI.getType(Index);
3605   auto Mul = MIRBuilder.buildMul(IdxTy, Index,
3606                                  MIRBuilder.buildConstant(IdxTy, EltSize));
3607 
3608   LLT PtrTy = MRI.getType(VecPtr);
3609   return MIRBuilder.buildPtrAdd(PtrTy, VecPtr, Mul).getReg(0);
3610 }
3611 
3612 #ifndef NDEBUG
3613 /// Check that all vector operands have same number of elements. Other operands
3614 /// should be listed in NonVecOp.
hasSameNumEltsOnAllVectorOperands(GenericMachineInstr & MI,MachineRegisterInfo & MRI,std::initializer_list<unsigned> NonVecOpIndices)3615 static bool hasSameNumEltsOnAllVectorOperands(
3616     GenericMachineInstr &MI, MachineRegisterInfo &MRI,
3617     std::initializer_list<unsigned> NonVecOpIndices) {
3618   if (MI.getNumMemOperands() != 0)
3619     return false;
3620 
3621   LLT VecTy = MRI.getType(MI.getReg(0));
3622   if (!VecTy.isVector())
3623     return false;
3624   unsigned NumElts = VecTy.getNumElements();
3625 
3626   for (unsigned OpIdx = 1; OpIdx < MI.getNumOperands(); ++OpIdx) {
3627     MachineOperand &Op = MI.getOperand(OpIdx);
3628     if (!Op.isReg()) {
3629       if (!is_contained(NonVecOpIndices, OpIdx))
3630         return false;
3631       continue;
3632     }
3633 
3634     LLT Ty = MRI.getType(Op.getReg());
3635     if (!Ty.isVector()) {
3636       if (!is_contained(NonVecOpIndices, OpIdx))
3637         return false;
3638       continue;
3639     }
3640 
3641     if (Ty.getNumElements() != NumElts)
3642       return false;
3643   }
3644 
3645   return true;
3646 }
3647 #endif
3648 
3649 /// Fill \p DstOps with DstOps that have same number of elements combined as
3650 /// the Ty. These DstOps have either scalar type when \p NumElts = 1 or are
3651 /// vectors with \p NumElts elements. When Ty.getNumElements() is not multiple
3652 /// of \p NumElts last DstOp (leftover) has fewer then \p NumElts elements.
makeDstOps(SmallVectorImpl<DstOp> & DstOps,LLT Ty,unsigned NumElts)3653 static void makeDstOps(SmallVectorImpl<DstOp> &DstOps, LLT Ty,
3654                        unsigned NumElts) {
3655   LLT LeftoverTy;
3656   assert(Ty.isVector() && "Expected vector type");
3657   LLT EltTy = Ty.getElementType();
3658   LLT NarrowTy = (NumElts == 1) ? EltTy : LLT::fixed_vector(NumElts, EltTy);
3659   int NumParts, NumLeftover;
3660   std::tie(NumParts, NumLeftover) =
3661       getNarrowTypeBreakDown(Ty, NarrowTy, LeftoverTy);
3662 
3663   assert(NumParts > 0 && "Error in getNarrowTypeBreakDown");
3664   for (int i = 0; i < NumParts; ++i) {
3665     DstOps.push_back(NarrowTy);
3666   }
3667 
3668   if (LeftoverTy.isValid()) {
3669     assert(NumLeftover == 1 && "expected exactly one leftover");
3670     DstOps.push_back(LeftoverTy);
3671   }
3672 }
3673 
3674 /// Operand \p Op is used on \p N sub-instructions. Fill \p Ops with \p N SrcOps
3675 /// made from \p Op depending on operand type.
broadcastSrcOp(SmallVectorImpl<SrcOp> & Ops,unsigned N,MachineOperand & Op)3676 static void broadcastSrcOp(SmallVectorImpl<SrcOp> &Ops, unsigned N,
3677                            MachineOperand &Op) {
3678   for (unsigned i = 0; i < N; ++i) {
3679     if (Op.isReg())
3680       Ops.push_back(Op.getReg());
3681     else if (Op.isImm())
3682       Ops.push_back(Op.getImm());
3683     else if (Op.isPredicate())
3684       Ops.push_back(static_cast<CmpInst::Predicate>(Op.getPredicate()));
3685     else
3686       llvm_unreachable("Unsupported type");
3687   }
3688 }
3689 
3690 // Handle splitting vector operations which need to have the same number of
3691 // elements in each type index, but each type index may have a different element
3692 // type.
3693 //
3694 // e.g.  <4 x s64> = G_SHL <4 x s64>, <4 x s32> ->
3695 //       <2 x s64> = G_SHL <2 x s64>, <2 x s32>
3696 //       <2 x s64> = G_SHL <2 x s64>, <2 x s32>
3697 //
3698 // Also handles some irregular breakdown cases, e.g.
3699 // e.g.  <3 x s64> = G_SHL <3 x s64>, <3 x s32> ->
3700 //       <2 x s64> = G_SHL <2 x s64>, <2 x s32>
3701 //             s64 = G_SHL s64, s32
3702 LegalizerHelper::LegalizeResult
fewerElementsVectorMultiEltType(GenericMachineInstr & MI,unsigned NumElts,std::initializer_list<unsigned> NonVecOpIndices)3703 LegalizerHelper::fewerElementsVectorMultiEltType(
3704     GenericMachineInstr &MI, unsigned NumElts,
3705     std::initializer_list<unsigned> NonVecOpIndices) {
3706   assert(hasSameNumEltsOnAllVectorOperands(MI, MRI, NonVecOpIndices) &&
3707          "Non-compatible opcode or not specified non-vector operands");
3708   unsigned OrigNumElts = MRI.getType(MI.getReg(0)).getNumElements();
3709 
3710   unsigned NumInputs = MI.getNumOperands() - MI.getNumDefs();
3711   unsigned NumDefs = MI.getNumDefs();
3712 
3713   // Create DstOps (sub-vectors with NumElts elts + Leftover) for each output.
3714   // Build instructions with DstOps to use instruction found by CSE directly.
3715   // CSE copies found instruction into given vreg when building with vreg dest.
3716   SmallVector<SmallVector<DstOp, 8>, 2> OutputOpsPieces(NumDefs);
3717   // Output registers will be taken from created instructions.
3718   SmallVector<SmallVector<Register, 8>, 2> OutputRegs(NumDefs);
3719   for (unsigned i = 0; i < NumDefs; ++i) {
3720     makeDstOps(OutputOpsPieces[i], MRI.getType(MI.getReg(i)), NumElts);
3721   }
3722 
3723   // Split vector input operands into sub-vectors with NumElts elts + Leftover.
3724   // Operands listed in NonVecOpIndices will be used as is without splitting;
3725   // examples: compare predicate in icmp and fcmp (op 1), vector select with i1
3726   // scalar condition (op 1), immediate in sext_inreg (op 2).
3727   SmallVector<SmallVector<SrcOp, 8>, 3> InputOpsPieces(NumInputs);
3728   for (unsigned UseIdx = NumDefs, UseNo = 0; UseIdx < MI.getNumOperands();
3729        ++UseIdx, ++UseNo) {
3730     if (is_contained(NonVecOpIndices, UseIdx)) {
3731       broadcastSrcOp(InputOpsPieces[UseNo], OutputOpsPieces[0].size(),
3732                      MI.getOperand(UseIdx));
3733     } else {
3734       SmallVector<Register, 8> SplitPieces;
3735       extractVectorParts(MI.getReg(UseIdx), NumElts, SplitPieces);
3736       for (auto Reg : SplitPieces)
3737         InputOpsPieces[UseNo].push_back(Reg);
3738     }
3739   }
3740 
3741   unsigned NumLeftovers = OrigNumElts % NumElts ? 1 : 0;
3742 
3743   // Take i-th piece of each input operand split and build sub-vector/scalar
3744   // instruction. Set i-th DstOp(s) from OutputOpsPieces as destination(s).
3745   for (unsigned i = 0; i < OrigNumElts / NumElts + NumLeftovers; ++i) {
3746     SmallVector<DstOp, 2> Defs;
3747     for (unsigned DstNo = 0; DstNo < NumDefs; ++DstNo)
3748       Defs.push_back(OutputOpsPieces[DstNo][i]);
3749 
3750     SmallVector<SrcOp, 3> Uses;
3751     for (unsigned InputNo = 0; InputNo < NumInputs; ++InputNo)
3752       Uses.push_back(InputOpsPieces[InputNo][i]);
3753 
3754     auto I = MIRBuilder.buildInstr(MI.getOpcode(), Defs, Uses, MI.getFlags());
3755     for (unsigned DstNo = 0; DstNo < NumDefs; ++DstNo)
3756       OutputRegs[DstNo].push_back(I.getReg(DstNo));
3757   }
3758 
3759   // Merge small outputs into MI's output for each def operand.
3760   if (NumLeftovers) {
3761     for (unsigned i = 0; i < NumDefs; ++i)
3762       mergeMixedSubvectors(MI.getReg(i), OutputRegs[i]);
3763   } else {
3764     for (unsigned i = 0; i < NumDefs; ++i)
3765       MIRBuilder.buildMergeLikeInstr(MI.getReg(i), OutputRegs[i]);
3766   }
3767 
3768   MI.eraseFromParent();
3769   return Legalized;
3770 }
3771 
3772 LegalizerHelper::LegalizeResult
fewerElementsVectorPhi(GenericMachineInstr & MI,unsigned NumElts)3773 LegalizerHelper::fewerElementsVectorPhi(GenericMachineInstr &MI,
3774                                         unsigned NumElts) {
3775   unsigned OrigNumElts = MRI.getType(MI.getReg(0)).getNumElements();
3776 
3777   unsigned NumInputs = MI.getNumOperands() - MI.getNumDefs();
3778   unsigned NumDefs = MI.getNumDefs();
3779 
3780   SmallVector<DstOp, 8> OutputOpsPieces;
3781   SmallVector<Register, 8> OutputRegs;
3782   makeDstOps(OutputOpsPieces, MRI.getType(MI.getReg(0)), NumElts);
3783 
3784   // Instructions that perform register split will be inserted in basic block
3785   // where register is defined (basic block is in the next operand).
3786   SmallVector<SmallVector<Register, 8>, 3> InputOpsPieces(NumInputs / 2);
3787   for (unsigned UseIdx = NumDefs, UseNo = 0; UseIdx < MI.getNumOperands();
3788        UseIdx += 2, ++UseNo) {
3789     MachineBasicBlock &OpMBB = *MI.getOperand(UseIdx + 1).getMBB();
3790     MIRBuilder.setInsertPt(OpMBB, OpMBB.getFirstTerminatorForward());
3791     extractVectorParts(MI.getReg(UseIdx), NumElts, InputOpsPieces[UseNo]);
3792   }
3793 
3794   // Build PHIs with fewer elements.
3795   unsigned NumLeftovers = OrigNumElts % NumElts ? 1 : 0;
3796   MIRBuilder.setInsertPt(*MI.getParent(), MI);
3797   for (unsigned i = 0; i < OrigNumElts / NumElts + NumLeftovers; ++i) {
3798     auto Phi = MIRBuilder.buildInstr(TargetOpcode::G_PHI);
3799     Phi.addDef(
3800         MRI.createGenericVirtualRegister(OutputOpsPieces[i].getLLTTy(MRI)));
3801     OutputRegs.push_back(Phi.getReg(0));
3802 
3803     for (unsigned j = 0; j < NumInputs / 2; ++j) {
3804       Phi.addUse(InputOpsPieces[j][i]);
3805       Phi.add(MI.getOperand(1 + j * 2 + 1));
3806     }
3807   }
3808 
3809   // Merge small outputs into MI's def.
3810   if (NumLeftovers) {
3811     mergeMixedSubvectors(MI.getReg(0), OutputRegs);
3812   } else {
3813     MIRBuilder.buildMergeLikeInstr(MI.getReg(0), OutputRegs);
3814   }
3815 
3816   MI.eraseFromParent();
3817   return Legalized;
3818 }
3819 
3820 LegalizerHelper::LegalizeResult
fewerElementsVectorUnmergeValues(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)3821 LegalizerHelper::fewerElementsVectorUnmergeValues(MachineInstr &MI,
3822                                                   unsigned TypeIdx,
3823                                                   LLT NarrowTy) {
3824   const int NumDst = MI.getNumOperands() - 1;
3825   const Register SrcReg = MI.getOperand(NumDst).getReg();
3826   LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
3827   LLT SrcTy = MRI.getType(SrcReg);
3828 
3829   if (TypeIdx != 1 || NarrowTy == DstTy)
3830     return UnableToLegalize;
3831 
3832   // Requires compatible types. Otherwise SrcReg should have been defined by
3833   // merge-like instruction that would get artifact combined. Most likely
3834   // instruction that defines SrcReg has to perform more/fewer elements
3835   // legalization compatible with NarrowTy.
3836   assert(SrcTy.isVector() && NarrowTy.isVector() && "Expected vector types");
3837   assert((SrcTy.getScalarType() == NarrowTy.getScalarType()) && "bad type");
3838 
3839   if ((SrcTy.getSizeInBits() % NarrowTy.getSizeInBits() != 0) ||
3840       (NarrowTy.getSizeInBits() % DstTy.getSizeInBits() != 0))
3841     return UnableToLegalize;
3842 
3843   // This is most likely DstTy (smaller then register size) packed in SrcTy
3844   // (larger then register size) and since unmerge was not combined it will be
3845   // lowered to bit sequence extracts from register. Unpack SrcTy to NarrowTy
3846   // (register size) pieces first. Then unpack each of NarrowTy pieces to DstTy.
3847 
3848   // %1:_(DstTy), %2, %3, %4 = G_UNMERGE_VALUES %0:_(SrcTy)
3849   //
3850   // %5:_(NarrowTy), %6 = G_UNMERGE_VALUES %0:_(SrcTy) - reg sequence
3851   // %1:_(DstTy), %2 = G_UNMERGE_VALUES %5:_(NarrowTy) - sequence of bits in reg
3852   // %3:_(DstTy), %4 = G_UNMERGE_VALUES %6:_(NarrowTy)
3853   auto Unmerge = MIRBuilder.buildUnmerge(NarrowTy, SrcReg);
3854   const int NumUnmerge = Unmerge->getNumOperands() - 1;
3855   const int PartsPerUnmerge = NumDst / NumUnmerge;
3856 
3857   for (int I = 0; I != NumUnmerge; ++I) {
3858     auto MIB = MIRBuilder.buildInstr(TargetOpcode::G_UNMERGE_VALUES);
3859 
3860     for (int J = 0; J != PartsPerUnmerge; ++J)
3861       MIB.addDef(MI.getOperand(I * PartsPerUnmerge + J).getReg());
3862     MIB.addUse(Unmerge.getReg(I));
3863   }
3864 
3865   MI.eraseFromParent();
3866   return Legalized;
3867 }
3868 
3869 LegalizerHelper::LegalizeResult
fewerElementsVectorMerge(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)3870 LegalizerHelper::fewerElementsVectorMerge(MachineInstr &MI, unsigned TypeIdx,
3871                                           LLT NarrowTy) {
3872   Register DstReg = MI.getOperand(0).getReg();
3873   LLT DstTy = MRI.getType(DstReg);
3874   LLT SrcTy = MRI.getType(MI.getOperand(1).getReg());
3875   // Requires compatible types. Otherwise user of DstReg did not perform unmerge
3876   // that should have been artifact combined. Most likely instruction that uses
3877   // DstReg has to do more/fewer elements legalization compatible with NarrowTy.
3878   assert(DstTy.isVector() && NarrowTy.isVector() && "Expected vector types");
3879   assert((DstTy.getScalarType() == NarrowTy.getScalarType()) && "bad type");
3880   if (NarrowTy == SrcTy)
3881     return UnableToLegalize;
3882 
3883   // This attempts to lower part of LCMTy merge/unmerge sequence. Intended use
3884   // is for old mir tests. Since the changes to more/fewer elements it should no
3885   // longer be possible to generate MIR like this when starting from llvm-ir
3886   // because LCMTy approach was replaced with merge/unmerge to vector elements.
3887   if (TypeIdx == 1) {
3888     assert(SrcTy.isVector() && "Expected vector types");
3889     assert((SrcTy.getScalarType() == NarrowTy.getScalarType()) && "bad type");
3890     if ((DstTy.getSizeInBits() % NarrowTy.getSizeInBits() != 0) ||
3891         (NarrowTy.getNumElements() >= SrcTy.getNumElements()))
3892       return UnableToLegalize;
3893     // %2:_(DstTy) = G_CONCAT_VECTORS %0:_(SrcTy), %1:_(SrcTy)
3894     //
3895     // %3:_(EltTy), %4, %5 = G_UNMERGE_VALUES %0:_(SrcTy)
3896     // %6:_(EltTy), %7, %8 = G_UNMERGE_VALUES %1:_(SrcTy)
3897     // %9:_(NarrowTy) = G_BUILD_VECTOR %3:_(EltTy), %4
3898     // %10:_(NarrowTy) = G_BUILD_VECTOR %5:_(EltTy), %6
3899     // %11:_(NarrowTy) = G_BUILD_VECTOR %7:_(EltTy), %8
3900     // %2:_(DstTy) = G_CONCAT_VECTORS %9:_(NarrowTy), %10, %11
3901 
3902     SmallVector<Register, 8> Elts;
3903     LLT EltTy = MRI.getType(MI.getOperand(1).getReg()).getScalarType();
3904     for (unsigned i = 1; i < MI.getNumOperands(); ++i) {
3905       auto Unmerge = MIRBuilder.buildUnmerge(EltTy, MI.getOperand(i).getReg());
3906       for (unsigned j = 0; j < Unmerge->getNumDefs(); ++j)
3907         Elts.push_back(Unmerge.getReg(j));
3908     }
3909 
3910     SmallVector<Register, 8> NarrowTyElts;
3911     unsigned NumNarrowTyElts = NarrowTy.getNumElements();
3912     unsigned NumNarrowTyPieces = DstTy.getNumElements() / NumNarrowTyElts;
3913     for (unsigned i = 0, Offset = 0; i < NumNarrowTyPieces;
3914          ++i, Offset += NumNarrowTyElts) {
3915       ArrayRef<Register> Pieces(&Elts[Offset], NumNarrowTyElts);
3916       NarrowTyElts.push_back(
3917           MIRBuilder.buildMergeLikeInstr(NarrowTy, Pieces).getReg(0));
3918     }
3919 
3920     MIRBuilder.buildMergeLikeInstr(DstReg, NarrowTyElts);
3921     MI.eraseFromParent();
3922     return Legalized;
3923   }
3924 
3925   assert(TypeIdx == 0 && "Bad type index");
3926   if ((NarrowTy.getSizeInBits() % SrcTy.getSizeInBits() != 0) ||
3927       (DstTy.getSizeInBits() % NarrowTy.getSizeInBits() != 0))
3928     return UnableToLegalize;
3929 
3930   // This is most likely SrcTy (smaller then register size) packed in DstTy
3931   // (larger then register size) and since merge was not combined it will be
3932   // lowered to bit sequence packing into register. Merge SrcTy to NarrowTy
3933   // (register size) pieces first. Then merge each of NarrowTy pieces to DstTy.
3934 
3935   // %0:_(DstTy) = G_MERGE_VALUES %1:_(SrcTy), %2, %3, %4
3936   //
3937   // %5:_(NarrowTy) = G_MERGE_VALUES %1:_(SrcTy), %2 - sequence of bits in reg
3938   // %6:_(NarrowTy) = G_MERGE_VALUES %3:_(SrcTy), %4
3939   // %0:_(DstTy)  = G_MERGE_VALUES %5:_(NarrowTy), %6 - reg sequence
3940   SmallVector<Register, 8> NarrowTyElts;
3941   unsigned NumParts = DstTy.getNumElements() / NarrowTy.getNumElements();
3942   unsigned NumSrcElts = SrcTy.isVector() ? SrcTy.getNumElements() : 1;
3943   unsigned NumElts = NarrowTy.getNumElements() / NumSrcElts;
3944   for (unsigned i = 0; i < NumParts; ++i) {
3945     SmallVector<Register, 8> Sources;
3946     for (unsigned j = 0; j < NumElts; ++j)
3947       Sources.push_back(MI.getOperand(1 + i * NumElts + j).getReg());
3948     NarrowTyElts.push_back(
3949         MIRBuilder.buildMergeLikeInstr(NarrowTy, Sources).getReg(0));
3950   }
3951 
3952   MIRBuilder.buildMergeLikeInstr(DstReg, NarrowTyElts);
3953   MI.eraseFromParent();
3954   return Legalized;
3955 }
3956 
3957 LegalizerHelper::LegalizeResult
fewerElementsVectorExtractInsertVectorElt(MachineInstr & MI,unsigned TypeIdx,LLT NarrowVecTy)3958 LegalizerHelper::fewerElementsVectorExtractInsertVectorElt(MachineInstr &MI,
3959                                                            unsigned TypeIdx,
3960                                                            LLT NarrowVecTy) {
3961   Register DstReg = MI.getOperand(0).getReg();
3962   Register SrcVec = MI.getOperand(1).getReg();
3963   Register InsertVal;
3964   bool IsInsert = MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT;
3965 
3966   assert((IsInsert ? TypeIdx == 0 : TypeIdx == 1) && "not a vector type index");
3967   if (IsInsert)
3968     InsertVal = MI.getOperand(2).getReg();
3969 
3970   Register Idx = MI.getOperand(MI.getNumOperands() - 1).getReg();
3971 
3972   // TODO: Handle total scalarization case.
3973   if (!NarrowVecTy.isVector())
3974     return UnableToLegalize;
3975 
3976   LLT VecTy = MRI.getType(SrcVec);
3977 
3978   // If the index is a constant, we can really break this down as you would
3979   // expect, and index into the target size pieces.
3980   int64_t IdxVal;
3981   auto MaybeCst = getIConstantVRegValWithLookThrough(Idx, MRI);
3982   if (MaybeCst) {
3983     IdxVal = MaybeCst->Value.getSExtValue();
3984     // Avoid out of bounds indexing the pieces.
3985     if (IdxVal >= VecTy.getNumElements()) {
3986       MIRBuilder.buildUndef(DstReg);
3987       MI.eraseFromParent();
3988       return Legalized;
3989     }
3990 
3991     SmallVector<Register, 8> VecParts;
3992     LLT GCDTy = extractGCDType(VecParts, VecTy, NarrowVecTy, SrcVec);
3993 
3994     // Build a sequence of NarrowTy pieces in VecParts for this operand.
3995     LLT LCMTy = buildLCMMergePieces(VecTy, NarrowVecTy, GCDTy, VecParts,
3996                                     TargetOpcode::G_ANYEXT);
3997 
3998     unsigned NewNumElts = NarrowVecTy.getNumElements();
3999 
4000     LLT IdxTy = MRI.getType(Idx);
4001     int64_t PartIdx = IdxVal / NewNumElts;
4002     auto NewIdx =
4003         MIRBuilder.buildConstant(IdxTy, IdxVal - NewNumElts * PartIdx);
4004 
4005     if (IsInsert) {
4006       LLT PartTy = MRI.getType(VecParts[PartIdx]);
4007 
4008       // Use the adjusted index to insert into one of the subvectors.
4009       auto InsertPart = MIRBuilder.buildInsertVectorElement(
4010           PartTy, VecParts[PartIdx], InsertVal, NewIdx);
4011       VecParts[PartIdx] = InsertPart.getReg(0);
4012 
4013       // Recombine the inserted subvector with the others to reform the result
4014       // vector.
4015       buildWidenedRemergeToDst(DstReg, LCMTy, VecParts);
4016     } else {
4017       MIRBuilder.buildExtractVectorElement(DstReg, VecParts[PartIdx], NewIdx);
4018     }
4019 
4020     MI.eraseFromParent();
4021     return Legalized;
4022   }
4023 
4024   // With a variable index, we can't perform the operation in a smaller type, so
4025   // we're forced to expand this.
4026   //
4027   // TODO: We could emit a chain of compare/select to figure out which piece to
4028   // index.
4029   return lowerExtractInsertVectorElt(MI);
4030 }
4031 
4032 LegalizerHelper::LegalizeResult
reduceLoadStoreWidth(GLoadStore & LdStMI,unsigned TypeIdx,LLT NarrowTy)4033 LegalizerHelper::reduceLoadStoreWidth(GLoadStore &LdStMI, unsigned TypeIdx,
4034                                       LLT NarrowTy) {
4035   // FIXME: Don't know how to handle secondary types yet.
4036   if (TypeIdx != 0)
4037     return UnableToLegalize;
4038 
4039   // This implementation doesn't work for atomics. Give up instead of doing
4040   // something invalid.
4041   if (LdStMI.isAtomic())
4042     return UnableToLegalize;
4043 
4044   bool IsLoad = isa<GLoad>(LdStMI);
4045   Register ValReg = LdStMI.getReg(0);
4046   Register AddrReg = LdStMI.getPointerReg();
4047   LLT ValTy = MRI.getType(ValReg);
4048 
4049   // FIXME: Do we need a distinct NarrowMemory legalize action?
4050   if (ValTy.getSizeInBits() != 8 * LdStMI.getMemSize()) {
4051     LLVM_DEBUG(dbgs() << "Can't narrow extload/truncstore\n");
4052     return UnableToLegalize;
4053   }
4054 
4055   int NumParts = -1;
4056   int NumLeftover = -1;
4057   LLT LeftoverTy;
4058   SmallVector<Register, 8> NarrowRegs, NarrowLeftoverRegs;
4059   if (IsLoad) {
4060     std::tie(NumParts, NumLeftover) = getNarrowTypeBreakDown(ValTy, NarrowTy, LeftoverTy);
4061   } else {
4062     if (extractParts(ValReg, ValTy, NarrowTy, LeftoverTy, NarrowRegs,
4063                      NarrowLeftoverRegs)) {
4064       NumParts = NarrowRegs.size();
4065       NumLeftover = NarrowLeftoverRegs.size();
4066     }
4067   }
4068 
4069   if (NumParts == -1)
4070     return UnableToLegalize;
4071 
4072   LLT PtrTy = MRI.getType(AddrReg);
4073   const LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
4074 
4075   unsigned TotalSize = ValTy.getSizeInBits();
4076 
4077   // Split the load/store into PartTy sized pieces starting at Offset. If this
4078   // is a load, return the new registers in ValRegs. For a store, each elements
4079   // of ValRegs should be PartTy. Returns the next offset that needs to be
4080   // handled.
4081   bool isBigEndian = MIRBuilder.getDataLayout().isBigEndian();
4082   auto MMO = LdStMI.getMMO();
4083   auto splitTypePieces = [=](LLT PartTy, SmallVectorImpl<Register> &ValRegs,
4084                              unsigned NumParts, unsigned Offset) -> unsigned {
4085     MachineFunction &MF = MIRBuilder.getMF();
4086     unsigned PartSize = PartTy.getSizeInBits();
4087     for (unsigned Idx = 0, E = NumParts; Idx != E && Offset < TotalSize;
4088          ++Idx) {
4089       unsigned ByteOffset = Offset / 8;
4090       Register NewAddrReg;
4091 
4092       MIRBuilder.materializePtrAdd(NewAddrReg, AddrReg, OffsetTy, ByteOffset);
4093 
4094       MachineMemOperand *NewMMO =
4095           MF.getMachineMemOperand(&MMO, ByteOffset, PartTy);
4096 
4097       if (IsLoad) {
4098         Register Dst = MRI.createGenericVirtualRegister(PartTy);
4099         ValRegs.push_back(Dst);
4100         MIRBuilder.buildLoad(Dst, NewAddrReg, *NewMMO);
4101       } else {
4102         MIRBuilder.buildStore(ValRegs[Idx], NewAddrReg, *NewMMO);
4103       }
4104       Offset = isBigEndian ? Offset - PartSize : Offset + PartSize;
4105     }
4106 
4107     return Offset;
4108   };
4109 
4110   unsigned Offset = isBigEndian ? TotalSize - NarrowTy.getSizeInBits() : 0;
4111   unsigned HandledOffset =
4112       splitTypePieces(NarrowTy, NarrowRegs, NumParts, Offset);
4113 
4114   // Handle the rest of the register if this isn't an even type breakdown.
4115   if (LeftoverTy.isValid())
4116     splitTypePieces(LeftoverTy, NarrowLeftoverRegs, NumLeftover, HandledOffset);
4117 
4118   if (IsLoad) {
4119     insertParts(ValReg, ValTy, NarrowTy, NarrowRegs,
4120                 LeftoverTy, NarrowLeftoverRegs);
4121   }
4122 
4123   LdStMI.eraseFromParent();
4124   return Legalized;
4125 }
4126 
4127 LegalizerHelper::LegalizeResult
fewerElementsVector(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)4128 LegalizerHelper::fewerElementsVector(MachineInstr &MI, unsigned TypeIdx,
4129                                      LLT NarrowTy) {
4130   using namespace TargetOpcode;
4131   GenericMachineInstr &GMI = cast<GenericMachineInstr>(MI);
4132   unsigned NumElts = NarrowTy.isVector() ? NarrowTy.getNumElements() : 1;
4133 
4134   switch (MI.getOpcode()) {
4135   case G_IMPLICIT_DEF:
4136   case G_TRUNC:
4137   case G_AND:
4138   case G_OR:
4139   case G_XOR:
4140   case G_ADD:
4141   case G_SUB:
4142   case G_MUL:
4143   case G_PTR_ADD:
4144   case G_SMULH:
4145   case G_UMULH:
4146   case G_FADD:
4147   case G_FMUL:
4148   case G_FSUB:
4149   case G_FNEG:
4150   case G_FABS:
4151   case G_FCANONICALIZE:
4152   case G_FDIV:
4153   case G_FREM:
4154   case G_FMA:
4155   case G_FMAD:
4156   case G_FPOW:
4157   case G_FEXP:
4158   case G_FEXP2:
4159   case G_FLOG:
4160   case G_FLOG2:
4161   case G_FLOG10:
4162   case G_FNEARBYINT:
4163   case G_FCEIL:
4164   case G_FFLOOR:
4165   case G_FRINT:
4166   case G_INTRINSIC_ROUND:
4167   case G_INTRINSIC_ROUNDEVEN:
4168   case G_INTRINSIC_TRUNC:
4169   case G_FCOS:
4170   case G_FSIN:
4171   case G_FSQRT:
4172   case G_BSWAP:
4173   case G_BITREVERSE:
4174   case G_SDIV:
4175   case G_UDIV:
4176   case G_SREM:
4177   case G_UREM:
4178   case G_SDIVREM:
4179   case G_UDIVREM:
4180   case G_SMIN:
4181   case G_SMAX:
4182   case G_UMIN:
4183   case G_UMAX:
4184   case G_ABS:
4185   case G_FMINNUM:
4186   case G_FMAXNUM:
4187   case G_FMINNUM_IEEE:
4188   case G_FMAXNUM_IEEE:
4189   case G_FMINIMUM:
4190   case G_FMAXIMUM:
4191   case G_FSHL:
4192   case G_FSHR:
4193   case G_ROTL:
4194   case G_ROTR:
4195   case G_FREEZE:
4196   case G_SADDSAT:
4197   case G_SSUBSAT:
4198   case G_UADDSAT:
4199   case G_USUBSAT:
4200   case G_UMULO:
4201   case G_SMULO:
4202   case G_SHL:
4203   case G_LSHR:
4204   case G_ASHR:
4205   case G_SSHLSAT:
4206   case G_USHLSAT:
4207   case G_CTLZ:
4208   case G_CTLZ_ZERO_UNDEF:
4209   case G_CTTZ:
4210   case G_CTTZ_ZERO_UNDEF:
4211   case G_CTPOP:
4212   case G_FCOPYSIGN:
4213   case G_ZEXT:
4214   case G_SEXT:
4215   case G_ANYEXT:
4216   case G_FPEXT:
4217   case G_FPTRUNC:
4218   case G_SITOFP:
4219   case G_UITOFP:
4220   case G_FPTOSI:
4221   case G_FPTOUI:
4222   case G_INTTOPTR:
4223   case G_PTRTOINT:
4224   case G_ADDRSPACE_CAST:
4225   case G_UADDO:
4226   case G_USUBO:
4227   case G_UADDE:
4228   case G_USUBE:
4229   case G_SADDO:
4230   case G_SSUBO:
4231   case G_SADDE:
4232   case G_SSUBE:
4233   case G_STRICT_FADD:
4234   case G_STRICT_FSUB:
4235   case G_STRICT_FMUL:
4236   case G_STRICT_FMA:
4237     return fewerElementsVectorMultiEltType(GMI, NumElts);
4238   case G_ICMP:
4239   case G_FCMP:
4240     return fewerElementsVectorMultiEltType(GMI, NumElts, {1 /*cpm predicate*/});
4241   case G_IS_FPCLASS:
4242     return fewerElementsVectorMultiEltType(GMI, NumElts, {2, 3 /*mask,fpsem*/});
4243   case G_SELECT:
4244     if (MRI.getType(MI.getOperand(1).getReg()).isVector())
4245       return fewerElementsVectorMultiEltType(GMI, NumElts);
4246     return fewerElementsVectorMultiEltType(GMI, NumElts, {1 /*scalar cond*/});
4247   case G_PHI:
4248     return fewerElementsVectorPhi(GMI, NumElts);
4249   case G_UNMERGE_VALUES:
4250     return fewerElementsVectorUnmergeValues(MI, TypeIdx, NarrowTy);
4251   case G_BUILD_VECTOR:
4252     assert(TypeIdx == 0 && "not a vector type index");
4253     return fewerElementsVectorMerge(MI, TypeIdx, NarrowTy);
4254   case G_CONCAT_VECTORS:
4255     if (TypeIdx != 1) // TODO: This probably does work as expected already.
4256       return UnableToLegalize;
4257     return fewerElementsVectorMerge(MI, TypeIdx, NarrowTy);
4258   case G_EXTRACT_VECTOR_ELT:
4259   case G_INSERT_VECTOR_ELT:
4260     return fewerElementsVectorExtractInsertVectorElt(MI, TypeIdx, NarrowTy);
4261   case G_LOAD:
4262   case G_STORE:
4263     return reduceLoadStoreWidth(cast<GLoadStore>(MI), TypeIdx, NarrowTy);
4264   case G_SEXT_INREG:
4265     return fewerElementsVectorMultiEltType(GMI, NumElts, {2 /*imm*/});
4266   GISEL_VECREDUCE_CASES_NONSEQ
4267     return fewerElementsVectorReductions(MI, TypeIdx, NarrowTy);
4268   case G_SHUFFLE_VECTOR:
4269     return fewerElementsVectorShuffle(MI, TypeIdx, NarrowTy);
4270   default:
4271     return UnableToLegalize;
4272   }
4273 }
4274 
fewerElementsVectorShuffle(MachineInstr & MI,unsigned int TypeIdx,LLT NarrowTy)4275 LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorShuffle(
4276     MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
4277   assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
4278   if (TypeIdx != 0)
4279     return UnableToLegalize;
4280 
4281   Register DstReg = MI.getOperand(0).getReg();
4282   Register Src1Reg = MI.getOperand(1).getReg();
4283   Register Src2Reg = MI.getOperand(2).getReg();
4284   ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
4285   LLT DstTy = MRI.getType(DstReg);
4286   LLT Src1Ty = MRI.getType(Src1Reg);
4287   LLT Src2Ty = MRI.getType(Src2Reg);
4288   // The shuffle should be canonicalized by now.
4289   if (DstTy != Src1Ty)
4290     return UnableToLegalize;
4291   if (DstTy != Src2Ty)
4292     return UnableToLegalize;
4293 
4294   if (!isPowerOf2_32(DstTy.getNumElements()))
4295     return UnableToLegalize;
4296 
4297   // We only support splitting a shuffle into 2, so adjust NarrowTy accordingly.
4298   // Further legalization attempts will be needed to do split further.
4299   NarrowTy =
4300       DstTy.changeElementCount(DstTy.getElementCount().divideCoefficientBy(2));
4301   unsigned NewElts = NarrowTy.getNumElements();
4302 
4303   SmallVector<Register> SplitSrc1Regs, SplitSrc2Regs;
4304   extractParts(Src1Reg, NarrowTy, 2, SplitSrc1Regs);
4305   extractParts(Src2Reg, NarrowTy, 2, SplitSrc2Regs);
4306   Register Inputs[4] = {SplitSrc1Regs[0], SplitSrc1Regs[1], SplitSrc2Regs[0],
4307                         SplitSrc2Regs[1]};
4308 
4309   Register Hi, Lo;
4310 
4311   // If Lo or Hi uses elements from at most two of the four input vectors, then
4312   // express it as a vector shuffle of those two inputs.  Otherwise extract the
4313   // input elements by hand and construct the Lo/Hi output using a BUILD_VECTOR.
4314   SmallVector<int, 16> Ops;
4315   for (unsigned High = 0; High < 2; ++High) {
4316     Register &Output = High ? Hi : Lo;
4317 
4318     // Build a shuffle mask for the output, discovering on the fly which
4319     // input vectors to use as shuffle operands (recorded in InputUsed).
4320     // If building a suitable shuffle vector proves too hard, then bail
4321     // out with useBuildVector set.
4322     unsigned InputUsed[2] = {-1U, -1U}; // Not yet discovered.
4323     unsigned FirstMaskIdx = High * NewElts;
4324     bool UseBuildVector = false;
4325     for (unsigned MaskOffset = 0; MaskOffset < NewElts; ++MaskOffset) {
4326       // The mask element.  This indexes into the input.
4327       int Idx = Mask[FirstMaskIdx + MaskOffset];
4328 
4329       // The input vector this mask element indexes into.
4330       unsigned Input = (unsigned)Idx / NewElts;
4331 
4332       if (Input >= std::size(Inputs)) {
4333         // The mask element does not index into any input vector.
4334         Ops.push_back(-1);
4335         continue;
4336       }
4337 
4338       // Turn the index into an offset from the start of the input vector.
4339       Idx -= Input * NewElts;
4340 
4341       // Find or create a shuffle vector operand to hold this input.
4342       unsigned OpNo;
4343       for (OpNo = 0; OpNo < std::size(InputUsed); ++OpNo) {
4344         if (InputUsed[OpNo] == Input) {
4345           // This input vector is already an operand.
4346           break;
4347         } else if (InputUsed[OpNo] == -1U) {
4348           // Create a new operand for this input vector.
4349           InputUsed[OpNo] = Input;
4350           break;
4351         }
4352       }
4353 
4354       if (OpNo >= std::size(InputUsed)) {
4355         // More than two input vectors used!  Give up on trying to create a
4356         // shuffle vector.  Insert all elements into a BUILD_VECTOR instead.
4357         UseBuildVector = true;
4358         break;
4359       }
4360 
4361       // Add the mask index for the new shuffle vector.
4362       Ops.push_back(Idx + OpNo * NewElts);
4363     }
4364 
4365     if (UseBuildVector) {
4366       LLT EltTy = NarrowTy.getElementType();
4367       SmallVector<Register, 16> SVOps;
4368 
4369       // Extract the input elements by hand.
4370       for (unsigned MaskOffset = 0; MaskOffset < NewElts; ++MaskOffset) {
4371         // The mask element.  This indexes into the input.
4372         int Idx = Mask[FirstMaskIdx + MaskOffset];
4373 
4374         // The input vector this mask element indexes into.
4375         unsigned Input = (unsigned)Idx / NewElts;
4376 
4377         if (Input >= std::size(Inputs)) {
4378           // The mask element is "undef" or indexes off the end of the input.
4379           SVOps.push_back(MIRBuilder.buildUndef(EltTy).getReg(0));
4380           continue;
4381         }
4382 
4383         // Turn the index into an offset from the start of the input vector.
4384         Idx -= Input * NewElts;
4385 
4386         // Extract the vector element by hand.
4387         SVOps.push_back(MIRBuilder
4388                             .buildExtractVectorElement(
4389                                 EltTy, Inputs[Input],
4390                                 MIRBuilder.buildConstant(LLT::scalar(32), Idx))
4391                             .getReg(0));
4392       }
4393 
4394       // Construct the Lo/Hi output using a G_BUILD_VECTOR.
4395       Output = MIRBuilder.buildBuildVector(NarrowTy, SVOps).getReg(0);
4396     } else if (InputUsed[0] == -1U) {
4397       // No input vectors were used! The result is undefined.
4398       Output = MIRBuilder.buildUndef(NarrowTy).getReg(0);
4399     } else {
4400       Register Op0 = Inputs[InputUsed[0]];
4401       // If only one input was used, use an undefined vector for the other.
4402       Register Op1 = InputUsed[1] == -1U
4403                          ? MIRBuilder.buildUndef(NarrowTy).getReg(0)
4404                          : Inputs[InputUsed[1]];
4405       // At least one input vector was used. Create a new shuffle vector.
4406       Output = MIRBuilder.buildShuffleVector(NarrowTy, Op0, Op1, Ops).getReg(0);
4407     }
4408 
4409     Ops.clear();
4410   }
4411 
4412   MIRBuilder.buildConcatVectors(DstReg, {Lo, Hi});
4413   MI.eraseFromParent();
4414   return Legalized;
4415 }
4416 
getScalarOpcForReduction(unsigned Opc)4417 static unsigned getScalarOpcForReduction(unsigned Opc) {
4418   unsigned ScalarOpc;
4419   switch (Opc) {
4420   case TargetOpcode::G_VECREDUCE_FADD:
4421     ScalarOpc = TargetOpcode::G_FADD;
4422     break;
4423   case TargetOpcode::G_VECREDUCE_FMUL:
4424     ScalarOpc = TargetOpcode::G_FMUL;
4425     break;
4426   case TargetOpcode::G_VECREDUCE_FMAX:
4427     ScalarOpc = TargetOpcode::G_FMAXNUM;
4428     break;
4429   case TargetOpcode::G_VECREDUCE_FMIN:
4430     ScalarOpc = TargetOpcode::G_FMINNUM;
4431     break;
4432   case TargetOpcode::G_VECREDUCE_ADD:
4433     ScalarOpc = TargetOpcode::G_ADD;
4434     break;
4435   case TargetOpcode::G_VECREDUCE_MUL:
4436     ScalarOpc = TargetOpcode::G_MUL;
4437     break;
4438   case TargetOpcode::G_VECREDUCE_AND:
4439     ScalarOpc = TargetOpcode::G_AND;
4440     break;
4441   case TargetOpcode::G_VECREDUCE_OR:
4442     ScalarOpc = TargetOpcode::G_OR;
4443     break;
4444   case TargetOpcode::G_VECREDUCE_XOR:
4445     ScalarOpc = TargetOpcode::G_XOR;
4446     break;
4447   case TargetOpcode::G_VECREDUCE_SMAX:
4448     ScalarOpc = TargetOpcode::G_SMAX;
4449     break;
4450   case TargetOpcode::G_VECREDUCE_SMIN:
4451     ScalarOpc = TargetOpcode::G_SMIN;
4452     break;
4453   case TargetOpcode::G_VECREDUCE_UMAX:
4454     ScalarOpc = TargetOpcode::G_UMAX;
4455     break;
4456   case TargetOpcode::G_VECREDUCE_UMIN:
4457     ScalarOpc = TargetOpcode::G_UMIN;
4458     break;
4459   default:
4460     llvm_unreachable("Unhandled reduction");
4461   }
4462   return ScalarOpc;
4463 }
4464 
fewerElementsVectorReductions(MachineInstr & MI,unsigned int TypeIdx,LLT NarrowTy)4465 LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
4466     MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
4467   unsigned Opc = MI.getOpcode();
4468   assert(Opc != TargetOpcode::G_VECREDUCE_SEQ_FADD &&
4469          Opc != TargetOpcode::G_VECREDUCE_SEQ_FMUL &&
4470          "Sequential reductions not expected");
4471 
4472   if (TypeIdx != 1)
4473     return UnableToLegalize;
4474 
4475   // The semantics of the normal non-sequential reductions allow us to freely
4476   // re-associate the operation.
4477   Register SrcReg = MI.getOperand(1).getReg();
4478   LLT SrcTy = MRI.getType(SrcReg);
4479   Register DstReg = MI.getOperand(0).getReg();
4480   LLT DstTy = MRI.getType(DstReg);
4481 
4482   if (NarrowTy.isVector() &&
4483       (SrcTy.getNumElements() % NarrowTy.getNumElements() != 0))
4484     return UnableToLegalize;
4485 
4486   unsigned ScalarOpc = getScalarOpcForReduction(Opc);
4487   SmallVector<Register> SplitSrcs;
4488   // If NarrowTy is a scalar then we're being asked to scalarize.
4489   const unsigned NumParts =
4490       NarrowTy.isVector() ? SrcTy.getNumElements() / NarrowTy.getNumElements()
4491                           : SrcTy.getNumElements();
4492 
4493   extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs);
4494   if (NarrowTy.isScalar()) {
4495     if (DstTy != NarrowTy)
4496       return UnableToLegalize; // FIXME: handle implicit extensions.
4497 
4498     if (isPowerOf2_32(NumParts)) {
4499       // Generate a tree of scalar operations to reduce the critical path.
4500       SmallVector<Register> PartialResults;
4501       unsigned NumPartsLeft = NumParts;
4502       while (NumPartsLeft > 1) {
4503         for (unsigned Idx = 0; Idx < NumPartsLeft - 1; Idx += 2) {
4504           PartialResults.emplace_back(
4505               MIRBuilder
4506                   .buildInstr(ScalarOpc, {NarrowTy},
4507                               {SplitSrcs[Idx], SplitSrcs[Idx + 1]})
4508                   .getReg(0));
4509         }
4510         SplitSrcs = PartialResults;
4511         PartialResults.clear();
4512         NumPartsLeft = SplitSrcs.size();
4513       }
4514       assert(SplitSrcs.size() == 1);
4515       MIRBuilder.buildCopy(DstReg, SplitSrcs[0]);
4516       MI.eraseFromParent();
4517       return Legalized;
4518     }
4519     // If we can't generate a tree, then just do sequential operations.
4520     Register Acc = SplitSrcs[0];
4521     for (unsigned Idx = 1; Idx < NumParts; ++Idx)
4522       Acc = MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {Acc, SplitSrcs[Idx]})
4523                 .getReg(0);
4524     MIRBuilder.buildCopy(DstReg, Acc);
4525     MI.eraseFromParent();
4526     return Legalized;
4527   }
4528   SmallVector<Register> PartialReductions;
4529   for (unsigned Part = 0; Part < NumParts; ++Part) {
4530     PartialReductions.push_back(
4531         MIRBuilder.buildInstr(Opc, {DstTy}, {SplitSrcs[Part]}).getReg(0));
4532   }
4533 
4534 
4535   // If the types involved are powers of 2, we can generate intermediate vector
4536   // ops, before generating a final reduction operation.
4537   if (isPowerOf2_32(SrcTy.getNumElements()) &&
4538       isPowerOf2_32(NarrowTy.getNumElements())) {
4539     return tryNarrowPow2Reduction(MI, SrcReg, SrcTy, NarrowTy, ScalarOpc);
4540   }
4541 
4542   Register Acc = PartialReductions[0];
4543   for (unsigned Part = 1; Part < NumParts; ++Part) {
4544     if (Part == NumParts - 1) {
4545       MIRBuilder.buildInstr(ScalarOpc, {DstReg},
4546                             {Acc, PartialReductions[Part]});
4547     } else {
4548       Acc = MIRBuilder
4549                 .buildInstr(ScalarOpc, {DstTy}, {Acc, PartialReductions[Part]})
4550                 .getReg(0);
4551     }
4552   }
4553   MI.eraseFromParent();
4554   return Legalized;
4555 }
4556 
4557 LegalizerHelper::LegalizeResult
tryNarrowPow2Reduction(MachineInstr & MI,Register SrcReg,LLT SrcTy,LLT NarrowTy,unsigned ScalarOpc)4558 LegalizerHelper::tryNarrowPow2Reduction(MachineInstr &MI, Register SrcReg,
4559                                         LLT SrcTy, LLT NarrowTy,
4560                                         unsigned ScalarOpc) {
4561   SmallVector<Register> SplitSrcs;
4562   // Split the sources into NarrowTy size pieces.
4563   extractParts(SrcReg, NarrowTy,
4564                SrcTy.getNumElements() / NarrowTy.getNumElements(), SplitSrcs);
4565   // We're going to do a tree reduction using vector operations until we have
4566   // one NarrowTy size value left.
4567   while (SplitSrcs.size() > 1) {
4568     SmallVector<Register> PartialRdxs;
4569     for (unsigned Idx = 0; Idx < SplitSrcs.size()-1; Idx += 2) {
4570       Register LHS = SplitSrcs[Idx];
4571       Register RHS = SplitSrcs[Idx + 1];
4572       // Create the intermediate vector op.
4573       Register Res =
4574           MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {LHS, RHS}).getReg(0);
4575       PartialRdxs.push_back(Res);
4576     }
4577     SplitSrcs = std::move(PartialRdxs);
4578   }
4579   // Finally generate the requested NarrowTy based reduction.
4580   Observer.changingInstr(MI);
4581   MI.getOperand(1).setReg(SplitSrcs[0]);
4582   Observer.changedInstr(MI);
4583   return Legalized;
4584 }
4585 
4586 LegalizerHelper::LegalizeResult
narrowScalarShiftByConstant(MachineInstr & MI,const APInt & Amt,const LLT HalfTy,const LLT AmtTy)4587 LegalizerHelper::narrowScalarShiftByConstant(MachineInstr &MI, const APInt &Amt,
4588                                              const LLT HalfTy, const LLT AmtTy) {
4589 
4590   Register InL = MRI.createGenericVirtualRegister(HalfTy);
4591   Register InH = MRI.createGenericVirtualRegister(HalfTy);
4592   MIRBuilder.buildUnmerge({InL, InH}, MI.getOperand(1));
4593 
4594   if (Amt.isZero()) {
4595     MIRBuilder.buildMergeLikeInstr(MI.getOperand(0), {InL, InH});
4596     MI.eraseFromParent();
4597     return Legalized;
4598   }
4599 
4600   LLT NVT = HalfTy;
4601   unsigned NVTBits = HalfTy.getSizeInBits();
4602   unsigned VTBits = 2 * NVTBits;
4603 
4604   SrcOp Lo(Register(0)), Hi(Register(0));
4605   if (MI.getOpcode() == TargetOpcode::G_SHL) {
4606     if (Amt.ugt(VTBits)) {
4607       Lo = Hi = MIRBuilder.buildConstant(NVT, 0);
4608     } else if (Amt.ugt(NVTBits)) {
4609       Lo = MIRBuilder.buildConstant(NVT, 0);
4610       Hi = MIRBuilder.buildShl(NVT, InL,
4611                                MIRBuilder.buildConstant(AmtTy, Amt - NVTBits));
4612     } else if (Amt == NVTBits) {
4613       Lo = MIRBuilder.buildConstant(NVT, 0);
4614       Hi = InL;
4615     } else {
4616       Lo = MIRBuilder.buildShl(NVT, InL, MIRBuilder.buildConstant(AmtTy, Amt));
4617       auto OrLHS =
4618           MIRBuilder.buildShl(NVT, InH, MIRBuilder.buildConstant(AmtTy, Amt));
4619       auto OrRHS = MIRBuilder.buildLShr(
4620           NVT, InL, MIRBuilder.buildConstant(AmtTy, -Amt + NVTBits));
4621       Hi = MIRBuilder.buildOr(NVT, OrLHS, OrRHS);
4622     }
4623   } else if (MI.getOpcode() == TargetOpcode::G_LSHR) {
4624     if (Amt.ugt(VTBits)) {
4625       Lo = Hi = MIRBuilder.buildConstant(NVT, 0);
4626     } else if (Amt.ugt(NVTBits)) {
4627       Lo = MIRBuilder.buildLShr(NVT, InH,
4628                                 MIRBuilder.buildConstant(AmtTy, Amt - NVTBits));
4629       Hi = MIRBuilder.buildConstant(NVT, 0);
4630     } else if (Amt == NVTBits) {
4631       Lo = InH;
4632       Hi = MIRBuilder.buildConstant(NVT, 0);
4633     } else {
4634       auto ShiftAmtConst = MIRBuilder.buildConstant(AmtTy, Amt);
4635 
4636       auto OrLHS = MIRBuilder.buildLShr(NVT, InL, ShiftAmtConst);
4637       auto OrRHS = MIRBuilder.buildShl(
4638           NVT, InH, MIRBuilder.buildConstant(AmtTy, -Amt + NVTBits));
4639 
4640       Lo = MIRBuilder.buildOr(NVT, OrLHS, OrRHS);
4641       Hi = MIRBuilder.buildLShr(NVT, InH, ShiftAmtConst);
4642     }
4643   } else {
4644     if (Amt.ugt(VTBits)) {
4645       Hi = Lo = MIRBuilder.buildAShr(
4646           NVT, InH, MIRBuilder.buildConstant(AmtTy, NVTBits - 1));
4647     } else if (Amt.ugt(NVTBits)) {
4648       Lo = MIRBuilder.buildAShr(NVT, InH,
4649                                 MIRBuilder.buildConstant(AmtTy, Amt - NVTBits));
4650       Hi = MIRBuilder.buildAShr(NVT, InH,
4651                                 MIRBuilder.buildConstant(AmtTy, NVTBits - 1));
4652     } else if (Amt == NVTBits) {
4653       Lo = InH;
4654       Hi = MIRBuilder.buildAShr(NVT, InH,
4655                                 MIRBuilder.buildConstant(AmtTy, NVTBits - 1));
4656     } else {
4657       auto ShiftAmtConst = MIRBuilder.buildConstant(AmtTy, Amt);
4658 
4659       auto OrLHS = MIRBuilder.buildLShr(NVT, InL, ShiftAmtConst);
4660       auto OrRHS = MIRBuilder.buildShl(
4661           NVT, InH, MIRBuilder.buildConstant(AmtTy, -Amt + NVTBits));
4662 
4663       Lo = MIRBuilder.buildOr(NVT, OrLHS, OrRHS);
4664       Hi = MIRBuilder.buildAShr(NVT, InH, ShiftAmtConst);
4665     }
4666   }
4667 
4668   MIRBuilder.buildMergeLikeInstr(MI.getOperand(0), {Lo, Hi});
4669   MI.eraseFromParent();
4670 
4671   return Legalized;
4672 }
4673 
4674 // TODO: Optimize if constant shift amount.
4675 LegalizerHelper::LegalizeResult
narrowScalarShift(MachineInstr & MI,unsigned TypeIdx,LLT RequestedTy)4676 LegalizerHelper::narrowScalarShift(MachineInstr &MI, unsigned TypeIdx,
4677                                    LLT RequestedTy) {
4678   if (TypeIdx == 1) {
4679     Observer.changingInstr(MI);
4680     narrowScalarSrc(MI, RequestedTy, 2);
4681     Observer.changedInstr(MI);
4682     return Legalized;
4683   }
4684 
4685   Register DstReg = MI.getOperand(0).getReg();
4686   LLT DstTy = MRI.getType(DstReg);
4687   if (DstTy.isVector())
4688     return UnableToLegalize;
4689 
4690   Register Amt = MI.getOperand(2).getReg();
4691   LLT ShiftAmtTy = MRI.getType(Amt);
4692   const unsigned DstEltSize = DstTy.getScalarSizeInBits();
4693   if (DstEltSize % 2 != 0)
4694     return UnableToLegalize;
4695 
4696   // Ignore the input type. We can only go to exactly half the size of the
4697   // input. If that isn't small enough, the resulting pieces will be further
4698   // legalized.
4699   const unsigned NewBitSize = DstEltSize / 2;
4700   const LLT HalfTy = LLT::scalar(NewBitSize);
4701   const LLT CondTy = LLT::scalar(1);
4702 
4703   if (auto VRegAndVal = getIConstantVRegValWithLookThrough(Amt, MRI)) {
4704     return narrowScalarShiftByConstant(MI, VRegAndVal->Value, HalfTy,
4705                                        ShiftAmtTy);
4706   }
4707 
4708   // TODO: Expand with known bits.
4709 
4710   // Handle the fully general expansion by an unknown amount.
4711   auto NewBits = MIRBuilder.buildConstant(ShiftAmtTy, NewBitSize);
4712 
4713   Register InL = MRI.createGenericVirtualRegister(HalfTy);
4714   Register InH = MRI.createGenericVirtualRegister(HalfTy);
4715   MIRBuilder.buildUnmerge({InL, InH}, MI.getOperand(1));
4716 
4717   auto AmtExcess = MIRBuilder.buildSub(ShiftAmtTy, Amt, NewBits);
4718   auto AmtLack = MIRBuilder.buildSub(ShiftAmtTy, NewBits, Amt);
4719 
4720   auto Zero = MIRBuilder.buildConstant(ShiftAmtTy, 0);
4721   auto IsShort = MIRBuilder.buildICmp(ICmpInst::ICMP_ULT, CondTy, Amt, NewBits);
4722   auto IsZero = MIRBuilder.buildICmp(ICmpInst::ICMP_EQ, CondTy, Amt, Zero);
4723 
4724   Register ResultRegs[2];
4725   switch (MI.getOpcode()) {
4726   case TargetOpcode::G_SHL: {
4727     // Short: ShAmt < NewBitSize
4728     auto LoS = MIRBuilder.buildShl(HalfTy, InL, Amt);
4729 
4730     auto LoOr = MIRBuilder.buildLShr(HalfTy, InL, AmtLack);
4731     auto HiOr = MIRBuilder.buildShl(HalfTy, InH, Amt);
4732     auto HiS = MIRBuilder.buildOr(HalfTy, LoOr, HiOr);
4733 
4734     // Long: ShAmt >= NewBitSize
4735     auto LoL = MIRBuilder.buildConstant(HalfTy, 0);         // Lo part is zero.
4736     auto HiL = MIRBuilder.buildShl(HalfTy, InL, AmtExcess); // Hi from Lo part.
4737 
4738     auto Lo = MIRBuilder.buildSelect(HalfTy, IsShort, LoS, LoL);
4739     auto Hi = MIRBuilder.buildSelect(
4740         HalfTy, IsZero, InH, MIRBuilder.buildSelect(HalfTy, IsShort, HiS, HiL));
4741 
4742     ResultRegs[0] = Lo.getReg(0);
4743     ResultRegs[1] = Hi.getReg(0);
4744     break;
4745   }
4746   case TargetOpcode::G_LSHR:
4747   case TargetOpcode::G_ASHR: {
4748     // Short: ShAmt < NewBitSize
4749     auto HiS = MIRBuilder.buildInstr(MI.getOpcode(), {HalfTy}, {InH, Amt});
4750 
4751     auto LoOr = MIRBuilder.buildLShr(HalfTy, InL, Amt);
4752     auto HiOr = MIRBuilder.buildShl(HalfTy, InH, AmtLack);
4753     auto LoS = MIRBuilder.buildOr(HalfTy, LoOr, HiOr);
4754 
4755     // Long: ShAmt >= NewBitSize
4756     MachineInstrBuilder HiL;
4757     if (MI.getOpcode() == TargetOpcode::G_LSHR) {
4758       HiL = MIRBuilder.buildConstant(HalfTy, 0);            // Hi part is zero.
4759     } else {
4760       auto ShiftAmt = MIRBuilder.buildConstant(ShiftAmtTy, NewBitSize - 1);
4761       HiL = MIRBuilder.buildAShr(HalfTy, InH, ShiftAmt);    // Sign of Hi part.
4762     }
4763     auto LoL = MIRBuilder.buildInstr(MI.getOpcode(), {HalfTy},
4764                                      {InH, AmtExcess});     // Lo from Hi part.
4765 
4766     auto Lo = MIRBuilder.buildSelect(
4767         HalfTy, IsZero, InL, MIRBuilder.buildSelect(HalfTy, IsShort, LoS, LoL));
4768 
4769     auto Hi = MIRBuilder.buildSelect(HalfTy, IsShort, HiS, HiL);
4770 
4771     ResultRegs[0] = Lo.getReg(0);
4772     ResultRegs[1] = Hi.getReg(0);
4773     break;
4774   }
4775   default:
4776     llvm_unreachable("not a shift");
4777   }
4778 
4779   MIRBuilder.buildMergeLikeInstr(DstReg, ResultRegs);
4780   MI.eraseFromParent();
4781   return Legalized;
4782 }
4783 
4784 LegalizerHelper::LegalizeResult
moreElementsVectorPhi(MachineInstr & MI,unsigned TypeIdx,LLT MoreTy)4785 LegalizerHelper::moreElementsVectorPhi(MachineInstr &MI, unsigned TypeIdx,
4786                                        LLT MoreTy) {
4787   assert(TypeIdx == 0 && "Expecting only Idx 0");
4788 
4789   Observer.changingInstr(MI);
4790   for (unsigned I = 1, E = MI.getNumOperands(); I != E; I += 2) {
4791     MachineBasicBlock &OpMBB = *MI.getOperand(I + 1).getMBB();
4792     MIRBuilder.setInsertPt(OpMBB, OpMBB.getFirstTerminator());
4793     moreElementsVectorSrc(MI, MoreTy, I);
4794   }
4795 
4796   MachineBasicBlock &MBB = *MI.getParent();
4797   MIRBuilder.setInsertPt(MBB, --MBB.getFirstNonPHI());
4798   moreElementsVectorDst(MI, MoreTy, 0);
4799   Observer.changedInstr(MI);
4800   return Legalized;
4801 }
4802 
4803 LegalizerHelper::LegalizeResult
moreElementsVector(MachineInstr & MI,unsigned TypeIdx,LLT MoreTy)4804 LegalizerHelper::moreElementsVector(MachineInstr &MI, unsigned TypeIdx,
4805                                     LLT MoreTy) {
4806   unsigned Opc = MI.getOpcode();
4807   switch (Opc) {
4808   case TargetOpcode::G_IMPLICIT_DEF:
4809   case TargetOpcode::G_LOAD: {
4810     if (TypeIdx != 0)
4811       return UnableToLegalize;
4812     Observer.changingInstr(MI);
4813     moreElementsVectorDst(MI, MoreTy, 0);
4814     Observer.changedInstr(MI);
4815     return Legalized;
4816   }
4817   case TargetOpcode::G_STORE:
4818     if (TypeIdx != 0)
4819       return UnableToLegalize;
4820     Observer.changingInstr(MI);
4821     moreElementsVectorSrc(MI, MoreTy, 0);
4822     Observer.changedInstr(MI);
4823     return Legalized;
4824   case TargetOpcode::G_AND:
4825   case TargetOpcode::G_OR:
4826   case TargetOpcode::G_XOR:
4827   case TargetOpcode::G_ADD:
4828   case TargetOpcode::G_SUB:
4829   case TargetOpcode::G_MUL:
4830   case TargetOpcode::G_FADD:
4831   case TargetOpcode::G_FMUL:
4832   case TargetOpcode::G_UADDSAT:
4833   case TargetOpcode::G_USUBSAT:
4834   case TargetOpcode::G_SADDSAT:
4835   case TargetOpcode::G_SSUBSAT:
4836   case TargetOpcode::G_SMIN:
4837   case TargetOpcode::G_SMAX:
4838   case TargetOpcode::G_UMIN:
4839   case TargetOpcode::G_UMAX:
4840   case TargetOpcode::G_FMINNUM:
4841   case TargetOpcode::G_FMAXNUM:
4842   case TargetOpcode::G_FMINNUM_IEEE:
4843   case TargetOpcode::G_FMAXNUM_IEEE:
4844   case TargetOpcode::G_FMINIMUM:
4845   case TargetOpcode::G_FMAXIMUM:
4846   case TargetOpcode::G_STRICT_FADD:
4847   case TargetOpcode::G_STRICT_FSUB:
4848   case TargetOpcode::G_STRICT_FMUL: {
4849     Observer.changingInstr(MI);
4850     moreElementsVectorSrc(MI, MoreTy, 1);
4851     moreElementsVectorSrc(MI, MoreTy, 2);
4852     moreElementsVectorDst(MI, MoreTy, 0);
4853     Observer.changedInstr(MI);
4854     return Legalized;
4855   }
4856   case TargetOpcode::G_FMA:
4857   case TargetOpcode::G_STRICT_FMA:
4858   case TargetOpcode::G_FSHR:
4859   case TargetOpcode::G_FSHL: {
4860     Observer.changingInstr(MI);
4861     moreElementsVectorSrc(MI, MoreTy, 1);
4862     moreElementsVectorSrc(MI, MoreTy, 2);
4863     moreElementsVectorSrc(MI, MoreTy, 3);
4864     moreElementsVectorDst(MI, MoreTy, 0);
4865     Observer.changedInstr(MI);
4866     return Legalized;
4867   }
4868   case TargetOpcode::G_EXTRACT:
4869     if (TypeIdx != 1)
4870       return UnableToLegalize;
4871     Observer.changingInstr(MI);
4872     moreElementsVectorSrc(MI, MoreTy, 1);
4873     Observer.changedInstr(MI);
4874     return Legalized;
4875   case TargetOpcode::G_INSERT:
4876   case TargetOpcode::G_FREEZE:
4877   case TargetOpcode::G_FNEG:
4878   case TargetOpcode::G_FABS:
4879   case TargetOpcode::G_BSWAP:
4880   case TargetOpcode::G_FCANONICALIZE:
4881   case TargetOpcode::G_SEXT_INREG:
4882     if (TypeIdx != 0)
4883       return UnableToLegalize;
4884     Observer.changingInstr(MI);
4885     moreElementsVectorSrc(MI, MoreTy, 1);
4886     moreElementsVectorDst(MI, MoreTy, 0);
4887     Observer.changedInstr(MI);
4888     return Legalized;
4889   case TargetOpcode::G_SELECT: {
4890     Register DstReg = MI.getOperand(0).getReg();
4891     Register CondReg = MI.getOperand(1).getReg();
4892     LLT DstTy = MRI.getType(DstReg);
4893     LLT CondTy = MRI.getType(CondReg);
4894     if (TypeIdx == 1) {
4895       if (!CondTy.isScalar() ||
4896           DstTy.getElementCount() != MoreTy.getElementCount())
4897         return UnableToLegalize;
4898 
4899       // This is turning a scalar select of vectors into a vector
4900       // select. Broadcast the select condition.
4901       auto ShufSplat = MIRBuilder.buildShuffleSplat(MoreTy, CondReg);
4902       Observer.changingInstr(MI);
4903       MI.getOperand(1).setReg(ShufSplat.getReg(0));
4904       Observer.changedInstr(MI);
4905       return Legalized;
4906     }
4907 
4908     if (CondTy.isVector())
4909       return UnableToLegalize;
4910 
4911     Observer.changingInstr(MI);
4912     moreElementsVectorSrc(MI, MoreTy, 2);
4913     moreElementsVectorSrc(MI, MoreTy, 3);
4914     moreElementsVectorDst(MI, MoreTy, 0);
4915     Observer.changedInstr(MI);
4916     return Legalized;
4917   }
4918   case TargetOpcode::G_UNMERGE_VALUES:
4919     return UnableToLegalize;
4920   case TargetOpcode::G_PHI:
4921     return moreElementsVectorPhi(MI, TypeIdx, MoreTy);
4922   case TargetOpcode::G_SHUFFLE_VECTOR:
4923     return moreElementsVectorShuffle(MI, TypeIdx, MoreTy);
4924   case TargetOpcode::G_BUILD_VECTOR: {
4925     SmallVector<SrcOp, 8> Elts;
4926     for (auto Op : MI.uses()) {
4927       Elts.push_back(Op.getReg());
4928     }
4929 
4930     for (unsigned i = Elts.size(); i < MoreTy.getNumElements(); ++i) {
4931       Elts.push_back(MIRBuilder.buildUndef(MoreTy.getScalarType()));
4932     }
4933 
4934     MIRBuilder.buildDeleteTrailingVectorElements(
4935         MI.getOperand(0).getReg(), MIRBuilder.buildInstr(Opc, {MoreTy}, Elts));
4936     MI.eraseFromParent();
4937     return Legalized;
4938   }
4939   case TargetOpcode::G_TRUNC: {
4940     Observer.changingInstr(MI);
4941     moreElementsVectorSrc(MI, MoreTy, 1);
4942     moreElementsVectorDst(MI, MoreTy, 0);
4943     Observer.changedInstr(MI);
4944     return Legalized;
4945   }
4946   default:
4947     return UnableToLegalize;
4948   }
4949 }
4950 
4951 /// Expand source vectors to the size of destination vector.
4952 static LegalizerHelper::LegalizeResult
equalizeVectorShuffleLengths(MachineInstr & MI,MachineIRBuilder & MIRBuilder)4953 equalizeVectorShuffleLengths(MachineInstr &MI, MachineIRBuilder &MIRBuilder) {
4954   MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
4955 
4956   LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
4957   LLT SrcTy = MRI.getType(MI.getOperand(1).getReg());
4958   ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
4959   unsigned MaskNumElts = Mask.size();
4960   unsigned SrcNumElts = SrcTy.getNumElements();
4961   Register DstReg = MI.getOperand(0).getReg();
4962   LLT DestEltTy = DstTy.getElementType();
4963 
4964   // TODO: Normalize the shuffle vector since mask and vector length don't
4965   // match.
4966   if (MaskNumElts <= SrcNumElts) {
4967     return LegalizerHelper::LegalizeResult::UnableToLegalize;
4968   }
4969 
4970   unsigned PaddedMaskNumElts = alignTo(MaskNumElts, SrcNumElts);
4971   unsigned NumConcat = PaddedMaskNumElts / SrcNumElts;
4972   LLT PaddedTy = LLT::fixed_vector(PaddedMaskNumElts, DestEltTy);
4973 
4974   // Create new source vectors by concatenating the initial
4975   // source vectors with undefined vectors of the same size.
4976   auto Undef = MIRBuilder.buildUndef(SrcTy);
4977   SmallVector<Register, 8> MOps1(NumConcat, Undef.getReg(0));
4978   SmallVector<Register, 8> MOps2(NumConcat, Undef.getReg(0));
4979   MOps1[0] = MI.getOperand(1).getReg();
4980   MOps2[0] = MI.getOperand(2).getReg();
4981 
4982   auto Src1 = MIRBuilder.buildConcatVectors(PaddedTy, MOps1);
4983   auto Src2 = MIRBuilder.buildConcatVectors(PaddedTy, MOps2);
4984 
4985   // Readjust mask for new input vector length.
4986   SmallVector<int, 8> MappedOps(PaddedMaskNumElts, -1);
4987   for (unsigned I = 0; I != MaskNumElts; ++I) {
4988     int Idx = Mask[I];
4989     if (Idx >= static_cast<int>(SrcNumElts))
4990       Idx += PaddedMaskNumElts - SrcNumElts;
4991     MappedOps[I] = Idx;
4992   }
4993 
4994   // If we got more elements than required, extract subvector.
4995   if (MaskNumElts != PaddedMaskNumElts) {
4996     auto Shuffle =
4997         MIRBuilder.buildShuffleVector(PaddedTy, Src1, Src2, MappedOps);
4998 
4999     SmallVector<Register, 16> Elts(MaskNumElts);
5000     for (unsigned I = 0; I < MaskNumElts; ++I) {
5001       Elts[I] =
5002           MIRBuilder.buildExtractVectorElementConstant(DestEltTy, Shuffle, I)
5003               .getReg(0);
5004     }
5005     MIRBuilder.buildBuildVector(DstReg, Elts);
5006   } else {
5007     MIRBuilder.buildShuffleVector(DstReg, Src1, Src2, MappedOps);
5008   }
5009 
5010   MI.eraseFromParent();
5011   return LegalizerHelper::LegalizeResult::Legalized;
5012 }
5013 
5014 LegalizerHelper::LegalizeResult
moreElementsVectorShuffle(MachineInstr & MI,unsigned int TypeIdx,LLT MoreTy)5015 LegalizerHelper::moreElementsVectorShuffle(MachineInstr &MI,
5016                                            unsigned int TypeIdx, LLT MoreTy) {
5017   Register DstReg = MI.getOperand(0).getReg();
5018   Register Src1Reg = MI.getOperand(1).getReg();
5019   Register Src2Reg = MI.getOperand(2).getReg();
5020   ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
5021   LLT DstTy = MRI.getType(DstReg);
5022   LLT Src1Ty = MRI.getType(Src1Reg);
5023   LLT Src2Ty = MRI.getType(Src2Reg);
5024   unsigned NumElts = DstTy.getNumElements();
5025   unsigned WidenNumElts = MoreTy.getNumElements();
5026 
5027   if (DstTy.isVector() && Src1Ty.isVector() &&
5028       DstTy.getNumElements() > Src1Ty.getNumElements()) {
5029     return equalizeVectorShuffleLengths(MI, MIRBuilder);
5030   }
5031 
5032   if (TypeIdx != 0)
5033     return UnableToLegalize;
5034 
5035   // Expect a canonicalized shuffle.
5036   if (DstTy != Src1Ty || DstTy != Src2Ty)
5037     return UnableToLegalize;
5038 
5039   moreElementsVectorSrc(MI, MoreTy, 1);
5040   moreElementsVectorSrc(MI, MoreTy, 2);
5041 
5042   // Adjust mask based on new input vector length.
5043   SmallVector<int, 16> NewMask;
5044   for (unsigned I = 0; I != NumElts; ++I) {
5045     int Idx = Mask[I];
5046     if (Idx < static_cast<int>(NumElts))
5047       NewMask.push_back(Idx);
5048     else
5049       NewMask.push_back(Idx - NumElts + WidenNumElts);
5050   }
5051   for (unsigned I = NumElts; I != WidenNumElts; ++I)
5052     NewMask.push_back(-1);
5053   moreElementsVectorDst(MI, MoreTy, 0);
5054   MIRBuilder.setInstrAndDebugLoc(MI);
5055   MIRBuilder.buildShuffleVector(MI.getOperand(0).getReg(),
5056                                 MI.getOperand(1).getReg(),
5057                                 MI.getOperand(2).getReg(), NewMask);
5058   MI.eraseFromParent();
5059   return Legalized;
5060 }
5061 
multiplyRegisters(SmallVectorImpl<Register> & DstRegs,ArrayRef<Register> Src1Regs,ArrayRef<Register> Src2Regs,LLT NarrowTy)5062 void LegalizerHelper::multiplyRegisters(SmallVectorImpl<Register> &DstRegs,
5063                                         ArrayRef<Register> Src1Regs,
5064                                         ArrayRef<Register> Src2Regs,
5065                                         LLT NarrowTy) {
5066   MachineIRBuilder &B = MIRBuilder;
5067   unsigned SrcParts = Src1Regs.size();
5068   unsigned DstParts = DstRegs.size();
5069 
5070   unsigned DstIdx = 0; // Low bits of the result.
5071   Register FactorSum =
5072       B.buildMul(NarrowTy, Src1Regs[DstIdx], Src2Regs[DstIdx]).getReg(0);
5073   DstRegs[DstIdx] = FactorSum;
5074 
5075   unsigned CarrySumPrevDstIdx;
5076   SmallVector<Register, 4> Factors;
5077 
5078   for (DstIdx = 1; DstIdx < DstParts; DstIdx++) {
5079     // Collect low parts of muls for DstIdx.
5080     for (unsigned i = DstIdx + 1 < SrcParts ? 0 : DstIdx - SrcParts + 1;
5081          i <= std::min(DstIdx, SrcParts - 1); ++i) {
5082       MachineInstrBuilder Mul =
5083           B.buildMul(NarrowTy, Src1Regs[DstIdx - i], Src2Regs[i]);
5084       Factors.push_back(Mul.getReg(0));
5085     }
5086     // Collect high parts of muls from previous DstIdx.
5087     for (unsigned i = DstIdx < SrcParts ? 0 : DstIdx - SrcParts;
5088          i <= std::min(DstIdx - 1, SrcParts - 1); ++i) {
5089       MachineInstrBuilder Umulh =
5090           B.buildUMulH(NarrowTy, Src1Regs[DstIdx - 1 - i], Src2Regs[i]);
5091       Factors.push_back(Umulh.getReg(0));
5092     }
5093     // Add CarrySum from additions calculated for previous DstIdx.
5094     if (DstIdx != 1) {
5095       Factors.push_back(CarrySumPrevDstIdx);
5096     }
5097 
5098     Register CarrySum;
5099     // Add all factors and accumulate all carries into CarrySum.
5100     if (DstIdx != DstParts - 1) {
5101       MachineInstrBuilder Uaddo =
5102           B.buildUAddo(NarrowTy, LLT::scalar(1), Factors[0], Factors[1]);
5103       FactorSum = Uaddo.getReg(0);
5104       CarrySum = B.buildZExt(NarrowTy, Uaddo.getReg(1)).getReg(0);
5105       for (unsigned i = 2; i < Factors.size(); ++i) {
5106         MachineInstrBuilder Uaddo =
5107             B.buildUAddo(NarrowTy, LLT::scalar(1), FactorSum, Factors[i]);
5108         FactorSum = Uaddo.getReg(0);
5109         MachineInstrBuilder Carry = B.buildZExt(NarrowTy, Uaddo.getReg(1));
5110         CarrySum = B.buildAdd(NarrowTy, CarrySum, Carry).getReg(0);
5111       }
5112     } else {
5113       // Since value for the next index is not calculated, neither is CarrySum.
5114       FactorSum = B.buildAdd(NarrowTy, Factors[0], Factors[1]).getReg(0);
5115       for (unsigned i = 2; i < Factors.size(); ++i)
5116         FactorSum = B.buildAdd(NarrowTy, FactorSum, Factors[i]).getReg(0);
5117     }
5118 
5119     CarrySumPrevDstIdx = CarrySum;
5120     DstRegs[DstIdx] = FactorSum;
5121     Factors.clear();
5122   }
5123 }
5124 
5125 LegalizerHelper::LegalizeResult
narrowScalarAddSub(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5126 LegalizerHelper::narrowScalarAddSub(MachineInstr &MI, unsigned TypeIdx,
5127                                     LLT NarrowTy) {
5128   if (TypeIdx != 0)
5129     return UnableToLegalize;
5130 
5131   Register DstReg = MI.getOperand(0).getReg();
5132   LLT DstType = MRI.getType(DstReg);
5133   // FIXME: add support for vector types
5134   if (DstType.isVector())
5135     return UnableToLegalize;
5136 
5137   unsigned Opcode = MI.getOpcode();
5138   unsigned OpO, OpE, OpF;
5139   switch (Opcode) {
5140   case TargetOpcode::G_SADDO:
5141   case TargetOpcode::G_SADDE:
5142   case TargetOpcode::G_UADDO:
5143   case TargetOpcode::G_UADDE:
5144   case TargetOpcode::G_ADD:
5145     OpO = TargetOpcode::G_UADDO;
5146     OpE = TargetOpcode::G_UADDE;
5147     OpF = TargetOpcode::G_UADDE;
5148     if (Opcode == TargetOpcode::G_SADDO || Opcode == TargetOpcode::G_SADDE)
5149       OpF = TargetOpcode::G_SADDE;
5150     break;
5151   case TargetOpcode::G_SSUBO:
5152   case TargetOpcode::G_SSUBE:
5153   case TargetOpcode::G_USUBO:
5154   case TargetOpcode::G_USUBE:
5155   case TargetOpcode::G_SUB:
5156     OpO = TargetOpcode::G_USUBO;
5157     OpE = TargetOpcode::G_USUBE;
5158     OpF = TargetOpcode::G_USUBE;
5159     if (Opcode == TargetOpcode::G_SSUBO || Opcode == TargetOpcode::G_SSUBE)
5160       OpF = TargetOpcode::G_SSUBE;
5161     break;
5162   default:
5163     llvm_unreachable("Unexpected add/sub opcode!");
5164   }
5165 
5166   // 1 for a plain add/sub, 2 if this is an operation with a carry-out.
5167   unsigned NumDefs = MI.getNumExplicitDefs();
5168   Register Src1 = MI.getOperand(NumDefs).getReg();
5169   Register Src2 = MI.getOperand(NumDefs + 1).getReg();
5170   Register CarryDst, CarryIn;
5171   if (NumDefs == 2)
5172     CarryDst = MI.getOperand(1).getReg();
5173   if (MI.getNumOperands() == NumDefs + 3)
5174     CarryIn = MI.getOperand(NumDefs + 2).getReg();
5175 
5176   LLT RegTy = MRI.getType(MI.getOperand(0).getReg());
5177   LLT LeftoverTy, DummyTy;
5178   SmallVector<Register, 2> Src1Regs, Src2Regs, Src1Left, Src2Left, DstRegs;
5179   extractParts(Src1, RegTy, NarrowTy, LeftoverTy, Src1Regs, Src1Left);
5180   extractParts(Src2, RegTy, NarrowTy, DummyTy, Src2Regs, Src2Left);
5181 
5182   int NarrowParts = Src1Regs.size();
5183   for (int I = 0, E = Src1Left.size(); I != E; ++I) {
5184     Src1Regs.push_back(Src1Left[I]);
5185     Src2Regs.push_back(Src2Left[I]);
5186   }
5187   DstRegs.reserve(Src1Regs.size());
5188 
5189   for (int i = 0, e = Src1Regs.size(); i != e; ++i) {
5190     Register DstReg =
5191         MRI.createGenericVirtualRegister(MRI.getType(Src1Regs[i]));
5192     Register CarryOut = MRI.createGenericVirtualRegister(LLT::scalar(1));
5193     // Forward the final carry-out to the destination register
5194     if (i == e - 1 && CarryDst)
5195       CarryOut = CarryDst;
5196 
5197     if (!CarryIn) {
5198       MIRBuilder.buildInstr(OpO, {DstReg, CarryOut},
5199                             {Src1Regs[i], Src2Regs[i]});
5200     } else if (i == e - 1) {
5201       MIRBuilder.buildInstr(OpF, {DstReg, CarryOut},
5202                             {Src1Regs[i], Src2Regs[i], CarryIn});
5203     } else {
5204       MIRBuilder.buildInstr(OpE, {DstReg, CarryOut},
5205                             {Src1Regs[i], Src2Regs[i], CarryIn});
5206     }
5207 
5208     DstRegs.push_back(DstReg);
5209     CarryIn = CarryOut;
5210   }
5211   insertParts(MI.getOperand(0).getReg(), RegTy, NarrowTy,
5212               ArrayRef(DstRegs).take_front(NarrowParts), LeftoverTy,
5213               ArrayRef(DstRegs).drop_front(NarrowParts));
5214 
5215   MI.eraseFromParent();
5216   return Legalized;
5217 }
5218 
5219 LegalizerHelper::LegalizeResult
narrowScalarMul(MachineInstr & MI,LLT NarrowTy)5220 LegalizerHelper::narrowScalarMul(MachineInstr &MI, LLT NarrowTy) {
5221   Register DstReg = MI.getOperand(0).getReg();
5222   Register Src1 = MI.getOperand(1).getReg();
5223   Register Src2 = MI.getOperand(2).getReg();
5224 
5225   LLT Ty = MRI.getType(DstReg);
5226   if (Ty.isVector())
5227     return UnableToLegalize;
5228 
5229   unsigned Size = Ty.getSizeInBits();
5230   unsigned NarrowSize = NarrowTy.getSizeInBits();
5231   if (Size % NarrowSize != 0)
5232     return UnableToLegalize;
5233 
5234   unsigned NumParts = Size / NarrowSize;
5235   bool IsMulHigh = MI.getOpcode() == TargetOpcode::G_UMULH;
5236   unsigned DstTmpParts = NumParts * (IsMulHigh ? 2 : 1);
5237 
5238   SmallVector<Register, 2> Src1Parts, Src2Parts;
5239   SmallVector<Register, 2> DstTmpRegs(DstTmpParts);
5240   extractParts(Src1, NarrowTy, NumParts, Src1Parts);
5241   extractParts(Src2, NarrowTy, NumParts, Src2Parts);
5242   multiplyRegisters(DstTmpRegs, Src1Parts, Src2Parts, NarrowTy);
5243 
5244   // Take only high half of registers if this is high mul.
5245   ArrayRef<Register> DstRegs(&DstTmpRegs[DstTmpParts - NumParts], NumParts);
5246   MIRBuilder.buildMergeLikeInstr(DstReg, DstRegs);
5247   MI.eraseFromParent();
5248   return Legalized;
5249 }
5250 
5251 LegalizerHelper::LegalizeResult
narrowScalarFPTOI(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5252 LegalizerHelper::narrowScalarFPTOI(MachineInstr &MI, unsigned TypeIdx,
5253                                    LLT NarrowTy) {
5254   if (TypeIdx != 0)
5255     return UnableToLegalize;
5256 
5257   bool IsSigned = MI.getOpcode() == TargetOpcode::G_FPTOSI;
5258 
5259   Register Src = MI.getOperand(1).getReg();
5260   LLT SrcTy = MRI.getType(Src);
5261 
5262   // If all finite floats fit into the narrowed integer type, we can just swap
5263   // out the result type. This is practically only useful for conversions from
5264   // half to at least 16-bits, so just handle the one case.
5265   if (SrcTy.getScalarType() != LLT::scalar(16) ||
5266       NarrowTy.getScalarSizeInBits() < (IsSigned ? 17u : 16u))
5267     return UnableToLegalize;
5268 
5269   Observer.changingInstr(MI);
5270   narrowScalarDst(MI, NarrowTy, 0,
5271                   IsSigned ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT);
5272   Observer.changedInstr(MI);
5273   return Legalized;
5274 }
5275 
5276 LegalizerHelper::LegalizeResult
narrowScalarExtract(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5277 LegalizerHelper::narrowScalarExtract(MachineInstr &MI, unsigned TypeIdx,
5278                                      LLT NarrowTy) {
5279   if (TypeIdx != 1)
5280     return UnableToLegalize;
5281 
5282   uint64_t NarrowSize = NarrowTy.getSizeInBits();
5283 
5284   int64_t SizeOp1 = MRI.getType(MI.getOperand(1).getReg()).getSizeInBits();
5285   // FIXME: add support for when SizeOp1 isn't an exact multiple of
5286   // NarrowSize.
5287   if (SizeOp1 % NarrowSize != 0)
5288     return UnableToLegalize;
5289   int NumParts = SizeOp1 / NarrowSize;
5290 
5291   SmallVector<Register, 2> SrcRegs, DstRegs;
5292   SmallVector<uint64_t, 2> Indexes;
5293   extractParts(MI.getOperand(1).getReg(), NarrowTy, NumParts, SrcRegs);
5294 
5295   Register OpReg = MI.getOperand(0).getReg();
5296   uint64_t OpStart = MI.getOperand(2).getImm();
5297   uint64_t OpSize = MRI.getType(OpReg).getSizeInBits();
5298   for (int i = 0; i < NumParts; ++i) {
5299     unsigned SrcStart = i * NarrowSize;
5300 
5301     if (SrcStart + NarrowSize <= OpStart || SrcStart >= OpStart + OpSize) {
5302       // No part of the extract uses this subregister, ignore it.
5303       continue;
5304     } else if (SrcStart == OpStart && NarrowTy == MRI.getType(OpReg)) {
5305       // The entire subregister is extracted, forward the value.
5306       DstRegs.push_back(SrcRegs[i]);
5307       continue;
5308     }
5309 
5310     // OpSegStart is where this destination segment would start in OpReg if it
5311     // extended infinitely in both directions.
5312     int64_t ExtractOffset;
5313     uint64_t SegSize;
5314     if (OpStart < SrcStart) {
5315       ExtractOffset = 0;
5316       SegSize = std::min(NarrowSize, OpStart + OpSize - SrcStart);
5317     } else {
5318       ExtractOffset = OpStart - SrcStart;
5319       SegSize = std::min(SrcStart + NarrowSize - OpStart, OpSize);
5320     }
5321 
5322     Register SegReg = SrcRegs[i];
5323     if (ExtractOffset != 0 || SegSize != NarrowSize) {
5324       // A genuine extract is needed.
5325       SegReg = MRI.createGenericVirtualRegister(LLT::scalar(SegSize));
5326       MIRBuilder.buildExtract(SegReg, SrcRegs[i], ExtractOffset);
5327     }
5328 
5329     DstRegs.push_back(SegReg);
5330   }
5331 
5332   Register DstReg = MI.getOperand(0).getReg();
5333   if (MRI.getType(DstReg).isVector())
5334     MIRBuilder.buildBuildVector(DstReg, DstRegs);
5335   else if (DstRegs.size() > 1)
5336     MIRBuilder.buildMergeLikeInstr(DstReg, DstRegs);
5337   else
5338     MIRBuilder.buildCopy(DstReg, DstRegs[0]);
5339   MI.eraseFromParent();
5340   return Legalized;
5341 }
5342 
5343 LegalizerHelper::LegalizeResult
narrowScalarInsert(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5344 LegalizerHelper::narrowScalarInsert(MachineInstr &MI, unsigned TypeIdx,
5345                                     LLT NarrowTy) {
5346   // FIXME: Don't know how to handle secondary types yet.
5347   if (TypeIdx != 0)
5348     return UnableToLegalize;
5349 
5350   SmallVector<Register, 2> SrcRegs, LeftoverRegs, DstRegs;
5351   SmallVector<uint64_t, 2> Indexes;
5352   LLT RegTy = MRI.getType(MI.getOperand(0).getReg());
5353   LLT LeftoverTy;
5354   extractParts(MI.getOperand(1).getReg(), RegTy, NarrowTy, LeftoverTy, SrcRegs,
5355                LeftoverRegs);
5356 
5357   for (Register Reg : LeftoverRegs)
5358     SrcRegs.push_back(Reg);
5359 
5360   uint64_t NarrowSize = NarrowTy.getSizeInBits();
5361   Register OpReg = MI.getOperand(2).getReg();
5362   uint64_t OpStart = MI.getOperand(3).getImm();
5363   uint64_t OpSize = MRI.getType(OpReg).getSizeInBits();
5364   for (int I = 0, E = SrcRegs.size(); I != E; ++I) {
5365     unsigned DstStart = I * NarrowSize;
5366 
5367     if (DstStart == OpStart && NarrowTy == MRI.getType(OpReg)) {
5368       // The entire subregister is defined by this insert, forward the new
5369       // value.
5370       DstRegs.push_back(OpReg);
5371       continue;
5372     }
5373 
5374     Register SrcReg = SrcRegs[I];
5375     if (MRI.getType(SrcRegs[I]) == LeftoverTy) {
5376       // The leftover reg is smaller than NarrowTy, so we need to extend it.
5377       SrcReg = MRI.createGenericVirtualRegister(NarrowTy);
5378       MIRBuilder.buildAnyExt(SrcReg, SrcRegs[I]);
5379     }
5380 
5381     if (DstStart + NarrowSize <= OpStart || DstStart >= OpStart + OpSize) {
5382       // No part of the insert affects this subregister, forward the original.
5383       DstRegs.push_back(SrcReg);
5384       continue;
5385     }
5386 
5387     // OpSegStart is where this destination segment would start in OpReg if it
5388     // extended infinitely in both directions.
5389     int64_t ExtractOffset, InsertOffset;
5390     uint64_t SegSize;
5391     if (OpStart < DstStart) {
5392       InsertOffset = 0;
5393       ExtractOffset = DstStart - OpStart;
5394       SegSize = std::min(NarrowSize, OpStart + OpSize - DstStart);
5395     } else {
5396       InsertOffset = OpStart - DstStart;
5397       ExtractOffset = 0;
5398       SegSize =
5399         std::min(NarrowSize - InsertOffset, OpStart + OpSize - DstStart);
5400     }
5401 
5402     Register SegReg = OpReg;
5403     if (ExtractOffset != 0 || SegSize != OpSize) {
5404       // A genuine extract is needed.
5405       SegReg = MRI.createGenericVirtualRegister(LLT::scalar(SegSize));
5406       MIRBuilder.buildExtract(SegReg, OpReg, ExtractOffset);
5407     }
5408 
5409     Register DstReg = MRI.createGenericVirtualRegister(NarrowTy);
5410     MIRBuilder.buildInsert(DstReg, SrcReg, SegReg, InsertOffset);
5411     DstRegs.push_back(DstReg);
5412   }
5413 
5414   uint64_t WideSize = DstRegs.size() * NarrowSize;
5415   Register DstReg = MI.getOperand(0).getReg();
5416   if (WideSize > RegTy.getSizeInBits()) {
5417     Register MergeReg = MRI.createGenericVirtualRegister(LLT::scalar(WideSize));
5418     MIRBuilder.buildMergeLikeInstr(MergeReg, DstRegs);
5419     MIRBuilder.buildTrunc(DstReg, MergeReg);
5420   } else
5421     MIRBuilder.buildMergeLikeInstr(DstReg, DstRegs);
5422 
5423   MI.eraseFromParent();
5424   return Legalized;
5425 }
5426 
5427 LegalizerHelper::LegalizeResult
narrowScalarBasic(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5428 LegalizerHelper::narrowScalarBasic(MachineInstr &MI, unsigned TypeIdx,
5429                                    LLT NarrowTy) {
5430   Register DstReg = MI.getOperand(0).getReg();
5431   LLT DstTy = MRI.getType(DstReg);
5432 
5433   assert(MI.getNumOperands() == 3 && TypeIdx == 0);
5434 
5435   SmallVector<Register, 4> DstRegs, DstLeftoverRegs;
5436   SmallVector<Register, 4> Src0Regs, Src0LeftoverRegs;
5437   SmallVector<Register, 4> Src1Regs, Src1LeftoverRegs;
5438   LLT LeftoverTy;
5439   if (!extractParts(MI.getOperand(1).getReg(), DstTy, NarrowTy, LeftoverTy,
5440                     Src0Regs, Src0LeftoverRegs))
5441     return UnableToLegalize;
5442 
5443   LLT Unused;
5444   if (!extractParts(MI.getOperand(2).getReg(), DstTy, NarrowTy, Unused,
5445                     Src1Regs, Src1LeftoverRegs))
5446     llvm_unreachable("inconsistent extractParts result");
5447 
5448   for (unsigned I = 0, E = Src1Regs.size(); I != E; ++I) {
5449     auto Inst = MIRBuilder.buildInstr(MI.getOpcode(), {NarrowTy},
5450                                         {Src0Regs[I], Src1Regs[I]});
5451     DstRegs.push_back(Inst.getReg(0));
5452   }
5453 
5454   for (unsigned I = 0, E = Src1LeftoverRegs.size(); I != E; ++I) {
5455     auto Inst = MIRBuilder.buildInstr(
5456       MI.getOpcode(),
5457       {LeftoverTy}, {Src0LeftoverRegs[I], Src1LeftoverRegs[I]});
5458     DstLeftoverRegs.push_back(Inst.getReg(0));
5459   }
5460 
5461   insertParts(DstReg, DstTy, NarrowTy, DstRegs,
5462               LeftoverTy, DstLeftoverRegs);
5463 
5464   MI.eraseFromParent();
5465   return Legalized;
5466 }
5467 
5468 LegalizerHelper::LegalizeResult
narrowScalarExt(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5469 LegalizerHelper::narrowScalarExt(MachineInstr &MI, unsigned TypeIdx,
5470                                  LLT NarrowTy) {
5471   if (TypeIdx != 0)
5472     return UnableToLegalize;
5473 
5474   Register DstReg = MI.getOperand(0).getReg();
5475   Register SrcReg = MI.getOperand(1).getReg();
5476 
5477   LLT DstTy = MRI.getType(DstReg);
5478   if (DstTy.isVector())
5479     return UnableToLegalize;
5480 
5481   SmallVector<Register, 8> Parts;
5482   LLT GCDTy = extractGCDType(Parts, DstTy, NarrowTy, SrcReg);
5483   LLT LCMTy = buildLCMMergePieces(DstTy, NarrowTy, GCDTy, Parts, MI.getOpcode());
5484   buildWidenedRemergeToDst(DstReg, LCMTy, Parts);
5485 
5486   MI.eraseFromParent();
5487   return Legalized;
5488 }
5489 
5490 LegalizerHelper::LegalizeResult
narrowScalarSelect(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5491 LegalizerHelper::narrowScalarSelect(MachineInstr &MI, unsigned TypeIdx,
5492                                     LLT NarrowTy) {
5493   if (TypeIdx != 0)
5494     return UnableToLegalize;
5495 
5496   Register CondReg = MI.getOperand(1).getReg();
5497   LLT CondTy = MRI.getType(CondReg);
5498   if (CondTy.isVector()) // TODO: Handle vselect
5499     return UnableToLegalize;
5500 
5501   Register DstReg = MI.getOperand(0).getReg();
5502   LLT DstTy = MRI.getType(DstReg);
5503 
5504   SmallVector<Register, 4> DstRegs, DstLeftoverRegs;
5505   SmallVector<Register, 4> Src1Regs, Src1LeftoverRegs;
5506   SmallVector<Register, 4> Src2Regs, Src2LeftoverRegs;
5507   LLT LeftoverTy;
5508   if (!extractParts(MI.getOperand(2).getReg(), DstTy, NarrowTy, LeftoverTy,
5509                     Src1Regs, Src1LeftoverRegs))
5510     return UnableToLegalize;
5511 
5512   LLT Unused;
5513   if (!extractParts(MI.getOperand(3).getReg(), DstTy, NarrowTy, Unused,
5514                     Src2Regs, Src2LeftoverRegs))
5515     llvm_unreachable("inconsistent extractParts result");
5516 
5517   for (unsigned I = 0, E = Src1Regs.size(); I != E; ++I) {
5518     auto Select = MIRBuilder.buildSelect(NarrowTy,
5519                                          CondReg, Src1Regs[I], Src2Regs[I]);
5520     DstRegs.push_back(Select.getReg(0));
5521   }
5522 
5523   for (unsigned I = 0, E = Src1LeftoverRegs.size(); I != E; ++I) {
5524     auto Select = MIRBuilder.buildSelect(
5525       LeftoverTy, CondReg, Src1LeftoverRegs[I], Src2LeftoverRegs[I]);
5526     DstLeftoverRegs.push_back(Select.getReg(0));
5527   }
5528 
5529   insertParts(DstReg, DstTy, NarrowTy, DstRegs,
5530               LeftoverTy, DstLeftoverRegs);
5531 
5532   MI.eraseFromParent();
5533   return Legalized;
5534 }
5535 
5536 LegalizerHelper::LegalizeResult
narrowScalarCTLZ(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5537 LegalizerHelper::narrowScalarCTLZ(MachineInstr &MI, unsigned TypeIdx,
5538                                   LLT NarrowTy) {
5539   if (TypeIdx != 1)
5540     return UnableToLegalize;
5541 
5542   Register DstReg = MI.getOperand(0).getReg();
5543   Register SrcReg = MI.getOperand(1).getReg();
5544   LLT DstTy = MRI.getType(DstReg);
5545   LLT SrcTy = MRI.getType(SrcReg);
5546   unsigned NarrowSize = NarrowTy.getSizeInBits();
5547 
5548   if (SrcTy.isScalar() && SrcTy.getSizeInBits() == 2 * NarrowSize) {
5549     const bool IsUndef = MI.getOpcode() == TargetOpcode::G_CTLZ_ZERO_UNDEF;
5550 
5551     MachineIRBuilder &B = MIRBuilder;
5552     auto UnmergeSrc = B.buildUnmerge(NarrowTy, SrcReg);
5553     // ctlz(Hi:Lo) -> Hi == 0 ? (NarrowSize + ctlz(Lo)) : ctlz(Hi)
5554     auto C_0 = B.buildConstant(NarrowTy, 0);
5555     auto HiIsZero = B.buildICmp(CmpInst::ICMP_EQ, LLT::scalar(1),
5556                                 UnmergeSrc.getReg(1), C_0);
5557     auto LoCTLZ = IsUndef ?
5558       B.buildCTLZ_ZERO_UNDEF(DstTy, UnmergeSrc.getReg(0)) :
5559       B.buildCTLZ(DstTy, UnmergeSrc.getReg(0));
5560     auto C_NarrowSize = B.buildConstant(DstTy, NarrowSize);
5561     auto HiIsZeroCTLZ = B.buildAdd(DstTy, LoCTLZ, C_NarrowSize);
5562     auto HiCTLZ = B.buildCTLZ_ZERO_UNDEF(DstTy, UnmergeSrc.getReg(1));
5563     B.buildSelect(DstReg, HiIsZero, HiIsZeroCTLZ, HiCTLZ);
5564 
5565     MI.eraseFromParent();
5566     return Legalized;
5567   }
5568 
5569   return UnableToLegalize;
5570 }
5571 
5572 LegalizerHelper::LegalizeResult
narrowScalarCTTZ(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5573 LegalizerHelper::narrowScalarCTTZ(MachineInstr &MI, unsigned TypeIdx,
5574                                   LLT NarrowTy) {
5575   if (TypeIdx != 1)
5576     return UnableToLegalize;
5577 
5578   Register DstReg = MI.getOperand(0).getReg();
5579   Register SrcReg = MI.getOperand(1).getReg();
5580   LLT DstTy = MRI.getType(DstReg);
5581   LLT SrcTy = MRI.getType(SrcReg);
5582   unsigned NarrowSize = NarrowTy.getSizeInBits();
5583 
5584   if (SrcTy.isScalar() && SrcTy.getSizeInBits() == 2 * NarrowSize) {
5585     const bool IsUndef = MI.getOpcode() == TargetOpcode::G_CTTZ_ZERO_UNDEF;
5586 
5587     MachineIRBuilder &B = MIRBuilder;
5588     auto UnmergeSrc = B.buildUnmerge(NarrowTy, SrcReg);
5589     // cttz(Hi:Lo) -> Lo == 0 ? (cttz(Hi) + NarrowSize) : cttz(Lo)
5590     auto C_0 = B.buildConstant(NarrowTy, 0);
5591     auto LoIsZero = B.buildICmp(CmpInst::ICMP_EQ, LLT::scalar(1),
5592                                 UnmergeSrc.getReg(0), C_0);
5593     auto HiCTTZ = IsUndef ?
5594       B.buildCTTZ_ZERO_UNDEF(DstTy, UnmergeSrc.getReg(1)) :
5595       B.buildCTTZ(DstTy, UnmergeSrc.getReg(1));
5596     auto C_NarrowSize = B.buildConstant(DstTy, NarrowSize);
5597     auto LoIsZeroCTTZ = B.buildAdd(DstTy, HiCTTZ, C_NarrowSize);
5598     auto LoCTTZ = B.buildCTTZ_ZERO_UNDEF(DstTy, UnmergeSrc.getReg(0));
5599     B.buildSelect(DstReg, LoIsZero, LoIsZeroCTTZ, LoCTTZ);
5600 
5601     MI.eraseFromParent();
5602     return Legalized;
5603   }
5604 
5605   return UnableToLegalize;
5606 }
5607 
5608 LegalizerHelper::LegalizeResult
narrowScalarCTPOP(MachineInstr & MI,unsigned TypeIdx,LLT NarrowTy)5609 LegalizerHelper::narrowScalarCTPOP(MachineInstr &MI, unsigned TypeIdx,
5610                                    LLT NarrowTy) {
5611   if (TypeIdx != 1)
5612     return UnableToLegalize;
5613 
5614   Register DstReg = MI.getOperand(0).getReg();
5615   LLT DstTy = MRI.getType(DstReg);
5616   LLT SrcTy = MRI.getType(MI.getOperand(1).getReg());
5617   unsigned NarrowSize = NarrowTy.getSizeInBits();
5618 
5619   if (SrcTy.isScalar() && SrcTy.getSizeInBits() == 2 * NarrowSize) {
5620     auto UnmergeSrc = MIRBuilder.buildUnmerge(NarrowTy, MI.getOperand(1));
5621 
5622     auto LoCTPOP = MIRBuilder.buildCTPOP(DstTy, UnmergeSrc.getReg(0));
5623     auto HiCTPOP = MIRBuilder.buildCTPOP(DstTy, UnmergeSrc.getReg(1));
5624     MIRBuilder.buildAdd(DstReg, HiCTPOP, LoCTPOP);
5625 
5626     MI.eraseFromParent();
5627     return Legalized;
5628   }
5629 
5630   return UnableToLegalize;
5631 }
5632 
5633 LegalizerHelper::LegalizeResult
lowerBitCount(MachineInstr & MI)5634 LegalizerHelper::lowerBitCount(MachineInstr &MI) {
5635   unsigned Opc = MI.getOpcode();
5636   const auto &TII = MIRBuilder.getTII();
5637   auto isSupported = [this](const LegalityQuery &Q) {
5638     auto QAction = LI.getAction(Q).Action;
5639     return QAction == Legal || QAction == Libcall || QAction == Custom;
5640   };
5641   switch (Opc) {
5642   default:
5643     return UnableToLegalize;
5644   case TargetOpcode::G_CTLZ_ZERO_UNDEF: {
5645     // This trivially expands to CTLZ.
5646     Observer.changingInstr(MI);
5647     MI.setDesc(TII.get(TargetOpcode::G_CTLZ));
5648     Observer.changedInstr(MI);
5649     return Legalized;
5650   }
5651   case TargetOpcode::G_CTLZ: {
5652     Register DstReg = MI.getOperand(0).getReg();
5653     Register SrcReg = MI.getOperand(1).getReg();
5654     LLT DstTy = MRI.getType(DstReg);
5655     LLT SrcTy = MRI.getType(SrcReg);
5656     unsigned Len = SrcTy.getSizeInBits();
5657 
5658     if (isSupported({TargetOpcode::G_CTLZ_ZERO_UNDEF, {DstTy, SrcTy}})) {
5659       // If CTLZ_ZERO_UNDEF is supported, emit that and a select for zero.
5660       auto CtlzZU = MIRBuilder.buildCTLZ_ZERO_UNDEF(DstTy, SrcReg);
5661       auto ZeroSrc = MIRBuilder.buildConstant(SrcTy, 0);
5662       auto ICmp = MIRBuilder.buildICmp(
5663           CmpInst::ICMP_EQ, SrcTy.changeElementSize(1), SrcReg, ZeroSrc);
5664       auto LenConst = MIRBuilder.buildConstant(DstTy, Len);
5665       MIRBuilder.buildSelect(DstReg, ICmp, LenConst, CtlzZU);
5666       MI.eraseFromParent();
5667       return Legalized;
5668     }
5669     // for now, we do this:
5670     // NewLen = NextPowerOf2(Len);
5671     // x = x | (x >> 1);
5672     // x = x | (x >> 2);
5673     // ...
5674     // x = x | (x >>16);
5675     // x = x | (x >>32); // for 64-bit input
5676     // Upto NewLen/2
5677     // return Len - popcount(x);
5678     //
5679     // Ref: "Hacker's Delight" by Henry Warren
5680     Register Op = SrcReg;
5681     unsigned NewLen = PowerOf2Ceil(Len);
5682     for (unsigned i = 0; (1U << i) <= (NewLen / 2); ++i) {
5683       auto MIBShiftAmt = MIRBuilder.buildConstant(SrcTy, 1ULL << i);
5684       auto MIBOp = MIRBuilder.buildOr(
5685           SrcTy, Op, MIRBuilder.buildLShr(SrcTy, Op, MIBShiftAmt));
5686       Op = MIBOp.getReg(0);
5687     }
5688     auto MIBPop = MIRBuilder.buildCTPOP(DstTy, Op);
5689     MIRBuilder.buildSub(MI.getOperand(0), MIRBuilder.buildConstant(DstTy, Len),
5690                         MIBPop);
5691     MI.eraseFromParent();
5692     return Legalized;
5693   }
5694   case TargetOpcode::G_CTTZ_ZERO_UNDEF: {
5695     // This trivially expands to CTTZ.
5696     Observer.changingInstr(MI);
5697     MI.setDesc(TII.get(TargetOpcode::G_CTTZ));
5698     Observer.changedInstr(MI);
5699     return Legalized;
5700   }
5701   case TargetOpcode::G_CTTZ: {
5702     Register DstReg = MI.getOperand(0).getReg();
5703     Register SrcReg = MI.getOperand(1).getReg();
5704     LLT DstTy = MRI.getType(DstReg);
5705     LLT SrcTy = MRI.getType(SrcReg);
5706 
5707     unsigned Len = SrcTy.getSizeInBits();
5708     if (isSupported({TargetOpcode::G_CTTZ_ZERO_UNDEF, {DstTy, SrcTy}})) {
5709       // If CTTZ_ZERO_UNDEF is legal or custom, emit that and a select with
5710       // zero.
5711       auto CttzZU = MIRBuilder.buildCTTZ_ZERO_UNDEF(DstTy, SrcReg);
5712       auto Zero = MIRBuilder.buildConstant(SrcTy, 0);
5713       auto ICmp = MIRBuilder.buildICmp(
5714           CmpInst::ICMP_EQ, DstTy.changeElementSize(1), SrcReg, Zero);
5715       auto LenConst = MIRBuilder.buildConstant(DstTy, Len);
5716       MIRBuilder.buildSelect(DstReg, ICmp, LenConst, CttzZU);
5717       MI.eraseFromParent();
5718       return Legalized;
5719     }
5720     // for now, we use: { return popcount(~x & (x - 1)); }
5721     // unless the target has ctlz but not ctpop, in which case we use:
5722     // { return 32 - nlz(~x & (x-1)); }
5723     // Ref: "Hacker's Delight" by Henry Warren
5724     auto MIBCstNeg1 = MIRBuilder.buildConstant(SrcTy, -1);
5725     auto MIBNot = MIRBuilder.buildXor(SrcTy, SrcReg, MIBCstNeg1);
5726     auto MIBTmp = MIRBuilder.buildAnd(
5727         SrcTy, MIBNot, MIRBuilder.buildAdd(SrcTy, SrcReg, MIBCstNeg1));
5728     if (!isSupported({TargetOpcode::G_CTPOP, {SrcTy, SrcTy}}) &&
5729         isSupported({TargetOpcode::G_CTLZ, {SrcTy, SrcTy}})) {
5730       auto MIBCstLen = MIRBuilder.buildConstant(SrcTy, Len);
5731       MIRBuilder.buildSub(MI.getOperand(0), MIBCstLen,
5732                           MIRBuilder.buildCTLZ(SrcTy, MIBTmp));
5733       MI.eraseFromParent();
5734       return Legalized;
5735     }
5736     MI.setDesc(TII.get(TargetOpcode::G_CTPOP));
5737     MI.getOperand(1).setReg(MIBTmp.getReg(0));
5738     return Legalized;
5739   }
5740   case TargetOpcode::G_CTPOP: {
5741     Register SrcReg = MI.getOperand(1).getReg();
5742     LLT Ty = MRI.getType(SrcReg);
5743     unsigned Size = Ty.getSizeInBits();
5744     MachineIRBuilder &B = MIRBuilder;
5745 
5746     // Count set bits in blocks of 2 bits. Default approach would be
5747     // B2Count = { val & 0x55555555 } + { (val >> 1) & 0x55555555 }
5748     // We use following formula instead:
5749     // B2Count = val - { (val >> 1) & 0x55555555 }
5750     // since it gives same result in blocks of 2 with one instruction less.
5751     auto C_1 = B.buildConstant(Ty, 1);
5752     auto B2Set1LoTo1Hi = B.buildLShr(Ty, SrcReg, C_1);
5753     APInt B2Mask1HiTo0 = APInt::getSplat(Size, APInt(8, 0x55));
5754     auto C_B2Mask1HiTo0 = B.buildConstant(Ty, B2Mask1HiTo0);
5755     auto B2Count1Hi = B.buildAnd(Ty, B2Set1LoTo1Hi, C_B2Mask1HiTo0);
5756     auto B2Count = B.buildSub(Ty, SrcReg, B2Count1Hi);
5757 
5758     // In order to get count in blocks of 4 add values from adjacent block of 2.
5759     // B4Count = { B2Count & 0x33333333 } + { (B2Count >> 2) & 0x33333333 }
5760     auto C_2 = B.buildConstant(Ty, 2);
5761     auto B4Set2LoTo2Hi = B.buildLShr(Ty, B2Count, C_2);
5762     APInt B4Mask2HiTo0 = APInt::getSplat(Size, APInt(8, 0x33));
5763     auto C_B4Mask2HiTo0 = B.buildConstant(Ty, B4Mask2HiTo0);
5764     auto B4HiB2Count = B.buildAnd(Ty, B4Set2LoTo2Hi, C_B4Mask2HiTo0);
5765     auto B4LoB2Count = B.buildAnd(Ty, B2Count, C_B4Mask2HiTo0);
5766     auto B4Count = B.buildAdd(Ty, B4HiB2Count, B4LoB2Count);
5767 
5768     // For count in blocks of 8 bits we don't have to mask high 4 bits before
5769     // addition since count value sits in range {0,...,8} and 4 bits are enough
5770     // to hold such binary values. After addition high 4 bits still hold count
5771     // of set bits in high 4 bit block, set them to zero and get 8 bit result.
5772     // B8Count = { B4Count + (B4Count >> 4) } & 0x0F0F0F0F
5773     auto C_4 = B.buildConstant(Ty, 4);
5774     auto B8HiB4Count = B.buildLShr(Ty, B4Count, C_4);
5775     auto B8CountDirty4Hi = B.buildAdd(Ty, B8HiB4Count, B4Count);
5776     APInt B8Mask4HiTo0 = APInt::getSplat(Size, APInt(8, 0x0F));
5777     auto C_B8Mask4HiTo0 = B.buildConstant(Ty, B8Mask4HiTo0);
5778     auto B8Count = B.buildAnd(Ty, B8CountDirty4Hi, C_B8Mask4HiTo0);
5779 
5780     assert(Size<=128 && "Scalar size is too large for CTPOP lower algorithm");
5781     // 8 bits can hold CTPOP result of 128 bit int or smaller. Mul with this
5782     // bitmask will set 8 msb in ResTmp to sum of all B8Counts in 8 bit blocks.
5783     auto MulMask = B.buildConstant(Ty, APInt::getSplat(Size, APInt(8, 0x01)));
5784     auto ResTmp = B.buildMul(Ty, B8Count, MulMask);
5785 
5786     // Shift count result from 8 high bits to low bits.
5787     auto C_SizeM8 = B.buildConstant(Ty, Size - 8);
5788     B.buildLShr(MI.getOperand(0).getReg(), ResTmp, C_SizeM8);
5789 
5790     MI.eraseFromParent();
5791     return Legalized;
5792   }
5793   }
5794 }
5795 
5796 // Check that (every element of) Reg is undef or not an exact multiple of BW.
isNonZeroModBitWidthOrUndef(const MachineRegisterInfo & MRI,Register Reg,unsigned BW)5797 static bool isNonZeroModBitWidthOrUndef(const MachineRegisterInfo &MRI,
5798                                         Register Reg, unsigned BW) {
5799   return matchUnaryPredicate(
5800       MRI, Reg,
5801       [=](const Constant *C) {
5802         // Null constant here means an undef.
5803         const ConstantInt *CI = dyn_cast_or_null<ConstantInt>(C);
5804         return !CI || CI->getValue().urem(BW) != 0;
5805       },
5806       /*AllowUndefs*/ true);
5807 }
5808 
5809 LegalizerHelper::LegalizeResult
lowerFunnelShiftWithInverse(MachineInstr & MI)5810 LegalizerHelper::lowerFunnelShiftWithInverse(MachineInstr &MI) {
5811   Register Dst = MI.getOperand(0).getReg();
5812   Register X = MI.getOperand(1).getReg();
5813   Register Y = MI.getOperand(2).getReg();
5814   Register Z = MI.getOperand(3).getReg();
5815   LLT Ty = MRI.getType(Dst);
5816   LLT ShTy = MRI.getType(Z);
5817 
5818   unsigned BW = Ty.getScalarSizeInBits();
5819 
5820   if (!isPowerOf2_32(BW))
5821     return UnableToLegalize;
5822 
5823   const bool IsFSHL = MI.getOpcode() == TargetOpcode::G_FSHL;
5824   unsigned RevOpcode = IsFSHL ? TargetOpcode::G_FSHR : TargetOpcode::G_FSHL;
5825 
5826   if (isNonZeroModBitWidthOrUndef(MRI, Z, BW)) {
5827     // fshl X, Y, Z -> fshr X, Y, -Z
5828     // fshr X, Y, Z -> fshl X, Y, -Z
5829     auto Zero = MIRBuilder.buildConstant(ShTy, 0);
5830     Z = MIRBuilder.buildSub(Ty, Zero, Z).getReg(0);
5831   } else {
5832     // fshl X, Y, Z -> fshr (srl X, 1), (fshr X, Y, 1), ~Z
5833     // fshr X, Y, Z -> fshl (fshl X, Y, 1), (shl Y, 1), ~Z
5834     auto One = MIRBuilder.buildConstant(ShTy, 1);
5835     if (IsFSHL) {
5836       Y = MIRBuilder.buildInstr(RevOpcode, {Ty}, {X, Y, One}).getReg(0);
5837       X = MIRBuilder.buildLShr(Ty, X, One).getReg(0);
5838     } else {
5839       X = MIRBuilder.buildInstr(RevOpcode, {Ty}, {X, Y, One}).getReg(0);
5840       Y = MIRBuilder.buildShl(Ty, Y, One).getReg(0);
5841     }
5842 
5843     Z = MIRBuilder.buildNot(ShTy, Z).getReg(0);
5844   }
5845 
5846   MIRBuilder.buildInstr(RevOpcode, {Dst}, {X, Y, Z});
5847   MI.eraseFromParent();
5848   return Legalized;
5849 }
5850 
5851 LegalizerHelper::LegalizeResult
lowerFunnelShiftAsShifts(MachineInstr & MI)5852 LegalizerHelper::lowerFunnelShiftAsShifts(MachineInstr &MI) {
5853   Register Dst = MI.getOperand(0).getReg();
5854   Register X = MI.getOperand(1).getReg();
5855   Register Y = MI.getOperand(2).getReg();
5856   Register Z = MI.getOperand(3).getReg();
5857   LLT Ty = MRI.getType(Dst);
5858   LLT ShTy = MRI.getType(Z);
5859 
5860   const unsigned BW = Ty.getScalarSizeInBits();
5861   const bool IsFSHL = MI.getOpcode() == TargetOpcode::G_FSHL;
5862 
5863   Register ShX, ShY;
5864   Register ShAmt, InvShAmt;
5865 
5866   // FIXME: Emit optimized urem by constant instead of letting it expand later.
5867   if (isNonZeroModBitWidthOrUndef(MRI, Z, BW)) {
5868     // fshl: X << C | Y >> (BW - C)
5869     // fshr: X << (BW - C) | Y >> C
5870     // where C = Z % BW is not zero
5871     auto BitWidthC = MIRBuilder.buildConstant(ShTy, BW);
5872     ShAmt = MIRBuilder.buildURem(ShTy, Z, BitWidthC).getReg(0);
5873     InvShAmt = MIRBuilder.buildSub(ShTy, BitWidthC, ShAmt).getReg(0);
5874     ShX = MIRBuilder.buildShl(Ty, X, IsFSHL ? ShAmt : InvShAmt).getReg(0);
5875     ShY = MIRBuilder.buildLShr(Ty, Y, IsFSHL ? InvShAmt : ShAmt).getReg(0);
5876   } else {
5877     // fshl: X << (Z % BW) | Y >> 1 >> (BW - 1 - (Z % BW))
5878     // fshr: X << 1 << (BW - 1 - (Z % BW)) | Y >> (Z % BW)
5879     auto Mask = MIRBuilder.buildConstant(ShTy, BW - 1);
5880     if (isPowerOf2_32(BW)) {
5881       // Z % BW -> Z & (BW - 1)
5882       ShAmt = MIRBuilder.buildAnd(ShTy, Z, Mask).getReg(0);
5883       // (BW - 1) - (Z % BW) -> ~Z & (BW - 1)
5884       auto NotZ = MIRBuilder.buildNot(ShTy, Z);
5885       InvShAmt = MIRBuilder.buildAnd(ShTy, NotZ, Mask).getReg(0);
5886     } else {
5887       auto BitWidthC = MIRBuilder.buildConstant(ShTy, BW);
5888       ShAmt = MIRBuilder.buildURem(ShTy, Z, BitWidthC).getReg(0);
5889       InvShAmt = MIRBuilder.buildSub(ShTy, Mask, ShAmt).getReg(0);
5890     }
5891 
5892     auto One = MIRBuilder.buildConstant(ShTy, 1);
5893     if (IsFSHL) {
5894       ShX = MIRBuilder.buildShl(Ty, X, ShAmt).getReg(0);
5895       auto ShY1 = MIRBuilder.buildLShr(Ty, Y, One);
5896       ShY = MIRBuilder.buildLShr(Ty, ShY1, InvShAmt).getReg(0);
5897     } else {
5898       auto ShX1 = MIRBuilder.buildShl(Ty, X, One);
5899       ShX = MIRBuilder.buildShl(Ty, ShX1, InvShAmt).getReg(0);
5900       ShY = MIRBuilder.buildLShr(Ty, Y, ShAmt).getReg(0);
5901     }
5902   }
5903 
5904   MIRBuilder.buildOr(Dst, ShX, ShY);
5905   MI.eraseFromParent();
5906   return Legalized;
5907 }
5908 
5909 LegalizerHelper::LegalizeResult
lowerFunnelShift(MachineInstr & MI)5910 LegalizerHelper::lowerFunnelShift(MachineInstr &MI) {
5911   // These operations approximately do the following (while avoiding undefined
5912   // shifts by BW):
5913   // G_FSHL: (X << (Z % BW)) | (Y >> (BW - (Z % BW)))
5914   // G_FSHR: (X << (BW - (Z % BW))) | (Y >> (Z % BW))
5915   Register Dst = MI.getOperand(0).getReg();
5916   LLT Ty = MRI.getType(Dst);
5917   LLT ShTy = MRI.getType(MI.getOperand(3).getReg());
5918 
5919   bool IsFSHL = MI.getOpcode() == TargetOpcode::G_FSHL;
5920   unsigned RevOpcode = IsFSHL ? TargetOpcode::G_FSHR : TargetOpcode::G_FSHL;
5921 
5922   // TODO: Use smarter heuristic that accounts for vector legalization.
5923   if (LI.getAction({RevOpcode, {Ty, ShTy}}).Action == Lower)
5924     return lowerFunnelShiftAsShifts(MI);
5925 
5926   // This only works for powers of 2, fallback to shifts if it fails.
5927   LegalizerHelper::LegalizeResult Result = lowerFunnelShiftWithInverse(MI);
5928   if (Result == UnableToLegalize)
5929     return lowerFunnelShiftAsShifts(MI);
5930   return Result;
5931 }
5932 
5933 LegalizerHelper::LegalizeResult
lowerRotateWithReverseRotate(MachineInstr & MI)5934 LegalizerHelper::lowerRotateWithReverseRotate(MachineInstr &MI) {
5935   Register Dst = MI.getOperand(0).getReg();
5936   Register Src = MI.getOperand(1).getReg();
5937   Register Amt = MI.getOperand(2).getReg();
5938   LLT AmtTy = MRI.getType(Amt);
5939   auto Zero = MIRBuilder.buildConstant(AmtTy, 0);
5940   bool IsLeft = MI.getOpcode() == TargetOpcode::G_ROTL;
5941   unsigned RevRot = IsLeft ? TargetOpcode::G_ROTR : TargetOpcode::G_ROTL;
5942   auto Neg = MIRBuilder.buildSub(AmtTy, Zero, Amt);
5943   MIRBuilder.buildInstr(RevRot, {Dst}, {Src, Neg});
5944   MI.eraseFromParent();
5945   return Legalized;
5946 }
5947 
lowerRotate(MachineInstr & MI)5948 LegalizerHelper::LegalizeResult LegalizerHelper::lowerRotate(MachineInstr &MI) {
5949   Register Dst = MI.getOperand(0).getReg();
5950   Register Src = MI.getOperand(1).getReg();
5951   Register Amt = MI.getOperand(2).getReg();
5952   LLT DstTy = MRI.getType(Dst);
5953   LLT SrcTy = MRI.getType(Src);
5954   LLT AmtTy = MRI.getType(Amt);
5955 
5956   unsigned EltSizeInBits = DstTy.getScalarSizeInBits();
5957   bool IsLeft = MI.getOpcode() == TargetOpcode::G_ROTL;
5958 
5959   MIRBuilder.setInstrAndDebugLoc(MI);
5960 
5961   // If a rotate in the other direction is supported, use it.
5962   unsigned RevRot = IsLeft ? TargetOpcode::G_ROTR : TargetOpcode::G_ROTL;
5963   if (LI.isLegalOrCustom({RevRot, {DstTy, SrcTy}}) &&
5964       isPowerOf2_32(EltSizeInBits))
5965     return lowerRotateWithReverseRotate(MI);
5966 
5967   // If a funnel shift is supported, use it.
5968   unsigned FShOpc = IsLeft ? TargetOpcode::G_FSHL : TargetOpcode::G_FSHR;
5969   unsigned RevFsh = !IsLeft ? TargetOpcode::G_FSHL : TargetOpcode::G_FSHR;
5970   bool IsFShLegal = false;
5971   if ((IsFShLegal = LI.isLegalOrCustom({FShOpc, {DstTy, AmtTy}})) ||
5972       LI.isLegalOrCustom({RevFsh, {DstTy, AmtTy}})) {
5973     auto buildFunnelShift = [&](unsigned Opc, Register R1, Register R2,
5974                                 Register R3) {
5975       MIRBuilder.buildInstr(Opc, {R1}, {R2, R2, R3});
5976       MI.eraseFromParent();
5977       return Legalized;
5978     };
5979     // If a funnel shift in the other direction is supported, use it.
5980     if (IsFShLegal) {
5981       return buildFunnelShift(FShOpc, Dst, Src, Amt);
5982     } else if (isPowerOf2_32(EltSizeInBits)) {
5983       Amt = MIRBuilder.buildNeg(DstTy, Amt).getReg(0);
5984       return buildFunnelShift(RevFsh, Dst, Src, Amt);
5985     }
5986   }
5987 
5988   auto Zero = MIRBuilder.buildConstant(AmtTy, 0);
5989   unsigned ShOpc = IsLeft ? TargetOpcode::G_SHL : TargetOpcode::G_LSHR;
5990   unsigned RevShiftOpc = IsLeft ? TargetOpcode::G_LSHR : TargetOpcode::G_SHL;
5991   auto BitWidthMinusOneC = MIRBuilder.buildConstant(AmtTy, EltSizeInBits - 1);
5992   Register ShVal;
5993   Register RevShiftVal;
5994   if (isPowerOf2_32(EltSizeInBits)) {
5995     // (rotl x, c) -> x << (c & (w - 1)) | x >> (-c & (w - 1))
5996     // (rotr x, c) -> x >> (c & (w - 1)) | x << (-c & (w - 1))
5997     auto NegAmt = MIRBuilder.buildSub(AmtTy, Zero, Amt);
5998     auto ShAmt = MIRBuilder.buildAnd(AmtTy, Amt, BitWidthMinusOneC);
5999     ShVal = MIRBuilder.buildInstr(ShOpc, {DstTy}, {Src, ShAmt}).getReg(0);
6000     auto RevAmt = MIRBuilder.buildAnd(AmtTy, NegAmt, BitWidthMinusOneC);
6001     RevShiftVal =
6002         MIRBuilder.buildInstr(RevShiftOpc, {DstTy}, {Src, RevAmt}).getReg(0);
6003   } else {
6004     // (rotl x, c) -> x << (c % w) | x >> 1 >> (w - 1 - (c % w))
6005     // (rotr x, c) -> x >> (c % w) | x << 1 << (w - 1 - (c % w))
6006     auto BitWidthC = MIRBuilder.buildConstant(AmtTy, EltSizeInBits);
6007     auto ShAmt = MIRBuilder.buildURem(AmtTy, Amt, BitWidthC);
6008     ShVal = MIRBuilder.buildInstr(ShOpc, {DstTy}, {Src, ShAmt}).getReg(0);
6009     auto RevAmt = MIRBuilder.buildSub(AmtTy, BitWidthMinusOneC, ShAmt);
6010     auto One = MIRBuilder.buildConstant(AmtTy, 1);
6011     auto Inner = MIRBuilder.buildInstr(RevShiftOpc, {DstTy}, {Src, One});
6012     RevShiftVal =
6013         MIRBuilder.buildInstr(RevShiftOpc, {DstTy}, {Inner, RevAmt}).getReg(0);
6014   }
6015   MIRBuilder.buildOr(Dst, ShVal, RevShiftVal);
6016   MI.eraseFromParent();
6017   return Legalized;
6018 }
6019 
6020 // Expand s32 = G_UITOFP s64 using bit operations to an IEEE float
6021 // representation.
6022 LegalizerHelper::LegalizeResult
lowerU64ToF32BitOps(MachineInstr & MI)6023 LegalizerHelper::lowerU64ToF32BitOps(MachineInstr &MI) {
6024   Register Dst = MI.getOperand(0).getReg();
6025   Register Src = MI.getOperand(1).getReg();
6026   const LLT S64 = LLT::scalar(64);
6027   const LLT S32 = LLT::scalar(32);
6028   const LLT S1 = LLT::scalar(1);
6029 
6030   assert(MRI.getType(Src) == S64 && MRI.getType(Dst) == S32);
6031 
6032   // unsigned cul2f(ulong u) {
6033   //   uint lz = clz(u);
6034   //   uint e = (u != 0) ? 127U + 63U - lz : 0;
6035   //   u = (u << lz) & 0x7fffffffffffffffUL;
6036   //   ulong t = u & 0xffffffffffUL;
6037   //   uint v = (e << 23) | (uint)(u >> 40);
6038   //   uint r = t > 0x8000000000UL ? 1U : (t == 0x8000000000UL ? v & 1U : 0U);
6039   //   return as_float(v + r);
6040   // }
6041 
6042   auto Zero32 = MIRBuilder.buildConstant(S32, 0);
6043   auto Zero64 = MIRBuilder.buildConstant(S64, 0);
6044 
6045   auto LZ = MIRBuilder.buildCTLZ_ZERO_UNDEF(S32, Src);
6046 
6047   auto K = MIRBuilder.buildConstant(S32, 127U + 63U);
6048   auto Sub = MIRBuilder.buildSub(S32, K, LZ);
6049 
6050   auto NotZero = MIRBuilder.buildICmp(CmpInst::ICMP_NE, S1, Src, Zero64);
6051   auto E = MIRBuilder.buildSelect(S32, NotZero, Sub, Zero32);
6052 
6053   auto Mask0 = MIRBuilder.buildConstant(S64, (-1ULL) >> 1);
6054   auto ShlLZ = MIRBuilder.buildShl(S64, Src, LZ);
6055 
6056   auto U = MIRBuilder.buildAnd(S64, ShlLZ, Mask0);
6057 
6058   auto Mask1 = MIRBuilder.buildConstant(S64, 0xffffffffffULL);
6059   auto T = MIRBuilder.buildAnd(S64, U, Mask1);
6060 
6061   auto UShl = MIRBuilder.buildLShr(S64, U, MIRBuilder.buildConstant(S64, 40));
6062   auto ShlE = MIRBuilder.buildShl(S32, E, MIRBuilder.buildConstant(S32, 23));
6063   auto V = MIRBuilder.buildOr(S32, ShlE, MIRBuilder.buildTrunc(S32, UShl));
6064 
6065   auto C = MIRBuilder.buildConstant(S64, 0x8000000000ULL);
6066   auto RCmp = MIRBuilder.buildICmp(CmpInst::ICMP_UGT, S1, T, C);
6067   auto TCmp = MIRBuilder.buildICmp(CmpInst::ICMP_EQ, S1, T, C);
6068   auto One = MIRBuilder.buildConstant(S32, 1);
6069 
6070   auto VTrunc1 = MIRBuilder.buildAnd(S32, V, One);
6071   auto Select0 = MIRBuilder.buildSelect(S32, TCmp, VTrunc1, Zero32);
6072   auto R = MIRBuilder.buildSelect(S32, RCmp, One, Select0);
6073   MIRBuilder.buildAdd(Dst, V, R);
6074 
6075   MI.eraseFromParent();
6076   return Legalized;
6077 }
6078 
lowerUITOFP(MachineInstr & MI)6079 LegalizerHelper::LegalizeResult LegalizerHelper::lowerUITOFP(MachineInstr &MI) {
6080   Register Dst = MI.getOperand(0).getReg();
6081   Register Src = MI.getOperand(1).getReg();
6082   LLT DstTy = MRI.getType(Dst);
6083   LLT SrcTy = MRI.getType(Src);
6084 
6085   if (SrcTy == LLT::scalar(1)) {
6086     auto True = MIRBuilder.buildFConstant(DstTy, 1.0);
6087     auto False = MIRBuilder.buildFConstant(DstTy, 0.0);
6088     MIRBuilder.buildSelect(Dst, Src, True, False);
6089     MI.eraseFromParent();
6090     return Legalized;
6091   }
6092 
6093   if (SrcTy != LLT::scalar(64))
6094     return UnableToLegalize;
6095 
6096   if (DstTy == LLT::scalar(32)) {
6097     // TODO: SelectionDAG has several alternative expansions to port which may
6098     // be more reasonble depending on the available instructions. If a target
6099     // has sitofp, does not have CTLZ, or can efficiently use f64 as an
6100     // intermediate type, this is probably worse.
6101     return lowerU64ToF32BitOps(MI);
6102   }
6103 
6104   return UnableToLegalize;
6105 }
6106 
lowerSITOFP(MachineInstr & MI)6107 LegalizerHelper::LegalizeResult LegalizerHelper::lowerSITOFP(MachineInstr &MI) {
6108   Register Dst = MI.getOperand(0).getReg();
6109   Register Src = MI.getOperand(1).getReg();
6110   LLT DstTy = MRI.getType(Dst);
6111   LLT SrcTy = MRI.getType(Src);
6112 
6113   const LLT S64 = LLT::scalar(64);
6114   const LLT S32 = LLT::scalar(32);
6115   const LLT S1 = LLT::scalar(1);
6116 
6117   if (SrcTy == S1) {
6118     auto True = MIRBuilder.buildFConstant(DstTy, -1.0);
6119     auto False = MIRBuilder.buildFConstant(DstTy, 0.0);
6120     MIRBuilder.buildSelect(Dst, Src, True, False);
6121     MI.eraseFromParent();
6122     return Legalized;
6123   }
6124 
6125   if (SrcTy != S64)
6126     return UnableToLegalize;
6127 
6128   if (DstTy == S32) {
6129     // signed cl2f(long l) {
6130     //   long s = l >> 63;
6131     //   float r = cul2f((l + s) ^ s);
6132     //   return s ? -r : r;
6133     // }
6134     Register L = Src;
6135     auto SignBit = MIRBuilder.buildConstant(S64, 63);
6136     auto S = MIRBuilder.buildAShr(S64, L, SignBit);
6137 
6138     auto LPlusS = MIRBuilder.buildAdd(S64, L, S);
6139     auto Xor = MIRBuilder.buildXor(S64, LPlusS, S);
6140     auto R = MIRBuilder.buildUITOFP(S32, Xor);
6141 
6142     auto RNeg = MIRBuilder.buildFNeg(S32, R);
6143     auto SignNotZero = MIRBuilder.buildICmp(CmpInst::ICMP_NE, S1, S,
6144                                             MIRBuilder.buildConstant(S64, 0));
6145     MIRBuilder.buildSelect(Dst, SignNotZero, RNeg, R);
6146     MI.eraseFromParent();
6147     return Legalized;
6148   }
6149 
6150   return UnableToLegalize;
6151 }
6152 
lowerFPTOUI(MachineInstr & MI)6153 LegalizerHelper::LegalizeResult LegalizerHelper::lowerFPTOUI(MachineInstr &MI) {
6154   Register Dst = MI.getOperand(0).getReg();
6155   Register Src = MI.getOperand(1).getReg();
6156   LLT DstTy = MRI.getType(Dst);
6157   LLT SrcTy = MRI.getType(Src);
6158   const LLT S64 = LLT::scalar(64);
6159   const LLT S32 = LLT::scalar(32);
6160 
6161   if (SrcTy != S64 && SrcTy != S32)
6162     return UnableToLegalize;
6163   if (DstTy != S32 && DstTy != S64)
6164     return UnableToLegalize;
6165 
6166   // FPTOSI gives same result as FPTOUI for positive signed integers.
6167   // FPTOUI needs to deal with fp values that convert to unsigned integers
6168   // greater or equal to 2^31 for float or 2^63 for double. For brevity 2^Exp.
6169 
6170   APInt TwoPExpInt = APInt::getSignMask(DstTy.getSizeInBits());
6171   APFloat TwoPExpFP(SrcTy.getSizeInBits() == 32 ? APFloat::IEEEsingle()
6172                                                 : APFloat::IEEEdouble(),
6173                     APInt::getZero(SrcTy.getSizeInBits()));
6174   TwoPExpFP.convertFromAPInt(TwoPExpInt, false, APFloat::rmNearestTiesToEven);
6175 
6176   MachineInstrBuilder FPTOSI = MIRBuilder.buildFPTOSI(DstTy, Src);
6177 
6178   MachineInstrBuilder Threshold = MIRBuilder.buildFConstant(SrcTy, TwoPExpFP);
6179   // For fp Value greater or equal to Threshold(2^Exp), we use FPTOSI on
6180   // (Value - 2^Exp) and add 2^Exp by setting highest bit in result to 1.
6181   MachineInstrBuilder FSub = MIRBuilder.buildFSub(SrcTy, Src, Threshold);
6182   MachineInstrBuilder ResLowBits = MIRBuilder.buildFPTOSI(DstTy, FSub);
6183   MachineInstrBuilder ResHighBit = MIRBuilder.buildConstant(DstTy, TwoPExpInt);
6184   MachineInstrBuilder Res = MIRBuilder.buildXor(DstTy, ResLowBits, ResHighBit);
6185 
6186   const LLT S1 = LLT::scalar(1);
6187 
6188   MachineInstrBuilder FCMP =
6189       MIRBuilder.buildFCmp(CmpInst::FCMP_ULT, S1, Src, Threshold);
6190   MIRBuilder.buildSelect(Dst, FCMP, FPTOSI, Res);
6191 
6192   MI.eraseFromParent();
6193   return Legalized;
6194 }
6195 
lowerFPTOSI(MachineInstr & MI)6196 LegalizerHelper::LegalizeResult LegalizerHelper::lowerFPTOSI(MachineInstr &MI) {
6197   Register Dst = MI.getOperand(0).getReg();
6198   Register Src = MI.getOperand(1).getReg();
6199   LLT DstTy = MRI.getType(Dst);
6200   LLT SrcTy = MRI.getType(Src);
6201   const LLT S64 = LLT::scalar(64);
6202   const LLT S32 = LLT::scalar(32);
6203 
6204   // FIXME: Only f32 to i64 conversions are supported.
6205   if (SrcTy.getScalarType() != S32 || DstTy.getScalarType() != S64)
6206     return UnableToLegalize;
6207 
6208   // Expand f32 -> i64 conversion
6209   // This algorithm comes from compiler-rt's implementation of fixsfdi:
6210   // https://github.com/llvm/llvm-project/blob/main/compiler-rt/lib/builtins/fixsfdi.c
6211 
6212   unsigned SrcEltBits = SrcTy.getScalarSizeInBits();
6213 
6214   auto ExponentMask = MIRBuilder.buildConstant(SrcTy, 0x7F800000);
6215   auto ExponentLoBit = MIRBuilder.buildConstant(SrcTy, 23);
6216 
6217   auto AndExpMask = MIRBuilder.buildAnd(SrcTy, Src, ExponentMask);
6218   auto ExponentBits = MIRBuilder.buildLShr(SrcTy, AndExpMask, ExponentLoBit);
6219 
6220   auto SignMask = MIRBuilder.buildConstant(SrcTy,
6221                                            APInt::getSignMask(SrcEltBits));
6222   auto AndSignMask = MIRBuilder.buildAnd(SrcTy, Src, SignMask);
6223   auto SignLowBit = MIRBuilder.buildConstant(SrcTy, SrcEltBits - 1);
6224   auto Sign = MIRBuilder.buildAShr(SrcTy, AndSignMask, SignLowBit);
6225   Sign = MIRBuilder.buildSExt(DstTy, Sign);
6226 
6227   auto MantissaMask = MIRBuilder.buildConstant(SrcTy, 0x007FFFFF);
6228   auto AndMantissaMask = MIRBuilder.buildAnd(SrcTy, Src, MantissaMask);
6229   auto K = MIRBuilder.buildConstant(SrcTy, 0x00800000);
6230 
6231   auto R = MIRBuilder.buildOr(SrcTy, AndMantissaMask, K);
6232   R = MIRBuilder.buildZExt(DstTy, R);
6233 
6234   auto Bias = MIRBuilder.buildConstant(SrcTy, 127);
6235   auto Exponent = MIRBuilder.buildSub(SrcTy, ExponentBits, Bias);
6236   auto SubExponent = MIRBuilder.buildSub(SrcTy, Exponent, ExponentLoBit);
6237   auto ExponentSub = MIRBuilder.buildSub(SrcTy, ExponentLoBit, Exponent);
6238 
6239   auto Shl = MIRBuilder.buildShl(DstTy, R, SubExponent);
6240   auto Srl = MIRBuilder.buildLShr(DstTy, R, ExponentSub);
6241 
6242   const LLT S1 = LLT::scalar(1);
6243   auto CmpGt = MIRBuilder.buildICmp(CmpInst::ICMP_SGT,
6244                                     S1, Exponent, ExponentLoBit);
6245 
6246   R = MIRBuilder.buildSelect(DstTy, CmpGt, Shl, Srl);
6247 
6248   auto XorSign = MIRBuilder.buildXor(DstTy, R, Sign);
6249   auto Ret = MIRBuilder.buildSub(DstTy, XorSign, Sign);
6250 
6251   auto ZeroSrcTy = MIRBuilder.buildConstant(SrcTy, 0);
6252 
6253   auto ExponentLt0 = MIRBuilder.buildICmp(CmpInst::ICMP_SLT,
6254                                           S1, Exponent, ZeroSrcTy);
6255 
6256   auto ZeroDstTy = MIRBuilder.buildConstant(DstTy, 0);
6257   MIRBuilder.buildSelect(Dst, ExponentLt0, ZeroDstTy, Ret);
6258 
6259   MI.eraseFromParent();
6260   return Legalized;
6261 }
6262 
6263 // f64 -> f16 conversion using round-to-nearest-even rounding mode.
6264 LegalizerHelper::LegalizeResult
lowerFPTRUNC_F64_TO_F16(MachineInstr & MI)6265 LegalizerHelper::lowerFPTRUNC_F64_TO_F16(MachineInstr &MI) {
6266   Register Dst = MI.getOperand(0).getReg();
6267   Register Src = MI.getOperand(1).getReg();
6268 
6269   if (MRI.getType(Src).isVector()) // TODO: Handle vectors directly.
6270     return UnableToLegalize;
6271 
6272   const unsigned ExpMask = 0x7ff;
6273   const unsigned ExpBiasf64 = 1023;
6274   const unsigned ExpBiasf16 = 15;
6275   const LLT S32 = LLT::scalar(32);
6276   const LLT S1 = LLT::scalar(1);
6277 
6278   auto Unmerge = MIRBuilder.buildUnmerge(S32, Src);
6279   Register U = Unmerge.getReg(0);
6280   Register UH = Unmerge.getReg(1);
6281 
6282   auto E = MIRBuilder.buildLShr(S32, UH, MIRBuilder.buildConstant(S32, 20));
6283   E = MIRBuilder.buildAnd(S32, E, MIRBuilder.buildConstant(S32, ExpMask));
6284 
6285   // Subtract the fp64 exponent bias (1023) to get the real exponent and
6286   // add the f16 bias (15) to get the biased exponent for the f16 format.
6287   E = MIRBuilder.buildAdd(
6288     S32, E, MIRBuilder.buildConstant(S32, -ExpBiasf64 + ExpBiasf16));
6289 
6290   auto M = MIRBuilder.buildLShr(S32, UH, MIRBuilder.buildConstant(S32, 8));
6291   M = MIRBuilder.buildAnd(S32, M, MIRBuilder.buildConstant(S32, 0xffe));
6292 
6293   auto MaskedSig = MIRBuilder.buildAnd(S32, UH,
6294                                        MIRBuilder.buildConstant(S32, 0x1ff));
6295   MaskedSig = MIRBuilder.buildOr(S32, MaskedSig, U);
6296 
6297   auto Zero = MIRBuilder.buildConstant(S32, 0);
6298   auto SigCmpNE0 = MIRBuilder.buildICmp(CmpInst::ICMP_NE, S1, MaskedSig, Zero);
6299   auto Lo40Set = MIRBuilder.buildZExt(S32, SigCmpNE0);
6300   M = MIRBuilder.buildOr(S32, M, Lo40Set);
6301 
6302   // (M != 0 ? 0x0200 : 0) | 0x7c00;
6303   auto Bits0x200 = MIRBuilder.buildConstant(S32, 0x0200);
6304   auto CmpM_NE0 = MIRBuilder.buildICmp(CmpInst::ICMP_NE, S1, M, Zero);
6305   auto SelectCC = MIRBuilder.buildSelect(S32, CmpM_NE0, Bits0x200, Zero);
6306 
6307   auto Bits0x7c00 = MIRBuilder.buildConstant(S32, 0x7c00);
6308   auto I = MIRBuilder.buildOr(S32, SelectCC, Bits0x7c00);
6309 
6310   // N = M | (E << 12);
6311   auto EShl12 = MIRBuilder.buildShl(S32, E, MIRBuilder.buildConstant(S32, 12));
6312   auto N = MIRBuilder.buildOr(S32, M, EShl12);
6313 
6314   // B = clamp(1-E, 0, 13);
6315   auto One = MIRBuilder.buildConstant(S32, 1);
6316   auto OneSubExp = MIRBuilder.buildSub(S32, One, E);
6317   auto B = MIRBuilder.buildSMax(S32, OneSubExp, Zero);
6318   B = MIRBuilder.buildSMin(S32, B, MIRBuilder.buildConstant(S32, 13));
6319 
6320   auto SigSetHigh = MIRBuilder.buildOr(S32, M,
6321                                        MIRBuilder.buildConstant(S32, 0x1000));
6322 
6323   auto D = MIRBuilder.buildLShr(S32, SigSetHigh, B);
6324   auto D0 = MIRBuilder.buildShl(S32, D, B);
6325 
6326   auto D0_NE_SigSetHigh = MIRBuilder.buildICmp(CmpInst::ICMP_NE, S1,
6327                                              D0, SigSetHigh);
6328   auto D1 = MIRBuilder.buildZExt(S32, D0_NE_SigSetHigh);
6329   D = MIRBuilder.buildOr(S32, D, D1);
6330 
6331   auto CmpELtOne = MIRBuilder.buildICmp(CmpInst::ICMP_SLT, S1, E, One);
6332   auto V = MIRBuilder.buildSelect(S32, CmpELtOne, D, N);
6333 
6334   auto VLow3 = MIRBuilder.buildAnd(S32, V, MIRBuilder.buildConstant(S32, 7));
6335   V = MIRBuilder.buildLShr(S32, V, MIRBuilder.buildConstant(S32, 2));
6336 
6337   auto VLow3Eq3 = MIRBuilder.buildICmp(CmpInst::ICMP_EQ, S1, VLow3,
6338                                        MIRBuilder.buildConstant(S32, 3));
6339   auto V0 = MIRBuilder.buildZExt(S32, VLow3Eq3);
6340 
6341   auto VLow3Gt5 = MIRBuilder.buildICmp(CmpInst::ICMP_SGT, S1, VLow3,
6342                                        MIRBuilder.buildConstant(S32, 5));
6343   auto V1 = MIRBuilder.buildZExt(S32, VLow3Gt5);
6344 
6345   V1 = MIRBuilder.buildOr(S32, V0, V1);
6346   V = MIRBuilder.buildAdd(S32, V, V1);
6347 
6348   auto CmpEGt30 = MIRBuilder.buildICmp(CmpInst::ICMP_SGT,  S1,
6349                                        E, MIRBuilder.buildConstant(S32, 30));
6350   V = MIRBuilder.buildSelect(S32, CmpEGt30,
6351                              MIRBuilder.buildConstant(S32, 0x7c00), V);
6352 
6353   auto CmpEGt1039 = MIRBuilder.buildICmp(CmpInst::ICMP_EQ, S1,
6354                                          E, MIRBuilder.buildConstant(S32, 1039));
6355   V = MIRBuilder.buildSelect(S32, CmpEGt1039, I, V);
6356 
6357   // Extract the sign bit.
6358   auto Sign = MIRBuilder.buildLShr(S32, UH, MIRBuilder.buildConstant(S32, 16));
6359   Sign = MIRBuilder.buildAnd(S32, Sign, MIRBuilder.buildConstant(S32, 0x8000));
6360 
6361   // Insert the sign bit
6362   V = MIRBuilder.buildOr(S32, Sign, V);
6363 
6364   MIRBuilder.buildTrunc(Dst, V);
6365   MI.eraseFromParent();
6366   return Legalized;
6367 }
6368 
6369 LegalizerHelper::LegalizeResult
lowerFPTRUNC(MachineInstr & MI)6370 LegalizerHelper::lowerFPTRUNC(MachineInstr &MI) {
6371   Register Dst = MI.getOperand(0).getReg();
6372   Register Src = MI.getOperand(1).getReg();
6373 
6374   LLT DstTy = MRI.getType(Dst);
6375   LLT SrcTy = MRI.getType(Src);
6376   const LLT S64 = LLT::scalar(64);
6377   const LLT S16 = LLT::scalar(16);
6378 
6379   if (DstTy.getScalarType() == S16 && SrcTy.getScalarType() == S64)
6380     return lowerFPTRUNC_F64_TO_F16(MI);
6381 
6382   return UnableToLegalize;
6383 }
6384 
6385 // TODO: If RHS is a constant SelectionDAGBuilder expands this into a
6386 // multiplication tree.
lowerFPOWI(MachineInstr & MI)6387 LegalizerHelper::LegalizeResult LegalizerHelper::lowerFPOWI(MachineInstr &MI) {
6388   Register Dst = MI.getOperand(0).getReg();
6389   Register Src0 = MI.getOperand(1).getReg();
6390   Register Src1 = MI.getOperand(2).getReg();
6391   LLT Ty = MRI.getType(Dst);
6392 
6393   auto CvtSrc1 = MIRBuilder.buildSITOFP(Ty, Src1);
6394   MIRBuilder.buildFPow(Dst, Src0, CvtSrc1, MI.getFlags());
6395   MI.eraseFromParent();
6396   return Legalized;
6397 }
6398 
minMaxToCompare(unsigned Opc)6399 static CmpInst::Predicate minMaxToCompare(unsigned Opc) {
6400   switch (Opc) {
6401   case TargetOpcode::G_SMIN:
6402     return CmpInst::ICMP_SLT;
6403   case TargetOpcode::G_SMAX:
6404     return CmpInst::ICMP_SGT;
6405   case TargetOpcode::G_UMIN:
6406     return CmpInst::ICMP_ULT;
6407   case TargetOpcode::G_UMAX:
6408     return CmpInst::ICMP_UGT;
6409   default:
6410     llvm_unreachable("not in integer min/max");
6411   }
6412 }
6413 
lowerMinMax(MachineInstr & MI)6414 LegalizerHelper::LegalizeResult LegalizerHelper::lowerMinMax(MachineInstr &MI) {
6415   Register Dst = MI.getOperand(0).getReg();
6416   Register Src0 = MI.getOperand(1).getReg();
6417   Register Src1 = MI.getOperand(2).getReg();
6418 
6419   const CmpInst::Predicate Pred = minMaxToCompare(MI.getOpcode());
6420   LLT CmpType = MRI.getType(Dst).changeElementSize(1);
6421 
6422   auto Cmp = MIRBuilder.buildICmp(Pred, CmpType, Src0, Src1);
6423   MIRBuilder.buildSelect(Dst, Cmp, Src0, Src1);
6424 
6425   MI.eraseFromParent();
6426   return Legalized;
6427 }
6428 
6429 LegalizerHelper::LegalizeResult
lowerFCopySign(MachineInstr & MI)6430 LegalizerHelper::lowerFCopySign(MachineInstr &MI) {
6431   Register Dst = MI.getOperand(0).getReg();
6432   Register Src0 = MI.getOperand(1).getReg();
6433   Register Src1 = MI.getOperand(2).getReg();
6434 
6435   const LLT Src0Ty = MRI.getType(Src0);
6436   const LLT Src1Ty = MRI.getType(Src1);
6437 
6438   const int Src0Size = Src0Ty.getScalarSizeInBits();
6439   const int Src1Size = Src1Ty.getScalarSizeInBits();
6440 
6441   auto SignBitMask = MIRBuilder.buildConstant(
6442     Src0Ty, APInt::getSignMask(Src0Size));
6443 
6444   auto NotSignBitMask = MIRBuilder.buildConstant(
6445     Src0Ty, APInt::getLowBitsSet(Src0Size, Src0Size - 1));
6446 
6447   Register And0 = MIRBuilder.buildAnd(Src0Ty, Src0, NotSignBitMask).getReg(0);
6448   Register And1;
6449   if (Src0Ty == Src1Ty) {
6450     And1 = MIRBuilder.buildAnd(Src1Ty, Src1, SignBitMask).getReg(0);
6451   } else if (Src0Size > Src1Size) {
6452     auto ShiftAmt = MIRBuilder.buildConstant(Src0Ty, Src0Size - Src1Size);
6453     auto Zext = MIRBuilder.buildZExt(Src0Ty, Src1);
6454     auto Shift = MIRBuilder.buildShl(Src0Ty, Zext, ShiftAmt);
6455     And1 = MIRBuilder.buildAnd(Src0Ty, Shift, SignBitMask).getReg(0);
6456   } else {
6457     auto ShiftAmt = MIRBuilder.buildConstant(Src1Ty, Src1Size - Src0Size);
6458     auto Shift = MIRBuilder.buildLShr(Src1Ty, Src1, ShiftAmt);
6459     auto Trunc = MIRBuilder.buildTrunc(Src0Ty, Shift);
6460     And1 = MIRBuilder.buildAnd(Src0Ty, Trunc, SignBitMask).getReg(0);
6461   }
6462 
6463   // Be careful about setting nsz/nnan/ninf on every instruction, since the
6464   // constants are a nan and -0.0, but the final result should preserve
6465   // everything.
6466   unsigned Flags = MI.getFlags();
6467   MIRBuilder.buildOr(Dst, And0, And1, Flags);
6468 
6469   MI.eraseFromParent();
6470   return Legalized;
6471 }
6472 
6473 LegalizerHelper::LegalizeResult
lowerFMinNumMaxNum(MachineInstr & MI)6474 LegalizerHelper::lowerFMinNumMaxNum(MachineInstr &MI) {
6475   unsigned NewOp = MI.getOpcode() == TargetOpcode::G_FMINNUM ?
6476     TargetOpcode::G_FMINNUM_IEEE : TargetOpcode::G_FMAXNUM_IEEE;
6477 
6478   Register Dst = MI.getOperand(0).getReg();
6479   Register Src0 = MI.getOperand(1).getReg();
6480   Register Src1 = MI.getOperand(2).getReg();
6481   LLT Ty = MRI.getType(Dst);
6482 
6483   if (!MI.getFlag(MachineInstr::FmNoNans)) {
6484     // Insert canonicalizes if it's possible we need to quiet to get correct
6485     // sNaN behavior.
6486 
6487     // Note this must be done here, and not as an optimization combine in the
6488     // absence of a dedicate quiet-snan instruction as we're using an
6489     // omni-purpose G_FCANONICALIZE.
6490     if (!isKnownNeverSNaN(Src0, MRI))
6491       Src0 = MIRBuilder.buildFCanonicalize(Ty, Src0, MI.getFlags()).getReg(0);
6492 
6493     if (!isKnownNeverSNaN(Src1, MRI))
6494       Src1 = MIRBuilder.buildFCanonicalize(Ty, Src1, MI.getFlags()).getReg(0);
6495   }
6496 
6497   // If there are no nans, it's safe to simply replace this with the non-IEEE
6498   // version.
6499   MIRBuilder.buildInstr(NewOp, {Dst}, {Src0, Src1}, MI.getFlags());
6500   MI.eraseFromParent();
6501   return Legalized;
6502 }
6503 
lowerFMad(MachineInstr & MI)6504 LegalizerHelper::LegalizeResult LegalizerHelper::lowerFMad(MachineInstr &MI) {
6505   // Expand G_FMAD a, b, c -> G_FADD (G_FMUL a, b), c
6506   Register DstReg = MI.getOperand(0).getReg();
6507   LLT Ty = MRI.getType(DstReg);
6508   unsigned Flags = MI.getFlags();
6509 
6510   auto Mul = MIRBuilder.buildFMul(Ty, MI.getOperand(1), MI.getOperand(2),
6511                                   Flags);
6512   MIRBuilder.buildFAdd(DstReg, Mul, MI.getOperand(3), Flags);
6513   MI.eraseFromParent();
6514   return Legalized;
6515 }
6516 
6517 LegalizerHelper::LegalizeResult
lowerIntrinsicRound(MachineInstr & MI)6518 LegalizerHelper::lowerIntrinsicRound(MachineInstr &MI) {
6519   Register DstReg = MI.getOperand(0).getReg();
6520   Register X = MI.getOperand(1).getReg();
6521   const unsigned Flags = MI.getFlags();
6522   const LLT Ty = MRI.getType(DstReg);
6523   const LLT CondTy = Ty.changeElementSize(1);
6524 
6525   // round(x) =>
6526   //  t = trunc(x);
6527   //  d = fabs(x - t);
6528   //  o = copysign(1.0f, x);
6529   //  return t + (d >= 0.5 ? o : 0.0);
6530 
6531   auto T = MIRBuilder.buildIntrinsicTrunc(Ty, X, Flags);
6532 
6533   auto Diff = MIRBuilder.buildFSub(Ty, X, T, Flags);
6534   auto AbsDiff = MIRBuilder.buildFAbs(Ty, Diff, Flags);
6535   auto Zero = MIRBuilder.buildFConstant(Ty, 0.0);
6536   auto One = MIRBuilder.buildFConstant(Ty, 1.0);
6537   auto Half = MIRBuilder.buildFConstant(Ty, 0.5);
6538   auto SignOne = MIRBuilder.buildFCopysign(Ty, One, X);
6539 
6540   auto Cmp = MIRBuilder.buildFCmp(CmpInst::FCMP_OGE, CondTy, AbsDiff, Half,
6541                                   Flags);
6542   auto Sel = MIRBuilder.buildSelect(Ty, Cmp, SignOne, Zero, Flags);
6543 
6544   MIRBuilder.buildFAdd(DstReg, T, Sel, Flags);
6545 
6546   MI.eraseFromParent();
6547   return Legalized;
6548 }
6549 
6550 LegalizerHelper::LegalizeResult
lowerFFloor(MachineInstr & MI)6551 LegalizerHelper::lowerFFloor(MachineInstr &MI) {
6552   Register DstReg = MI.getOperand(0).getReg();
6553   Register SrcReg = MI.getOperand(1).getReg();
6554   unsigned Flags = MI.getFlags();
6555   LLT Ty = MRI.getType(DstReg);
6556   const LLT CondTy = Ty.changeElementSize(1);
6557 
6558   // result = trunc(src);
6559   // if (src < 0.0 && src != result)
6560   //   result += -1.0.
6561 
6562   auto Trunc = MIRBuilder.buildIntrinsicTrunc(Ty, SrcReg, Flags);
6563   auto Zero = MIRBuilder.buildFConstant(Ty, 0.0);
6564 
6565   auto Lt0 = MIRBuilder.buildFCmp(CmpInst::FCMP_OLT, CondTy,
6566                                   SrcReg, Zero, Flags);
6567   auto NeTrunc = MIRBuilder.buildFCmp(CmpInst::FCMP_ONE, CondTy,
6568                                       SrcReg, Trunc, Flags);
6569   auto And = MIRBuilder.buildAnd(CondTy, Lt0, NeTrunc);
6570   auto AddVal = MIRBuilder.buildSITOFP(Ty, And);
6571 
6572   MIRBuilder.buildFAdd(DstReg, Trunc, AddVal, Flags);
6573   MI.eraseFromParent();
6574   return Legalized;
6575 }
6576 
6577 LegalizerHelper::LegalizeResult
lowerMergeValues(MachineInstr & MI)6578 LegalizerHelper::lowerMergeValues(MachineInstr &MI) {
6579   const unsigned NumOps = MI.getNumOperands();
6580   Register DstReg = MI.getOperand(0).getReg();
6581   Register Src0Reg = MI.getOperand(1).getReg();
6582   LLT DstTy = MRI.getType(DstReg);
6583   LLT SrcTy = MRI.getType(Src0Reg);
6584   unsigned PartSize = SrcTy.getSizeInBits();
6585 
6586   LLT WideTy = LLT::scalar(DstTy.getSizeInBits());
6587   Register ResultReg = MIRBuilder.buildZExt(WideTy, Src0Reg).getReg(0);
6588 
6589   for (unsigned I = 2; I != NumOps; ++I) {
6590     const unsigned Offset = (I - 1) * PartSize;
6591 
6592     Register SrcReg = MI.getOperand(I).getReg();
6593     auto ZextInput = MIRBuilder.buildZExt(WideTy, SrcReg);
6594 
6595     Register NextResult = I + 1 == NumOps && WideTy == DstTy ? DstReg :
6596       MRI.createGenericVirtualRegister(WideTy);
6597 
6598     auto ShiftAmt = MIRBuilder.buildConstant(WideTy, Offset);
6599     auto Shl = MIRBuilder.buildShl(WideTy, ZextInput, ShiftAmt);
6600     MIRBuilder.buildOr(NextResult, ResultReg, Shl);
6601     ResultReg = NextResult;
6602   }
6603 
6604   if (DstTy.isPointer()) {
6605     if (MIRBuilder.getDataLayout().isNonIntegralAddressSpace(
6606           DstTy.getAddressSpace())) {
6607       LLVM_DEBUG(dbgs() << "Not casting nonintegral address space\n");
6608       return UnableToLegalize;
6609     }
6610 
6611     MIRBuilder.buildIntToPtr(DstReg, ResultReg);
6612   }
6613 
6614   MI.eraseFromParent();
6615   return Legalized;
6616 }
6617 
6618 LegalizerHelper::LegalizeResult
lowerUnmergeValues(MachineInstr & MI)6619 LegalizerHelper::lowerUnmergeValues(MachineInstr &MI) {
6620   const unsigned NumDst = MI.getNumOperands() - 1;
6621   Register SrcReg = MI.getOperand(NumDst).getReg();
6622   Register Dst0Reg = MI.getOperand(0).getReg();
6623   LLT DstTy = MRI.getType(Dst0Reg);
6624   if (DstTy.isPointer())
6625     return UnableToLegalize; // TODO
6626 
6627   SrcReg = coerceToScalar(SrcReg);
6628   if (!SrcReg)
6629     return UnableToLegalize;
6630 
6631   // Expand scalarizing unmerge as bitcast to integer and shift.
6632   LLT IntTy = MRI.getType(SrcReg);
6633 
6634   MIRBuilder.buildTrunc(Dst0Reg, SrcReg);
6635 
6636   const unsigned DstSize = DstTy.getSizeInBits();
6637   unsigned Offset = DstSize;
6638   for (unsigned I = 1; I != NumDst; ++I, Offset += DstSize) {
6639     auto ShiftAmt = MIRBuilder.buildConstant(IntTy, Offset);
6640     auto Shift = MIRBuilder.buildLShr(IntTy, SrcReg, ShiftAmt);
6641     MIRBuilder.buildTrunc(MI.getOperand(I), Shift);
6642   }
6643 
6644   MI.eraseFromParent();
6645   return Legalized;
6646 }
6647 
6648 /// Lower a vector extract or insert by writing the vector to a stack temporary
6649 /// and reloading the element or vector.
6650 ///
6651 /// %dst = G_EXTRACT_VECTOR_ELT %vec, %idx
6652 ///  =>
6653 ///  %stack_temp = G_FRAME_INDEX
6654 ///  G_STORE %vec, %stack_temp
6655 ///  %idx = clamp(%idx, %vec.getNumElements())
6656 ///  %element_ptr = G_PTR_ADD %stack_temp, %idx
6657 ///  %dst = G_LOAD %element_ptr
6658 LegalizerHelper::LegalizeResult
lowerExtractInsertVectorElt(MachineInstr & MI)6659 LegalizerHelper::lowerExtractInsertVectorElt(MachineInstr &MI) {
6660   Register DstReg = MI.getOperand(0).getReg();
6661   Register SrcVec = MI.getOperand(1).getReg();
6662   Register InsertVal;
6663   if (MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT)
6664     InsertVal = MI.getOperand(2).getReg();
6665 
6666   Register Idx = MI.getOperand(MI.getNumOperands() - 1).getReg();
6667 
6668   LLT VecTy = MRI.getType(SrcVec);
6669   LLT EltTy = VecTy.getElementType();
6670   unsigned NumElts = VecTy.getNumElements();
6671 
6672   int64_t IdxVal;
6673   if (mi_match(Idx, MRI, m_ICst(IdxVal)) && IdxVal <= NumElts) {
6674     SmallVector<Register, 8> SrcRegs;
6675     extractParts(SrcVec, EltTy, NumElts, SrcRegs);
6676 
6677     if (InsertVal) {
6678       SrcRegs[IdxVal] = MI.getOperand(2).getReg();
6679       MIRBuilder.buildMergeLikeInstr(DstReg, SrcRegs);
6680     } else {
6681       MIRBuilder.buildCopy(DstReg, SrcRegs[IdxVal]);
6682     }
6683 
6684     MI.eraseFromParent();
6685     return Legalized;
6686   }
6687 
6688   if (!EltTy.isByteSized()) { // Not implemented.
6689     LLVM_DEBUG(dbgs() << "Can't handle non-byte element vectors yet\n");
6690     return UnableToLegalize;
6691   }
6692 
6693   unsigned EltBytes = EltTy.getSizeInBytes();
6694   Align VecAlign = getStackTemporaryAlignment(VecTy);
6695   Align EltAlign;
6696 
6697   MachinePointerInfo PtrInfo;
6698   auto StackTemp = createStackTemporary(TypeSize::Fixed(VecTy.getSizeInBytes()),
6699                                         VecAlign, PtrInfo);
6700   MIRBuilder.buildStore(SrcVec, StackTemp, PtrInfo, VecAlign);
6701 
6702   // Get the pointer to the element, and be sure not to hit undefined behavior
6703   // if the index is out of bounds.
6704   Register EltPtr = getVectorElementPointer(StackTemp.getReg(0), VecTy, Idx);
6705 
6706   if (mi_match(Idx, MRI, m_ICst(IdxVal))) {
6707     int64_t Offset = IdxVal * EltBytes;
6708     PtrInfo = PtrInfo.getWithOffset(Offset);
6709     EltAlign = commonAlignment(VecAlign, Offset);
6710   } else {
6711     // We lose information with a variable offset.
6712     EltAlign = getStackTemporaryAlignment(EltTy);
6713     PtrInfo = MachinePointerInfo(MRI.getType(EltPtr).getAddressSpace());
6714   }
6715 
6716   if (InsertVal) {
6717     // Write the inserted element
6718     MIRBuilder.buildStore(InsertVal, EltPtr, PtrInfo, EltAlign);
6719 
6720     // Reload the whole vector.
6721     MIRBuilder.buildLoad(DstReg, StackTemp, PtrInfo, VecAlign);
6722   } else {
6723     MIRBuilder.buildLoad(DstReg, EltPtr, PtrInfo, EltAlign);
6724   }
6725 
6726   MI.eraseFromParent();
6727   return Legalized;
6728 }
6729 
6730 LegalizerHelper::LegalizeResult
lowerShuffleVector(MachineInstr & MI)6731 LegalizerHelper::lowerShuffleVector(MachineInstr &MI) {
6732   Register DstReg = MI.getOperand(0).getReg();
6733   Register Src0Reg = MI.getOperand(1).getReg();
6734   Register Src1Reg = MI.getOperand(2).getReg();
6735   LLT Src0Ty = MRI.getType(Src0Reg);
6736   LLT DstTy = MRI.getType(DstReg);
6737   LLT IdxTy = LLT::scalar(32);
6738 
6739   ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
6740 
6741   if (DstTy.isScalar()) {
6742     if (Src0Ty.isVector())
6743       return UnableToLegalize;
6744 
6745     // This is just a SELECT.
6746     assert(Mask.size() == 1 && "Expected a single mask element");
6747     Register Val;
6748     if (Mask[0] < 0 || Mask[0] > 1)
6749       Val = MIRBuilder.buildUndef(DstTy).getReg(0);
6750     else
6751       Val = Mask[0] == 0 ? Src0Reg : Src1Reg;
6752     MIRBuilder.buildCopy(DstReg, Val);
6753     MI.eraseFromParent();
6754     return Legalized;
6755   }
6756 
6757   Register Undef;
6758   SmallVector<Register, 32> BuildVec;
6759   LLT EltTy = DstTy.getElementType();
6760 
6761   for (int Idx : Mask) {
6762     if (Idx < 0) {
6763       if (!Undef.isValid())
6764         Undef = MIRBuilder.buildUndef(EltTy).getReg(0);
6765       BuildVec.push_back(Undef);
6766       continue;
6767     }
6768 
6769     if (Src0Ty.isScalar()) {
6770       BuildVec.push_back(Idx == 0 ? Src0Reg : Src1Reg);
6771     } else {
6772       int NumElts = Src0Ty.getNumElements();
6773       Register SrcVec = Idx < NumElts ? Src0Reg : Src1Reg;
6774       int ExtractIdx = Idx < NumElts ? Idx : Idx - NumElts;
6775       auto IdxK = MIRBuilder.buildConstant(IdxTy, ExtractIdx);
6776       auto Extract = MIRBuilder.buildExtractVectorElement(EltTy, SrcVec, IdxK);
6777       BuildVec.push_back(Extract.getReg(0));
6778     }
6779   }
6780 
6781   MIRBuilder.buildBuildVector(DstReg, BuildVec);
6782   MI.eraseFromParent();
6783   return Legalized;
6784 }
6785 
6786 LegalizerHelper::LegalizeResult
lowerDynStackAlloc(MachineInstr & MI)6787 LegalizerHelper::lowerDynStackAlloc(MachineInstr &MI) {
6788   const auto &MF = *MI.getMF();
6789   const auto &TFI = *MF.getSubtarget().getFrameLowering();
6790   if (TFI.getStackGrowthDirection() == TargetFrameLowering::StackGrowsUp)
6791     return UnableToLegalize;
6792 
6793   Register Dst = MI.getOperand(0).getReg();
6794   Register AllocSize = MI.getOperand(1).getReg();
6795   Align Alignment = assumeAligned(MI.getOperand(2).getImm());
6796 
6797   LLT PtrTy = MRI.getType(Dst);
6798   LLT IntPtrTy = LLT::scalar(PtrTy.getSizeInBits());
6799 
6800   Register SPReg = TLI.getStackPointerRegisterToSaveRestore();
6801   auto SPTmp = MIRBuilder.buildCopy(PtrTy, SPReg);
6802   SPTmp = MIRBuilder.buildCast(IntPtrTy, SPTmp);
6803 
6804   // Subtract the final alloc from the SP. We use G_PTRTOINT here so we don't
6805   // have to generate an extra instruction to negate the alloc and then use
6806   // G_PTR_ADD to add the negative offset.
6807   auto Alloc = MIRBuilder.buildSub(IntPtrTy, SPTmp, AllocSize);
6808   if (Alignment > Align(1)) {
6809     APInt AlignMask(IntPtrTy.getSizeInBits(), Alignment.value(), true);
6810     AlignMask.negate();
6811     auto AlignCst = MIRBuilder.buildConstant(IntPtrTy, AlignMask);
6812     Alloc = MIRBuilder.buildAnd(IntPtrTy, Alloc, AlignCst);
6813   }
6814 
6815   SPTmp = MIRBuilder.buildCast(PtrTy, Alloc);
6816   MIRBuilder.buildCopy(SPReg, SPTmp);
6817   MIRBuilder.buildCopy(Dst, SPTmp);
6818 
6819   MI.eraseFromParent();
6820   return Legalized;
6821 }
6822 
6823 LegalizerHelper::LegalizeResult
lowerExtract(MachineInstr & MI)6824 LegalizerHelper::lowerExtract(MachineInstr &MI) {
6825   Register Dst = MI.getOperand(0).getReg();
6826   Register Src = MI.getOperand(1).getReg();
6827   unsigned Offset = MI.getOperand(2).getImm();
6828 
6829   LLT DstTy = MRI.getType(Dst);
6830   LLT SrcTy = MRI.getType(Src);
6831 
6832   // Extract sub-vector or one element
6833   if (SrcTy.isVector()) {
6834     unsigned SrcEltSize = SrcTy.getElementType().getSizeInBits();
6835     unsigned DstSize = DstTy.getSizeInBits();
6836 
6837     if ((Offset % SrcEltSize == 0) && (DstSize % SrcEltSize == 0) &&
6838         (Offset + DstSize <= SrcTy.getSizeInBits())) {
6839       // Unmerge and allow access to each Src element for the artifact combiner.
6840       auto Unmerge = MIRBuilder.buildUnmerge(SrcTy.getElementType(), Src);
6841 
6842       // Take element(s) we need to extract and copy it (merge them).
6843       SmallVector<Register, 8> SubVectorElts;
6844       for (unsigned Idx = Offset / SrcEltSize;
6845            Idx < (Offset + DstSize) / SrcEltSize; ++Idx) {
6846         SubVectorElts.push_back(Unmerge.getReg(Idx));
6847       }
6848       if (SubVectorElts.size() == 1)
6849         MIRBuilder.buildCopy(Dst, SubVectorElts[0]);
6850       else
6851         MIRBuilder.buildMergeLikeInstr(Dst, SubVectorElts);
6852 
6853       MI.eraseFromParent();
6854       return Legalized;
6855     }
6856   }
6857 
6858   if (DstTy.isScalar() &&
6859       (SrcTy.isScalar() ||
6860        (SrcTy.isVector() && DstTy == SrcTy.getElementType()))) {
6861     LLT SrcIntTy = SrcTy;
6862     if (!SrcTy.isScalar()) {
6863       SrcIntTy = LLT::scalar(SrcTy.getSizeInBits());
6864       Src = MIRBuilder.buildBitcast(SrcIntTy, Src).getReg(0);
6865     }
6866 
6867     if (Offset == 0)
6868       MIRBuilder.buildTrunc(Dst, Src);
6869     else {
6870       auto ShiftAmt = MIRBuilder.buildConstant(SrcIntTy, Offset);
6871       auto Shr = MIRBuilder.buildLShr(SrcIntTy, Src, ShiftAmt);
6872       MIRBuilder.buildTrunc(Dst, Shr);
6873     }
6874 
6875     MI.eraseFromParent();
6876     return Legalized;
6877   }
6878 
6879   return UnableToLegalize;
6880 }
6881 
lowerInsert(MachineInstr & MI)6882 LegalizerHelper::LegalizeResult LegalizerHelper::lowerInsert(MachineInstr &MI) {
6883   Register Dst = MI.getOperand(0).getReg();
6884   Register Src = MI.getOperand(1).getReg();
6885   Register InsertSrc = MI.getOperand(2).getReg();
6886   uint64_t Offset = MI.getOperand(3).getImm();
6887 
6888   LLT DstTy = MRI.getType(Src);
6889   LLT InsertTy = MRI.getType(InsertSrc);
6890 
6891   // Insert sub-vector or one element
6892   if (DstTy.isVector() && !InsertTy.isPointer()) {
6893     LLT EltTy = DstTy.getElementType();
6894     unsigned EltSize = EltTy.getSizeInBits();
6895     unsigned InsertSize = InsertTy.getSizeInBits();
6896 
6897     if ((Offset % EltSize == 0) && (InsertSize % EltSize == 0) &&
6898         (Offset + InsertSize <= DstTy.getSizeInBits())) {
6899       auto UnmergeSrc = MIRBuilder.buildUnmerge(EltTy, Src);
6900       SmallVector<Register, 8> DstElts;
6901       unsigned Idx = 0;
6902       // Elements from Src before insert start Offset
6903       for (; Idx < Offset / EltSize; ++Idx) {
6904         DstElts.push_back(UnmergeSrc.getReg(Idx));
6905       }
6906 
6907       // Replace elements in Src with elements from InsertSrc
6908       if (InsertTy.getSizeInBits() > EltSize) {
6909         auto UnmergeInsertSrc = MIRBuilder.buildUnmerge(EltTy, InsertSrc);
6910         for (unsigned i = 0; Idx < (Offset + InsertSize) / EltSize;
6911              ++Idx, ++i) {
6912           DstElts.push_back(UnmergeInsertSrc.getReg(i));
6913         }
6914       } else {
6915         DstElts.push_back(InsertSrc);
6916         ++Idx;
6917       }
6918 
6919       // Remaining elements from Src after insert
6920       for (; Idx < DstTy.getNumElements(); ++Idx) {
6921         DstElts.push_back(UnmergeSrc.getReg(Idx));
6922       }
6923 
6924       MIRBuilder.buildMergeLikeInstr(Dst, DstElts);
6925       MI.eraseFromParent();
6926       return Legalized;
6927     }
6928   }
6929 
6930   if (InsertTy.isVector() ||
6931       (DstTy.isVector() && DstTy.getElementType() != InsertTy))
6932     return UnableToLegalize;
6933 
6934   const DataLayout &DL = MIRBuilder.getDataLayout();
6935   if ((DstTy.isPointer() &&
6936        DL.isNonIntegralAddressSpace(DstTy.getAddressSpace())) ||
6937       (InsertTy.isPointer() &&
6938        DL.isNonIntegralAddressSpace(InsertTy.getAddressSpace()))) {
6939     LLVM_DEBUG(dbgs() << "Not casting non-integral address space integer\n");
6940     return UnableToLegalize;
6941   }
6942 
6943   LLT IntDstTy = DstTy;
6944 
6945   if (!DstTy.isScalar()) {
6946     IntDstTy = LLT::scalar(DstTy.getSizeInBits());
6947     Src = MIRBuilder.buildCast(IntDstTy, Src).getReg(0);
6948   }
6949 
6950   if (!InsertTy.isScalar()) {
6951     const LLT IntInsertTy = LLT::scalar(InsertTy.getSizeInBits());
6952     InsertSrc = MIRBuilder.buildPtrToInt(IntInsertTy, InsertSrc).getReg(0);
6953   }
6954 
6955   Register ExtInsSrc = MIRBuilder.buildZExt(IntDstTy, InsertSrc).getReg(0);
6956   if (Offset != 0) {
6957     auto ShiftAmt = MIRBuilder.buildConstant(IntDstTy, Offset);
6958     ExtInsSrc = MIRBuilder.buildShl(IntDstTy, ExtInsSrc, ShiftAmt).getReg(0);
6959   }
6960 
6961   APInt MaskVal = APInt::getBitsSetWithWrap(
6962       DstTy.getSizeInBits(), Offset + InsertTy.getSizeInBits(), Offset);
6963 
6964   auto Mask = MIRBuilder.buildConstant(IntDstTy, MaskVal);
6965   auto MaskedSrc = MIRBuilder.buildAnd(IntDstTy, Src, Mask);
6966   auto Or = MIRBuilder.buildOr(IntDstTy, MaskedSrc, ExtInsSrc);
6967 
6968   MIRBuilder.buildCast(Dst, Or);
6969   MI.eraseFromParent();
6970   return Legalized;
6971 }
6972 
6973 LegalizerHelper::LegalizeResult
lowerSADDO_SSUBO(MachineInstr & MI)6974 LegalizerHelper::lowerSADDO_SSUBO(MachineInstr &MI) {
6975   Register Dst0 = MI.getOperand(0).getReg();
6976   Register Dst1 = MI.getOperand(1).getReg();
6977   Register LHS = MI.getOperand(2).getReg();
6978   Register RHS = MI.getOperand(3).getReg();
6979   const bool IsAdd = MI.getOpcode() == TargetOpcode::G_SADDO;
6980 
6981   LLT Ty = MRI.getType(Dst0);
6982   LLT BoolTy = MRI.getType(Dst1);
6983 
6984   if (IsAdd)
6985     MIRBuilder.buildAdd(Dst0, LHS, RHS);
6986   else
6987     MIRBuilder.buildSub(Dst0, LHS, RHS);
6988 
6989   // TODO: If SADDSAT/SSUBSAT is legal, compare results to detect overflow.
6990 
6991   auto Zero = MIRBuilder.buildConstant(Ty, 0);
6992 
6993   // For an addition, the result should be less than one of the operands (LHS)
6994   // if and only if the other operand (RHS) is negative, otherwise there will
6995   // be overflow.
6996   // For a subtraction, the result should be less than one of the operands
6997   // (LHS) if and only if the other operand (RHS) is (non-zero) positive,
6998   // otherwise there will be overflow.
6999   auto ResultLowerThanLHS =
7000       MIRBuilder.buildICmp(CmpInst::ICMP_SLT, BoolTy, Dst0, LHS);
7001   auto ConditionRHS = MIRBuilder.buildICmp(
7002       IsAdd ? CmpInst::ICMP_SLT : CmpInst::ICMP_SGT, BoolTy, RHS, Zero);
7003 
7004   MIRBuilder.buildXor(Dst1, ConditionRHS, ResultLowerThanLHS);
7005   MI.eraseFromParent();
7006   return Legalized;
7007 }
7008 
7009 LegalizerHelper::LegalizeResult
lowerAddSubSatToMinMax(MachineInstr & MI)7010 LegalizerHelper::lowerAddSubSatToMinMax(MachineInstr &MI) {
7011   Register Res = MI.getOperand(0).getReg();
7012   Register LHS = MI.getOperand(1).getReg();
7013   Register RHS = MI.getOperand(2).getReg();
7014   LLT Ty = MRI.getType(Res);
7015   bool IsSigned;
7016   bool IsAdd;
7017   unsigned BaseOp;
7018   switch (MI.getOpcode()) {
7019   default:
7020     llvm_unreachable("unexpected addsat/subsat opcode");
7021   case TargetOpcode::G_UADDSAT:
7022     IsSigned = false;
7023     IsAdd = true;
7024     BaseOp = TargetOpcode::G_ADD;
7025     break;
7026   case TargetOpcode::G_SADDSAT:
7027     IsSigned = true;
7028     IsAdd = true;
7029     BaseOp = TargetOpcode::G_ADD;
7030     break;
7031   case TargetOpcode::G_USUBSAT:
7032     IsSigned = false;
7033     IsAdd = false;
7034     BaseOp = TargetOpcode::G_SUB;
7035     break;
7036   case TargetOpcode::G_SSUBSAT:
7037     IsSigned = true;
7038     IsAdd = false;
7039     BaseOp = TargetOpcode::G_SUB;
7040     break;
7041   }
7042 
7043   if (IsSigned) {
7044     // sadd.sat(a, b) ->
7045     //   hi = 0x7fffffff - smax(a, 0)
7046     //   lo = 0x80000000 - smin(a, 0)
7047     //   a + smin(smax(lo, b), hi)
7048     // ssub.sat(a, b) ->
7049     //   lo = smax(a, -1) - 0x7fffffff
7050     //   hi = smin(a, -1) - 0x80000000
7051     //   a - smin(smax(lo, b), hi)
7052     // TODO: AMDGPU can use a "median of 3" instruction here:
7053     //   a +/- med3(lo, b, hi)
7054     uint64_t NumBits = Ty.getScalarSizeInBits();
7055     auto MaxVal =
7056         MIRBuilder.buildConstant(Ty, APInt::getSignedMaxValue(NumBits));
7057     auto MinVal =
7058         MIRBuilder.buildConstant(Ty, APInt::getSignedMinValue(NumBits));
7059     MachineInstrBuilder Hi, Lo;
7060     if (IsAdd) {
7061       auto Zero = MIRBuilder.buildConstant(Ty, 0);
7062       Hi = MIRBuilder.buildSub(Ty, MaxVal, MIRBuilder.buildSMax(Ty, LHS, Zero));
7063       Lo = MIRBuilder.buildSub(Ty, MinVal, MIRBuilder.buildSMin(Ty, LHS, Zero));
7064     } else {
7065       auto NegOne = MIRBuilder.buildConstant(Ty, -1);
7066       Lo = MIRBuilder.buildSub(Ty, MIRBuilder.buildSMax(Ty, LHS, NegOne),
7067                                MaxVal);
7068       Hi = MIRBuilder.buildSub(Ty, MIRBuilder.buildSMin(Ty, LHS, NegOne),
7069                                MinVal);
7070     }
7071     auto RHSClamped =
7072         MIRBuilder.buildSMin(Ty, MIRBuilder.buildSMax(Ty, Lo, RHS), Hi);
7073     MIRBuilder.buildInstr(BaseOp, {Res}, {LHS, RHSClamped});
7074   } else {
7075     // uadd.sat(a, b) -> a + umin(~a, b)
7076     // usub.sat(a, b) -> a - umin(a, b)
7077     Register Not = IsAdd ? MIRBuilder.buildNot(Ty, LHS).getReg(0) : LHS;
7078     auto Min = MIRBuilder.buildUMin(Ty, Not, RHS);
7079     MIRBuilder.buildInstr(BaseOp, {Res}, {LHS, Min});
7080   }
7081 
7082   MI.eraseFromParent();
7083   return Legalized;
7084 }
7085 
7086 LegalizerHelper::LegalizeResult
lowerAddSubSatToAddoSubo(MachineInstr & MI)7087 LegalizerHelper::lowerAddSubSatToAddoSubo(MachineInstr &MI) {
7088   Register Res = MI.getOperand(0).getReg();
7089   Register LHS = MI.getOperand(1).getReg();
7090   Register RHS = MI.getOperand(2).getReg();
7091   LLT Ty = MRI.getType(Res);
7092   LLT BoolTy = Ty.changeElementSize(1);
7093   bool IsSigned;
7094   bool IsAdd;
7095   unsigned OverflowOp;
7096   switch (MI.getOpcode()) {
7097   default:
7098     llvm_unreachable("unexpected addsat/subsat opcode");
7099   case TargetOpcode::G_UADDSAT:
7100     IsSigned = false;
7101     IsAdd = true;
7102     OverflowOp = TargetOpcode::G_UADDO;
7103     break;
7104   case TargetOpcode::G_SADDSAT:
7105     IsSigned = true;
7106     IsAdd = true;
7107     OverflowOp = TargetOpcode::G_SADDO;
7108     break;
7109   case TargetOpcode::G_USUBSAT:
7110     IsSigned = false;
7111     IsAdd = false;
7112     OverflowOp = TargetOpcode::G_USUBO;
7113     break;
7114   case TargetOpcode::G_SSUBSAT:
7115     IsSigned = true;
7116     IsAdd = false;
7117     OverflowOp = TargetOpcode::G_SSUBO;
7118     break;
7119   }
7120 
7121   auto OverflowRes =
7122       MIRBuilder.buildInstr(OverflowOp, {Ty, BoolTy}, {LHS, RHS});
7123   Register Tmp = OverflowRes.getReg(0);
7124   Register Ov = OverflowRes.getReg(1);
7125   MachineInstrBuilder Clamp;
7126   if (IsSigned) {
7127     // sadd.sat(a, b) ->
7128     //   {tmp, ov} = saddo(a, b)
7129     //   ov ? (tmp >>s 31) + 0x80000000 : r
7130     // ssub.sat(a, b) ->
7131     //   {tmp, ov} = ssubo(a, b)
7132     //   ov ? (tmp >>s 31) + 0x80000000 : r
7133     uint64_t NumBits = Ty.getScalarSizeInBits();
7134     auto ShiftAmount = MIRBuilder.buildConstant(Ty, NumBits - 1);
7135     auto Sign = MIRBuilder.buildAShr(Ty, Tmp, ShiftAmount);
7136     auto MinVal =
7137         MIRBuilder.buildConstant(Ty, APInt::getSignedMinValue(NumBits));
7138     Clamp = MIRBuilder.buildAdd(Ty, Sign, MinVal);
7139   } else {
7140     // uadd.sat(a, b) ->
7141     //   {tmp, ov} = uaddo(a, b)
7142     //   ov ? 0xffffffff : tmp
7143     // usub.sat(a, b) ->
7144     //   {tmp, ov} = usubo(a, b)
7145     //   ov ? 0 : tmp
7146     Clamp = MIRBuilder.buildConstant(Ty, IsAdd ? -1 : 0);
7147   }
7148   MIRBuilder.buildSelect(Res, Ov, Clamp, Tmp);
7149 
7150   MI.eraseFromParent();
7151   return Legalized;
7152 }
7153 
7154 LegalizerHelper::LegalizeResult
lowerShlSat(MachineInstr & MI)7155 LegalizerHelper::lowerShlSat(MachineInstr &MI) {
7156   assert((MI.getOpcode() == TargetOpcode::G_SSHLSAT ||
7157           MI.getOpcode() == TargetOpcode::G_USHLSAT) &&
7158          "Expected shlsat opcode!");
7159   bool IsSigned = MI.getOpcode() == TargetOpcode::G_SSHLSAT;
7160   Register Res = MI.getOperand(0).getReg();
7161   Register LHS = MI.getOperand(1).getReg();
7162   Register RHS = MI.getOperand(2).getReg();
7163   LLT Ty = MRI.getType(Res);
7164   LLT BoolTy = Ty.changeElementSize(1);
7165 
7166   unsigned BW = Ty.getScalarSizeInBits();
7167   auto Result = MIRBuilder.buildShl(Ty, LHS, RHS);
7168   auto Orig = IsSigned ? MIRBuilder.buildAShr(Ty, Result, RHS)
7169                        : MIRBuilder.buildLShr(Ty, Result, RHS);
7170 
7171   MachineInstrBuilder SatVal;
7172   if (IsSigned) {
7173     auto SatMin = MIRBuilder.buildConstant(Ty, APInt::getSignedMinValue(BW));
7174     auto SatMax = MIRBuilder.buildConstant(Ty, APInt::getSignedMaxValue(BW));
7175     auto Cmp = MIRBuilder.buildICmp(CmpInst::ICMP_SLT, BoolTy, LHS,
7176                                     MIRBuilder.buildConstant(Ty, 0));
7177     SatVal = MIRBuilder.buildSelect(Ty, Cmp, SatMin, SatMax);
7178   } else {
7179     SatVal = MIRBuilder.buildConstant(Ty, APInt::getMaxValue(BW));
7180   }
7181   auto Ov = MIRBuilder.buildICmp(CmpInst::ICMP_NE, BoolTy, LHS, Orig);
7182   MIRBuilder.buildSelect(Res, Ov, SatVal, Result);
7183 
7184   MI.eraseFromParent();
7185   return Legalized;
7186 }
7187 
7188 LegalizerHelper::LegalizeResult
lowerBswap(MachineInstr & MI)7189 LegalizerHelper::lowerBswap(MachineInstr &MI) {
7190   Register Dst = MI.getOperand(0).getReg();
7191   Register Src = MI.getOperand(1).getReg();
7192   const LLT Ty = MRI.getType(Src);
7193   unsigned SizeInBytes = (Ty.getScalarSizeInBits() + 7) / 8;
7194   unsigned BaseShiftAmt = (SizeInBytes - 1) * 8;
7195 
7196   // Swap most and least significant byte, set remaining bytes in Res to zero.
7197   auto ShiftAmt = MIRBuilder.buildConstant(Ty, BaseShiftAmt);
7198   auto LSByteShiftedLeft = MIRBuilder.buildShl(Ty, Src, ShiftAmt);
7199   auto MSByteShiftedRight = MIRBuilder.buildLShr(Ty, Src, ShiftAmt);
7200   auto Res = MIRBuilder.buildOr(Ty, MSByteShiftedRight, LSByteShiftedLeft);
7201 
7202   // Set i-th high/low byte in Res to i-th low/high byte from Src.
7203   for (unsigned i = 1; i < SizeInBytes / 2; ++i) {
7204     // AND with Mask leaves byte i unchanged and sets remaining bytes to 0.
7205     APInt APMask(SizeInBytes * 8, 0xFF << (i * 8));
7206     auto Mask = MIRBuilder.buildConstant(Ty, APMask);
7207     auto ShiftAmt = MIRBuilder.buildConstant(Ty, BaseShiftAmt - 16 * i);
7208     // Low byte shifted left to place of high byte: (Src & Mask) << ShiftAmt.
7209     auto LoByte = MIRBuilder.buildAnd(Ty, Src, Mask);
7210     auto LoShiftedLeft = MIRBuilder.buildShl(Ty, LoByte, ShiftAmt);
7211     Res = MIRBuilder.buildOr(Ty, Res, LoShiftedLeft);
7212     // High byte shifted right to place of low byte: (Src >> ShiftAmt) & Mask.
7213     auto SrcShiftedRight = MIRBuilder.buildLShr(Ty, Src, ShiftAmt);
7214     auto HiShiftedRight = MIRBuilder.buildAnd(Ty, SrcShiftedRight, Mask);
7215     Res = MIRBuilder.buildOr(Ty, Res, HiShiftedRight);
7216   }
7217   Res.getInstr()->getOperand(0).setReg(Dst);
7218 
7219   MI.eraseFromParent();
7220   return Legalized;
7221 }
7222 
7223 //{ (Src & Mask) >> N } | { (Src << N) & Mask }
SwapN(unsigned N,DstOp Dst,MachineIRBuilder & B,MachineInstrBuilder Src,APInt Mask)7224 static MachineInstrBuilder SwapN(unsigned N, DstOp Dst, MachineIRBuilder &B,
7225                                  MachineInstrBuilder Src, APInt Mask) {
7226   const LLT Ty = Dst.getLLTTy(*B.getMRI());
7227   MachineInstrBuilder C_N = B.buildConstant(Ty, N);
7228   MachineInstrBuilder MaskLoNTo0 = B.buildConstant(Ty, Mask);
7229   auto LHS = B.buildLShr(Ty, B.buildAnd(Ty, Src, MaskLoNTo0), C_N);
7230   auto RHS = B.buildAnd(Ty, B.buildShl(Ty, Src, C_N), MaskLoNTo0);
7231   return B.buildOr(Dst, LHS, RHS);
7232 }
7233 
7234 LegalizerHelper::LegalizeResult
lowerBitreverse(MachineInstr & MI)7235 LegalizerHelper::lowerBitreverse(MachineInstr &MI) {
7236   Register Dst = MI.getOperand(0).getReg();
7237   Register Src = MI.getOperand(1).getReg();
7238   const LLT Ty = MRI.getType(Src);
7239   unsigned Size = Ty.getSizeInBits();
7240 
7241   MachineInstrBuilder BSWAP =
7242       MIRBuilder.buildInstr(TargetOpcode::G_BSWAP, {Ty}, {Src});
7243 
7244   // swap high and low 4 bits in 8 bit blocks 7654|3210 -> 3210|7654
7245   //    [(val & 0xF0F0F0F0) >> 4] | [(val & 0x0F0F0F0F) << 4]
7246   // -> [(val & 0xF0F0F0F0) >> 4] | [(val << 4) & 0xF0F0F0F0]
7247   MachineInstrBuilder Swap4 =
7248       SwapN(4, Ty, MIRBuilder, BSWAP, APInt::getSplat(Size, APInt(8, 0xF0)));
7249 
7250   // swap high and low 2 bits in 4 bit blocks 32|10 76|54 -> 10|32 54|76
7251   //    [(val & 0xCCCCCCCC) >> 2] & [(val & 0x33333333) << 2]
7252   // -> [(val & 0xCCCCCCCC) >> 2] & [(val << 2) & 0xCCCCCCCC]
7253   MachineInstrBuilder Swap2 =
7254       SwapN(2, Ty, MIRBuilder, Swap4, APInt::getSplat(Size, APInt(8, 0xCC)));
7255 
7256   // swap high and low 1 bit in 2 bit blocks 1|0 3|2 5|4 7|6 -> 0|1 2|3 4|5 6|7
7257   //    [(val & 0xAAAAAAAA) >> 1] & [(val & 0x55555555) << 1]
7258   // -> [(val & 0xAAAAAAAA) >> 1] & [(val << 1) & 0xAAAAAAAA]
7259   SwapN(1, Dst, MIRBuilder, Swap2, APInt::getSplat(Size, APInt(8, 0xAA)));
7260 
7261   MI.eraseFromParent();
7262   return Legalized;
7263 }
7264 
7265 LegalizerHelper::LegalizeResult
lowerReadWriteRegister(MachineInstr & MI)7266 LegalizerHelper::lowerReadWriteRegister(MachineInstr &MI) {
7267   MachineFunction &MF = MIRBuilder.getMF();
7268 
7269   bool IsRead = MI.getOpcode() == TargetOpcode::G_READ_REGISTER;
7270   int NameOpIdx = IsRead ? 1 : 0;
7271   int ValRegIndex = IsRead ? 0 : 1;
7272 
7273   Register ValReg = MI.getOperand(ValRegIndex).getReg();
7274   const LLT Ty = MRI.getType(ValReg);
7275   const MDString *RegStr = cast<MDString>(
7276     cast<MDNode>(MI.getOperand(NameOpIdx).getMetadata())->getOperand(0));
7277 
7278   Register PhysReg = TLI.getRegisterByName(RegStr->getString().data(), Ty, MF);
7279   if (!PhysReg.isValid())
7280     return UnableToLegalize;
7281 
7282   if (IsRead)
7283     MIRBuilder.buildCopy(ValReg, PhysReg);
7284   else
7285     MIRBuilder.buildCopy(PhysReg, ValReg);
7286 
7287   MI.eraseFromParent();
7288   return Legalized;
7289 }
7290 
7291 LegalizerHelper::LegalizeResult
lowerSMULH_UMULH(MachineInstr & MI)7292 LegalizerHelper::lowerSMULH_UMULH(MachineInstr &MI) {
7293   bool IsSigned = MI.getOpcode() == TargetOpcode::G_SMULH;
7294   unsigned ExtOp = IsSigned ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
7295   Register Result = MI.getOperand(0).getReg();
7296   LLT OrigTy = MRI.getType(Result);
7297   auto SizeInBits = OrigTy.getScalarSizeInBits();
7298   LLT WideTy = OrigTy.changeElementSize(SizeInBits * 2);
7299 
7300   auto LHS = MIRBuilder.buildInstr(ExtOp, {WideTy}, {MI.getOperand(1)});
7301   auto RHS = MIRBuilder.buildInstr(ExtOp, {WideTy}, {MI.getOperand(2)});
7302   auto Mul = MIRBuilder.buildMul(WideTy, LHS, RHS);
7303   unsigned ShiftOp = IsSigned ? TargetOpcode::G_ASHR : TargetOpcode::G_LSHR;
7304 
7305   auto ShiftAmt = MIRBuilder.buildConstant(WideTy, SizeInBits);
7306   auto Shifted = MIRBuilder.buildInstr(ShiftOp, {WideTy}, {Mul, ShiftAmt});
7307   MIRBuilder.buildTrunc(Result, Shifted);
7308 
7309   MI.eraseFromParent();
7310   return Legalized;
7311 }
7312 
7313 LegalizerHelper::LegalizeResult
lowerISFPCLASS(MachineInstr & MI)7314 LegalizerHelper::lowerISFPCLASS(MachineInstr &MI) {
7315   Register DstReg = MI.getOperand(0).getReg();
7316   Register SrcReg = MI.getOperand(1).getReg();
7317   LLT DstTy = MRI.getType(DstReg);
7318   LLT SrcTy = MRI.getType(SrcReg);
7319   uint64_t Mask = MI.getOperand(2).getImm();
7320 
7321   if (Mask == 0) {
7322     MIRBuilder.buildConstant(DstReg, 0);
7323     MI.eraseFromParent();
7324     return Legalized;
7325   }
7326   if ((Mask & fcAllFlags) == fcAllFlags) {
7327     MIRBuilder.buildConstant(DstReg, 1);
7328     MI.eraseFromParent();
7329     return Legalized;
7330   }
7331 
7332   unsigned BitSize = SrcTy.getScalarSizeInBits();
7333   const fltSemantics &Semantics = getFltSemanticForLLT(SrcTy.getScalarType());
7334 
7335   LLT IntTy = LLT::scalar(BitSize);
7336   if (SrcTy.isVector())
7337     IntTy = LLT::vector(SrcTy.getElementCount(), IntTy);
7338   auto AsInt = MIRBuilder.buildCopy(IntTy, SrcReg);
7339 
7340   // Various masks.
7341   APInt SignBit = APInt::getSignMask(BitSize);
7342   APInt ValueMask = APInt::getSignedMaxValue(BitSize);     // All bits but sign.
7343   APInt Inf = APFloat::getInf(Semantics).bitcastToAPInt(); // Exp and int bit.
7344   APInt ExpMask = Inf;
7345   APInt AllOneMantissa = APFloat::getLargest(Semantics).bitcastToAPInt() & ~Inf;
7346   APInt QNaNBitMask =
7347       APInt::getOneBitSet(BitSize, AllOneMantissa.getActiveBits() - 1);
7348   APInt InvertionMask = APInt::getAllOnesValue(DstTy.getScalarSizeInBits());
7349 
7350   auto SignBitC = MIRBuilder.buildConstant(IntTy, SignBit);
7351   auto ValueMaskC = MIRBuilder.buildConstant(IntTy, ValueMask);
7352   auto InfC = MIRBuilder.buildConstant(IntTy, Inf);
7353   auto ExpMaskC = MIRBuilder.buildConstant(IntTy, ExpMask);
7354   auto ZeroC = MIRBuilder.buildConstant(IntTy, 0);
7355 
7356   auto Abs = MIRBuilder.buildAnd(IntTy, AsInt, ValueMaskC);
7357   auto Sign =
7358       MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_NE, DstTy, AsInt, Abs);
7359 
7360   auto Res = MIRBuilder.buildConstant(DstTy, 0);
7361   const auto appendToRes = [&](MachineInstrBuilder ToAppend) {
7362     Res = MIRBuilder.buildOr(DstTy, Res, ToAppend);
7363   };
7364 
7365   // Tests that involve more than one class should be processed first.
7366   if ((Mask & fcFinite) == fcFinite) {
7367     // finite(V) ==> abs(V) u< exp_mask
7368     appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, Abs,
7369                                      ExpMaskC));
7370     Mask &= ~fcFinite;
7371   } else if ((Mask & fcFinite) == fcPosFinite) {
7372     // finite(V) && V > 0 ==> V u< exp_mask
7373     appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, AsInt,
7374                                      ExpMaskC));
7375     Mask &= ~fcPosFinite;
7376   } else if ((Mask & fcFinite) == fcNegFinite) {
7377     // finite(V) && V < 0 ==> abs(V) u< exp_mask && signbit == 1
7378     auto Cmp = MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, Abs,
7379                                     ExpMaskC);
7380     auto And = MIRBuilder.buildAnd(DstTy, Cmp, Sign);
7381     appendToRes(And);
7382     Mask &= ~fcNegFinite;
7383   }
7384 
7385   // Check for individual classes.
7386   if (unsigned PartialCheck = Mask & fcZero) {
7387     if (PartialCheck == fcPosZero)
7388       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
7389                                        AsInt, ZeroC));
7390     else if (PartialCheck == fcZero)
7391       appendToRes(
7392           MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, ZeroC));
7393     else // fcNegZero
7394       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
7395                                        AsInt, SignBitC));
7396   }
7397 
7398   if (unsigned PartialCheck = Mask & fcInf) {
7399     if (PartialCheck == fcPosInf)
7400       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
7401                                        AsInt, InfC));
7402     else if (PartialCheck == fcInf)
7403       appendToRes(
7404           MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, InfC));
7405     else { // fcNegInf
7406       APInt NegInf = APFloat::getInf(Semantics, true).bitcastToAPInt();
7407       auto NegInfC = MIRBuilder.buildConstant(IntTy, NegInf);
7408       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
7409                                        AsInt, NegInfC));
7410     }
7411   }
7412 
7413   if (unsigned PartialCheck = Mask & fcNan) {
7414     auto InfWithQnanBitC = MIRBuilder.buildConstant(IntTy, Inf | QNaNBitMask);
7415     if (PartialCheck == fcNan) {
7416       // isnan(V) ==> abs(V) u> int(inf)
7417       appendToRes(
7418           MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
7419     } else if (PartialCheck == fcQNan) {
7420       // isquiet(V) ==> abs(V) u>= (unsigned(Inf) | quiet_bit)
7421       appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGE, DstTy, Abs,
7422                                        InfWithQnanBitC));
7423     } else { // fcSNan
7424       // issignaling(V) ==> abs(V) u> unsigned(Inf) &&
7425       //                    abs(V) u< (unsigned(Inf) | quiet_bit)
7426       auto IsNan =
7427           MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC);
7428       auto IsNotQnan = MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy,
7429                                             Abs, InfWithQnanBitC);
7430       appendToRes(MIRBuilder.buildAnd(DstTy, IsNan, IsNotQnan));
7431     }
7432   }
7433 
7434   if (unsigned PartialCheck = Mask & fcSubnormal) {
7435     // issubnormal(V) ==> unsigned(abs(V) - 1) u< (all mantissa bits set)
7436     // issubnormal(V) && V>0 ==> unsigned(V - 1) u< (all mantissa bits set)
7437     auto V = (PartialCheck == fcPosSubnormal) ? AsInt : Abs;
7438     auto OneC = MIRBuilder.buildConstant(IntTy, 1);
7439     auto VMinusOne = MIRBuilder.buildSub(IntTy, V, OneC);
7440     auto SubnormalRes =
7441         MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, VMinusOne,
7442                              MIRBuilder.buildConstant(IntTy, AllOneMantissa));
7443     if (PartialCheck == fcNegSubnormal)
7444       SubnormalRes = MIRBuilder.buildAnd(DstTy, SubnormalRes, Sign);
7445     appendToRes(SubnormalRes);
7446   }
7447 
7448   if (unsigned PartialCheck = Mask & fcNormal) {
7449     // isnormal(V) ==> (0 u< exp u< max_exp) ==> (unsigned(exp-1) u<
7450     // (max_exp-1))
7451     APInt ExpLSB = ExpMask & ~(ExpMask.shl(1));
7452     auto ExpMinusOne = MIRBuilder.buildSub(
7453         IntTy, Abs, MIRBuilder.buildConstant(IntTy, ExpLSB));
7454     APInt MaxExpMinusOne = ExpMask - ExpLSB;
7455     auto NormalRes =
7456         MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, ExpMinusOne,
7457                              MIRBuilder.buildConstant(IntTy, MaxExpMinusOne));
7458     if (PartialCheck == fcNegNormal)
7459       NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, Sign);
7460     else if (PartialCheck == fcPosNormal) {
7461       auto PosSign = MIRBuilder.buildXor(
7462           DstTy, Sign, MIRBuilder.buildConstant(DstTy, InvertionMask));
7463       NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, PosSign);
7464     }
7465     appendToRes(NormalRes);
7466   }
7467 
7468   MIRBuilder.buildCopy(DstReg, Res);
7469   MI.eraseFromParent();
7470   return Legalized;
7471 }
7472 
lowerSelect(MachineInstr & MI)7473 LegalizerHelper::LegalizeResult LegalizerHelper::lowerSelect(MachineInstr &MI) {
7474   // Implement vector G_SELECT in terms of XOR, AND, OR.
7475   Register DstReg = MI.getOperand(0).getReg();
7476   Register MaskReg = MI.getOperand(1).getReg();
7477   Register Op1Reg = MI.getOperand(2).getReg();
7478   Register Op2Reg = MI.getOperand(3).getReg();
7479   LLT DstTy = MRI.getType(DstReg);
7480   LLT MaskTy = MRI.getType(MaskReg);
7481   if (!DstTy.isVector())
7482     return UnableToLegalize;
7483 
7484   bool IsEltPtr = DstTy.getElementType().isPointer();
7485   if (IsEltPtr) {
7486     LLT ScalarPtrTy = LLT::scalar(DstTy.getScalarSizeInBits());
7487     LLT NewTy = DstTy.changeElementType(ScalarPtrTy);
7488     Op1Reg = MIRBuilder.buildPtrToInt(NewTy, Op1Reg).getReg(0);
7489     Op2Reg = MIRBuilder.buildPtrToInt(NewTy, Op2Reg).getReg(0);
7490     DstTy = NewTy;
7491   }
7492 
7493   if (MaskTy.isScalar()) {
7494     // Turn the scalar condition into a vector condition mask.
7495 
7496     Register MaskElt = MaskReg;
7497 
7498     // The condition was potentially zero extended before, but we want a sign
7499     // extended boolean.
7500     if (MaskTy != LLT::scalar(1))
7501       MaskElt = MIRBuilder.buildSExtInReg(MaskTy, MaskElt, 1).getReg(0);
7502 
7503     // Continue the sign extension (or truncate) to match the data type.
7504     MaskElt = MIRBuilder.buildSExtOrTrunc(DstTy.getElementType(),
7505                                           MaskElt).getReg(0);
7506 
7507     // Generate a vector splat idiom.
7508     auto ShufSplat = MIRBuilder.buildShuffleSplat(DstTy, MaskElt);
7509     MaskReg = ShufSplat.getReg(0);
7510     MaskTy = DstTy;
7511   }
7512 
7513   if (MaskTy.getSizeInBits() != DstTy.getSizeInBits()) {
7514     return UnableToLegalize;
7515   }
7516 
7517   auto NotMask = MIRBuilder.buildNot(MaskTy, MaskReg);
7518   auto NewOp1 = MIRBuilder.buildAnd(MaskTy, Op1Reg, MaskReg);
7519   auto NewOp2 = MIRBuilder.buildAnd(MaskTy, Op2Reg, NotMask);
7520   if (IsEltPtr) {
7521     auto Or = MIRBuilder.buildOr(DstTy, NewOp1, NewOp2);
7522     MIRBuilder.buildIntToPtr(DstReg, Or);
7523   } else {
7524     MIRBuilder.buildOr(DstReg, NewOp1, NewOp2);
7525   }
7526   MI.eraseFromParent();
7527   return Legalized;
7528 }
7529 
lowerDIVREM(MachineInstr & MI)7530 LegalizerHelper::LegalizeResult LegalizerHelper::lowerDIVREM(MachineInstr &MI) {
7531   // Split DIVREM into individual instructions.
7532   unsigned Opcode = MI.getOpcode();
7533 
7534   MIRBuilder.buildInstr(
7535       Opcode == TargetOpcode::G_SDIVREM ? TargetOpcode::G_SDIV
7536                                         : TargetOpcode::G_UDIV,
7537       {MI.getOperand(0).getReg()}, {MI.getOperand(2), MI.getOperand(3)});
7538   MIRBuilder.buildInstr(
7539       Opcode == TargetOpcode::G_SDIVREM ? TargetOpcode::G_SREM
7540                                         : TargetOpcode::G_UREM,
7541       {MI.getOperand(1).getReg()}, {MI.getOperand(2), MI.getOperand(3)});
7542   MI.eraseFromParent();
7543   return Legalized;
7544 }
7545 
7546 LegalizerHelper::LegalizeResult
lowerAbsToAddXor(MachineInstr & MI)7547 LegalizerHelper::lowerAbsToAddXor(MachineInstr &MI) {
7548   // Expand %res = G_ABS %a into:
7549   // %v1 = G_ASHR %a, scalar_size-1
7550   // %v2 = G_ADD %a, %v1
7551   // %res = G_XOR %v2, %v1
7552   LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
7553   Register OpReg = MI.getOperand(1).getReg();
7554   auto ShiftAmt =
7555       MIRBuilder.buildConstant(DstTy, DstTy.getScalarSizeInBits() - 1);
7556   auto Shift = MIRBuilder.buildAShr(DstTy, OpReg, ShiftAmt);
7557   auto Add = MIRBuilder.buildAdd(DstTy, OpReg, Shift);
7558   MIRBuilder.buildXor(MI.getOperand(0).getReg(), Add, Shift);
7559   MI.eraseFromParent();
7560   return Legalized;
7561 }
7562 
7563 LegalizerHelper::LegalizeResult
lowerAbsToMaxNeg(MachineInstr & MI)7564 LegalizerHelper::lowerAbsToMaxNeg(MachineInstr &MI) {
7565   // Expand %res = G_ABS %a into:
7566   // %v1 = G_CONSTANT 0
7567   // %v2 = G_SUB %v1, %a
7568   // %res = G_SMAX %a, %v2
7569   Register SrcReg = MI.getOperand(1).getReg();
7570   LLT Ty = MRI.getType(SrcReg);
7571   auto Zero = MIRBuilder.buildConstant(Ty, 0).getReg(0);
7572   auto Sub = MIRBuilder.buildSub(Ty, Zero, SrcReg).getReg(0);
7573   MIRBuilder.buildSMax(MI.getOperand(0), SrcReg, Sub);
7574   MI.eraseFromParent();
7575   return Legalized;
7576 }
7577 
7578 LegalizerHelper::LegalizeResult
lowerVectorReduction(MachineInstr & MI)7579 LegalizerHelper::lowerVectorReduction(MachineInstr &MI) {
7580   Register SrcReg = MI.getOperand(1).getReg();
7581   LLT SrcTy = MRI.getType(SrcReg);
7582   LLT DstTy = MRI.getType(SrcReg);
7583 
7584   // The source could be a scalar if the IR type was <1 x sN>.
7585   if (SrcTy.isScalar()) {
7586     if (DstTy.getSizeInBits() > SrcTy.getSizeInBits())
7587       return UnableToLegalize; // FIXME: handle extension.
7588     // This can be just a plain copy.
7589     Observer.changingInstr(MI);
7590     MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::COPY));
7591     Observer.changedInstr(MI);
7592     return Legalized;
7593   }
7594   return UnableToLegalize;;
7595 }
7596 
shouldLowerMemFuncForSize(const MachineFunction & MF)7597 static bool shouldLowerMemFuncForSize(const MachineFunction &MF) {
7598   // On Darwin, -Os means optimize for size without hurting performance, so
7599   // only really optimize for size when -Oz (MinSize) is used.
7600   if (MF.getTarget().getTargetTriple().isOSDarwin())
7601     return MF.getFunction().hasMinSize();
7602   return MF.getFunction().hasOptSize();
7603 }
7604 
7605 // Returns a list of types to use for memory op lowering in MemOps. A partial
7606 // port of findOptimalMemOpLowering in TargetLowering.
findGISelOptimalMemOpLowering(std::vector<LLT> & MemOps,unsigned Limit,const MemOp & Op,unsigned DstAS,unsigned SrcAS,const AttributeList & FuncAttributes,const TargetLowering & TLI)7607 static bool findGISelOptimalMemOpLowering(std::vector<LLT> &MemOps,
7608                                           unsigned Limit, const MemOp &Op,
7609                                           unsigned DstAS, unsigned SrcAS,
7610                                           const AttributeList &FuncAttributes,
7611                                           const TargetLowering &TLI) {
7612   if (Op.isMemcpyWithFixedDstAlign() && Op.getSrcAlign() < Op.getDstAlign())
7613     return false;
7614 
7615   LLT Ty = TLI.getOptimalMemOpLLT(Op, FuncAttributes);
7616 
7617   if (Ty == LLT()) {
7618     // Use the largest scalar type whose alignment constraints are satisfied.
7619     // We only need to check DstAlign here as SrcAlign is always greater or
7620     // equal to DstAlign (or zero).
7621     Ty = LLT::scalar(64);
7622     if (Op.isFixedDstAlign())
7623       while (Op.getDstAlign() < Ty.getSizeInBytes() &&
7624              !TLI.allowsMisalignedMemoryAccesses(Ty, DstAS, Op.getDstAlign()))
7625         Ty = LLT::scalar(Ty.getSizeInBytes());
7626     assert(Ty.getSizeInBits() > 0 && "Could not find valid type");
7627     // FIXME: check for the largest legal type we can load/store to.
7628   }
7629 
7630   unsigned NumMemOps = 0;
7631   uint64_t Size = Op.size();
7632   while (Size) {
7633     unsigned TySize = Ty.getSizeInBytes();
7634     while (TySize > Size) {
7635       // For now, only use non-vector load / store's for the left-over pieces.
7636       LLT NewTy = Ty;
7637       // FIXME: check for mem op safety and legality of the types. Not all of
7638       // SDAGisms map cleanly to GISel concepts.
7639       if (NewTy.isVector())
7640         NewTy = NewTy.getSizeInBits() > 64 ? LLT::scalar(64) : LLT::scalar(32);
7641       NewTy = LLT::scalar(PowerOf2Floor(NewTy.getSizeInBits() - 1));
7642       unsigned NewTySize = NewTy.getSizeInBytes();
7643       assert(NewTySize > 0 && "Could not find appropriate type");
7644 
7645       // If the new LLT cannot cover all of the remaining bits, then consider
7646       // issuing a (or a pair of) unaligned and overlapping load / store.
7647       unsigned Fast;
7648       // Need to get a VT equivalent for allowMisalignedMemoryAccesses().
7649       MVT VT = getMVTForLLT(Ty);
7650       if (NumMemOps && Op.allowOverlap() && NewTySize < Size &&
7651           TLI.allowsMisalignedMemoryAccesses(
7652               VT, DstAS, Op.isFixedDstAlign() ? Op.getDstAlign() : Align(1),
7653               MachineMemOperand::MONone, &Fast) &&
7654           Fast)
7655         TySize = Size;
7656       else {
7657         Ty = NewTy;
7658         TySize = NewTySize;
7659       }
7660     }
7661 
7662     if (++NumMemOps > Limit)
7663       return false;
7664 
7665     MemOps.push_back(Ty);
7666     Size -= TySize;
7667   }
7668 
7669   return true;
7670 }
7671 
getTypeForLLT(LLT Ty,LLVMContext & C)7672 static Type *getTypeForLLT(LLT Ty, LLVMContext &C) {
7673   if (Ty.isVector())
7674     return FixedVectorType::get(IntegerType::get(C, Ty.getScalarSizeInBits()),
7675                                 Ty.getNumElements());
7676   return IntegerType::get(C, Ty.getSizeInBits());
7677 }
7678 
7679 // Get a vectorized representation of the memset value operand, GISel edition.
getMemsetValue(Register Val,LLT Ty,MachineIRBuilder & MIB)7680 static Register getMemsetValue(Register Val, LLT Ty, MachineIRBuilder &MIB) {
7681   MachineRegisterInfo &MRI = *MIB.getMRI();
7682   unsigned NumBits = Ty.getScalarSizeInBits();
7683   auto ValVRegAndVal = getIConstantVRegValWithLookThrough(Val, MRI);
7684   if (!Ty.isVector() && ValVRegAndVal) {
7685     APInt Scalar = ValVRegAndVal->Value.trunc(8);
7686     APInt SplatVal = APInt::getSplat(NumBits, Scalar);
7687     return MIB.buildConstant(Ty, SplatVal).getReg(0);
7688   }
7689 
7690   // Extend the byte value to the larger type, and then multiply by a magic
7691   // value 0x010101... in order to replicate it across every byte.
7692   // Unless it's zero, in which case just emit a larger G_CONSTANT 0.
7693   if (ValVRegAndVal && ValVRegAndVal->Value == 0) {
7694     return MIB.buildConstant(Ty, 0).getReg(0);
7695   }
7696 
7697   LLT ExtType = Ty.getScalarType();
7698   auto ZExt = MIB.buildZExtOrTrunc(ExtType, Val);
7699   if (NumBits > 8) {
7700     APInt Magic = APInt::getSplat(NumBits, APInt(8, 0x01));
7701     auto MagicMI = MIB.buildConstant(ExtType, Magic);
7702     Val = MIB.buildMul(ExtType, ZExt, MagicMI).getReg(0);
7703   }
7704 
7705   // For vector types create a G_BUILD_VECTOR.
7706   if (Ty.isVector())
7707     Val = MIB.buildSplatVector(Ty, Val).getReg(0);
7708 
7709   return Val;
7710 }
7711 
7712 LegalizerHelper::LegalizeResult
lowerMemset(MachineInstr & MI,Register Dst,Register Val,uint64_t KnownLen,Align Alignment,bool IsVolatile)7713 LegalizerHelper::lowerMemset(MachineInstr &MI, Register Dst, Register Val,
7714                              uint64_t KnownLen, Align Alignment,
7715                              bool IsVolatile) {
7716   auto &MF = *MI.getParent()->getParent();
7717   const auto &TLI = *MF.getSubtarget().getTargetLowering();
7718   auto &DL = MF.getDataLayout();
7719   LLVMContext &C = MF.getFunction().getContext();
7720 
7721   assert(KnownLen != 0 && "Have a zero length memset length!");
7722 
7723   bool DstAlignCanChange = false;
7724   MachineFrameInfo &MFI = MF.getFrameInfo();
7725   bool OptSize = shouldLowerMemFuncForSize(MF);
7726 
7727   MachineInstr *FIDef = getOpcodeDef(TargetOpcode::G_FRAME_INDEX, Dst, MRI);
7728   if (FIDef && !MFI.isFixedObjectIndex(FIDef->getOperand(1).getIndex()))
7729     DstAlignCanChange = true;
7730 
7731   unsigned Limit = TLI.getMaxStoresPerMemset(OptSize);
7732   std::vector<LLT> MemOps;
7733 
7734   const auto &DstMMO = **MI.memoperands_begin();
7735   MachinePointerInfo DstPtrInfo = DstMMO.getPointerInfo();
7736 
7737   auto ValVRegAndVal = getIConstantVRegValWithLookThrough(Val, MRI);
7738   bool IsZeroVal = ValVRegAndVal && ValVRegAndVal->Value == 0;
7739 
7740   if (!findGISelOptimalMemOpLowering(MemOps, Limit,
7741                                      MemOp::Set(KnownLen, DstAlignCanChange,
7742                                                 Alignment,
7743                                                 /*IsZeroMemset=*/IsZeroVal,
7744                                                 /*IsVolatile=*/IsVolatile),
7745                                      DstPtrInfo.getAddrSpace(), ~0u,
7746                                      MF.getFunction().getAttributes(), TLI))
7747     return UnableToLegalize;
7748 
7749   if (DstAlignCanChange) {
7750     // Get an estimate of the type from the LLT.
7751     Type *IRTy = getTypeForLLT(MemOps[0], C);
7752     Align NewAlign = DL.getABITypeAlign(IRTy);
7753     if (NewAlign > Alignment) {
7754       Alignment = NewAlign;
7755       unsigned FI = FIDef->getOperand(1).getIndex();
7756       // Give the stack frame object a larger alignment if needed.
7757       if (MFI.getObjectAlign(FI) < Alignment)
7758         MFI.setObjectAlignment(FI, Alignment);
7759     }
7760   }
7761 
7762   MachineIRBuilder MIB(MI);
7763   // Find the largest store and generate the bit pattern for it.
7764   LLT LargestTy = MemOps[0];
7765   for (unsigned i = 1; i < MemOps.size(); i++)
7766     if (MemOps[i].getSizeInBits() > LargestTy.getSizeInBits())
7767       LargestTy = MemOps[i];
7768 
7769   // The memset stored value is always defined as an s8, so in order to make it
7770   // work with larger store types we need to repeat the bit pattern across the
7771   // wider type.
7772   Register MemSetValue = getMemsetValue(Val, LargestTy, MIB);
7773 
7774   if (!MemSetValue)
7775     return UnableToLegalize;
7776 
7777   // Generate the stores. For each store type in the list, we generate the
7778   // matching store of that type to the destination address.
7779   LLT PtrTy = MRI.getType(Dst);
7780   unsigned DstOff = 0;
7781   unsigned Size = KnownLen;
7782   for (unsigned I = 0; I < MemOps.size(); I++) {
7783     LLT Ty = MemOps[I];
7784     unsigned TySize = Ty.getSizeInBytes();
7785     if (TySize > Size) {
7786       // Issuing an unaligned load / store pair that overlaps with the previous
7787       // pair. Adjust the offset accordingly.
7788       assert(I == MemOps.size() - 1 && I != 0);
7789       DstOff -= TySize - Size;
7790     }
7791 
7792     // If this store is smaller than the largest store see whether we can get
7793     // the smaller value for free with a truncate.
7794     Register Value = MemSetValue;
7795     if (Ty.getSizeInBits() < LargestTy.getSizeInBits()) {
7796       MVT VT = getMVTForLLT(Ty);
7797       MVT LargestVT = getMVTForLLT(LargestTy);
7798       if (!LargestTy.isVector() && !Ty.isVector() &&
7799           TLI.isTruncateFree(LargestVT, VT))
7800         Value = MIB.buildTrunc(Ty, MemSetValue).getReg(0);
7801       else
7802         Value = getMemsetValue(Val, Ty, MIB);
7803       if (!Value)
7804         return UnableToLegalize;
7805     }
7806 
7807     auto *StoreMMO = MF.getMachineMemOperand(&DstMMO, DstOff, Ty);
7808 
7809     Register Ptr = Dst;
7810     if (DstOff != 0) {
7811       auto Offset =
7812           MIB.buildConstant(LLT::scalar(PtrTy.getSizeInBits()), DstOff);
7813       Ptr = MIB.buildPtrAdd(PtrTy, Dst, Offset).getReg(0);
7814     }
7815 
7816     MIB.buildStore(Value, Ptr, *StoreMMO);
7817     DstOff += Ty.getSizeInBytes();
7818     Size -= TySize;
7819   }
7820 
7821   MI.eraseFromParent();
7822   return Legalized;
7823 }
7824 
7825 LegalizerHelper::LegalizeResult
lowerMemcpyInline(MachineInstr & MI)7826 LegalizerHelper::lowerMemcpyInline(MachineInstr &MI) {
7827   assert(MI.getOpcode() == TargetOpcode::G_MEMCPY_INLINE);
7828 
7829   Register Dst = MI.getOperand(0).getReg();
7830   Register Src = MI.getOperand(1).getReg();
7831   Register Len = MI.getOperand(2).getReg();
7832 
7833   const auto *MMOIt = MI.memoperands_begin();
7834   const MachineMemOperand *MemOp = *MMOIt;
7835   bool IsVolatile = MemOp->isVolatile();
7836 
7837   // See if this is a constant length copy
7838   auto LenVRegAndVal = getIConstantVRegValWithLookThrough(Len, MRI);
7839   // FIXME: support dynamically sized G_MEMCPY_INLINE
7840   assert(LenVRegAndVal &&
7841          "inline memcpy with dynamic size is not yet supported");
7842   uint64_t KnownLen = LenVRegAndVal->Value.getZExtValue();
7843   if (KnownLen == 0) {
7844     MI.eraseFromParent();
7845     return Legalized;
7846   }
7847 
7848   const auto &DstMMO = **MI.memoperands_begin();
7849   const auto &SrcMMO = **std::next(MI.memoperands_begin());
7850   Align DstAlign = DstMMO.getBaseAlign();
7851   Align SrcAlign = SrcMMO.getBaseAlign();
7852 
7853   return lowerMemcpyInline(MI, Dst, Src, KnownLen, DstAlign, SrcAlign,
7854                            IsVolatile);
7855 }
7856 
7857 LegalizerHelper::LegalizeResult
lowerMemcpyInline(MachineInstr & MI,Register Dst,Register Src,uint64_t KnownLen,Align DstAlign,Align SrcAlign,bool IsVolatile)7858 LegalizerHelper::lowerMemcpyInline(MachineInstr &MI, Register Dst, Register Src,
7859                                    uint64_t KnownLen, Align DstAlign,
7860                                    Align SrcAlign, bool IsVolatile) {
7861   assert(MI.getOpcode() == TargetOpcode::G_MEMCPY_INLINE);
7862   return lowerMemcpy(MI, Dst, Src, KnownLen,
7863                      std::numeric_limits<uint64_t>::max(), DstAlign, SrcAlign,
7864                      IsVolatile);
7865 }
7866 
7867 LegalizerHelper::LegalizeResult
lowerMemcpy(MachineInstr & MI,Register Dst,Register Src,uint64_t KnownLen,uint64_t Limit,Align DstAlign,Align SrcAlign,bool IsVolatile)7868 LegalizerHelper::lowerMemcpy(MachineInstr &MI, Register Dst, Register Src,
7869                              uint64_t KnownLen, uint64_t Limit, Align DstAlign,
7870                              Align SrcAlign, bool IsVolatile) {
7871   auto &MF = *MI.getParent()->getParent();
7872   const auto &TLI = *MF.getSubtarget().getTargetLowering();
7873   auto &DL = MF.getDataLayout();
7874   LLVMContext &C = MF.getFunction().getContext();
7875 
7876   assert(KnownLen != 0 && "Have a zero length memcpy length!");
7877 
7878   bool DstAlignCanChange = false;
7879   MachineFrameInfo &MFI = MF.getFrameInfo();
7880   Align Alignment = std::min(DstAlign, SrcAlign);
7881 
7882   MachineInstr *FIDef = getOpcodeDef(TargetOpcode::G_FRAME_INDEX, Dst, MRI);
7883   if (FIDef && !MFI.isFixedObjectIndex(FIDef->getOperand(1).getIndex()))
7884     DstAlignCanChange = true;
7885 
7886   // FIXME: infer better src pointer alignment like SelectionDAG does here.
7887   // FIXME: also use the equivalent of isMemSrcFromConstant and alwaysinlining
7888   // if the memcpy is in a tail call position.
7889 
7890   std::vector<LLT> MemOps;
7891 
7892   const auto &DstMMO = **MI.memoperands_begin();
7893   const auto &SrcMMO = **std::next(MI.memoperands_begin());
7894   MachinePointerInfo DstPtrInfo = DstMMO.getPointerInfo();
7895   MachinePointerInfo SrcPtrInfo = SrcMMO.getPointerInfo();
7896 
7897   if (!findGISelOptimalMemOpLowering(
7898           MemOps, Limit,
7899           MemOp::Copy(KnownLen, DstAlignCanChange, Alignment, SrcAlign,
7900                       IsVolatile),
7901           DstPtrInfo.getAddrSpace(), SrcPtrInfo.getAddrSpace(),
7902           MF.getFunction().getAttributes(), TLI))
7903     return UnableToLegalize;
7904 
7905   if (DstAlignCanChange) {
7906     // Get an estimate of the type from the LLT.
7907     Type *IRTy = getTypeForLLT(MemOps[0], C);
7908     Align NewAlign = DL.getABITypeAlign(IRTy);
7909 
7910     // Don't promote to an alignment that would require dynamic stack
7911     // realignment.
7912     const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
7913     if (!TRI->hasStackRealignment(MF))
7914       while (NewAlign > Alignment && DL.exceedsNaturalStackAlignment(NewAlign))
7915         NewAlign = NewAlign.previous();
7916 
7917     if (NewAlign > Alignment) {
7918       Alignment = NewAlign;
7919       unsigned FI = FIDef->getOperand(1).getIndex();
7920       // Give the stack frame object a larger alignment if needed.
7921       if (MFI.getObjectAlign(FI) < Alignment)
7922         MFI.setObjectAlignment(FI, Alignment);
7923     }
7924   }
7925 
7926   LLVM_DEBUG(dbgs() << "Inlining memcpy: " << MI << " into loads & stores\n");
7927 
7928   MachineIRBuilder MIB(MI);
7929   // Now we need to emit a pair of load and stores for each of the types we've
7930   // collected. I.e. for each type, generate a load from the source pointer of
7931   // that type width, and then generate a corresponding store to the dest buffer
7932   // of that value loaded. This can result in a sequence of loads and stores
7933   // mixed types, depending on what the target specifies as good types to use.
7934   unsigned CurrOffset = 0;
7935   unsigned Size = KnownLen;
7936   for (auto CopyTy : MemOps) {
7937     // Issuing an unaligned load / store pair  that overlaps with the previous
7938     // pair. Adjust the offset accordingly.
7939     if (CopyTy.getSizeInBytes() > Size)
7940       CurrOffset -= CopyTy.getSizeInBytes() - Size;
7941 
7942     // Construct MMOs for the accesses.
7943     auto *LoadMMO =
7944         MF.getMachineMemOperand(&SrcMMO, CurrOffset, CopyTy.getSizeInBytes());
7945     auto *StoreMMO =
7946         MF.getMachineMemOperand(&DstMMO, CurrOffset, CopyTy.getSizeInBytes());
7947 
7948     // Create the load.
7949     Register LoadPtr = Src;
7950     Register Offset;
7951     if (CurrOffset != 0) {
7952       LLT SrcTy = MRI.getType(Src);
7953       Offset = MIB.buildConstant(LLT::scalar(SrcTy.getSizeInBits()), CurrOffset)
7954                    .getReg(0);
7955       LoadPtr = MIB.buildPtrAdd(SrcTy, Src, Offset).getReg(0);
7956     }
7957     auto LdVal = MIB.buildLoad(CopyTy, LoadPtr, *LoadMMO);
7958 
7959     // Create the store.
7960     Register StorePtr = Dst;
7961     if (CurrOffset != 0) {
7962       LLT DstTy = MRI.getType(Dst);
7963       StorePtr = MIB.buildPtrAdd(DstTy, Dst, Offset).getReg(0);
7964     }
7965     MIB.buildStore(LdVal, StorePtr, *StoreMMO);
7966     CurrOffset += CopyTy.getSizeInBytes();
7967     Size -= CopyTy.getSizeInBytes();
7968   }
7969 
7970   MI.eraseFromParent();
7971   return Legalized;
7972 }
7973 
7974 LegalizerHelper::LegalizeResult
lowerMemmove(MachineInstr & MI,Register Dst,Register Src,uint64_t KnownLen,Align DstAlign,Align SrcAlign,bool IsVolatile)7975 LegalizerHelper::lowerMemmove(MachineInstr &MI, Register Dst, Register Src,
7976                               uint64_t KnownLen, Align DstAlign, Align SrcAlign,
7977                               bool IsVolatile) {
7978   auto &MF = *MI.getParent()->getParent();
7979   const auto &TLI = *MF.getSubtarget().getTargetLowering();
7980   auto &DL = MF.getDataLayout();
7981   LLVMContext &C = MF.getFunction().getContext();
7982 
7983   assert(KnownLen != 0 && "Have a zero length memmove length!");
7984 
7985   bool DstAlignCanChange = false;
7986   MachineFrameInfo &MFI = MF.getFrameInfo();
7987   bool OptSize = shouldLowerMemFuncForSize(MF);
7988   Align Alignment = std::min(DstAlign, SrcAlign);
7989 
7990   MachineInstr *FIDef = getOpcodeDef(TargetOpcode::G_FRAME_INDEX, Dst, MRI);
7991   if (FIDef && !MFI.isFixedObjectIndex(FIDef->getOperand(1).getIndex()))
7992     DstAlignCanChange = true;
7993 
7994   unsigned Limit = TLI.getMaxStoresPerMemmove(OptSize);
7995   std::vector<LLT> MemOps;
7996 
7997   const auto &DstMMO = **MI.memoperands_begin();
7998   const auto &SrcMMO = **std::next(MI.memoperands_begin());
7999   MachinePointerInfo DstPtrInfo = DstMMO.getPointerInfo();
8000   MachinePointerInfo SrcPtrInfo = SrcMMO.getPointerInfo();
8001 
8002   // FIXME: SelectionDAG always passes false for 'AllowOverlap', apparently due
8003   // to a bug in it's findOptimalMemOpLowering implementation. For now do the
8004   // same thing here.
8005   if (!findGISelOptimalMemOpLowering(
8006           MemOps, Limit,
8007           MemOp::Copy(KnownLen, DstAlignCanChange, Alignment, SrcAlign,
8008                       /*IsVolatile*/ true),
8009           DstPtrInfo.getAddrSpace(), SrcPtrInfo.getAddrSpace(),
8010           MF.getFunction().getAttributes(), TLI))
8011     return UnableToLegalize;
8012 
8013   if (DstAlignCanChange) {
8014     // Get an estimate of the type from the LLT.
8015     Type *IRTy = getTypeForLLT(MemOps[0], C);
8016     Align NewAlign = DL.getABITypeAlign(IRTy);
8017 
8018     // Don't promote to an alignment that would require dynamic stack
8019     // realignment.
8020     const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
8021     if (!TRI->hasStackRealignment(MF))
8022       while (NewAlign > Alignment && DL.exceedsNaturalStackAlignment(NewAlign))
8023         NewAlign = NewAlign.previous();
8024 
8025     if (NewAlign > Alignment) {
8026       Alignment = NewAlign;
8027       unsigned FI = FIDef->getOperand(1).getIndex();
8028       // Give the stack frame object a larger alignment if needed.
8029       if (MFI.getObjectAlign(FI) < Alignment)
8030         MFI.setObjectAlignment(FI, Alignment);
8031     }
8032   }
8033 
8034   LLVM_DEBUG(dbgs() << "Inlining memmove: " << MI << " into loads & stores\n");
8035 
8036   MachineIRBuilder MIB(MI);
8037   // Memmove requires that we perform the loads first before issuing the stores.
8038   // Apart from that, this loop is pretty much doing the same thing as the
8039   // memcpy codegen function.
8040   unsigned CurrOffset = 0;
8041   SmallVector<Register, 16> LoadVals;
8042   for (auto CopyTy : MemOps) {
8043     // Construct MMO for the load.
8044     auto *LoadMMO =
8045         MF.getMachineMemOperand(&SrcMMO, CurrOffset, CopyTy.getSizeInBytes());
8046 
8047     // Create the load.
8048     Register LoadPtr = Src;
8049     if (CurrOffset != 0) {
8050       LLT SrcTy = MRI.getType(Src);
8051       auto Offset =
8052           MIB.buildConstant(LLT::scalar(SrcTy.getSizeInBits()), CurrOffset);
8053       LoadPtr = MIB.buildPtrAdd(SrcTy, Src, Offset).getReg(0);
8054     }
8055     LoadVals.push_back(MIB.buildLoad(CopyTy, LoadPtr, *LoadMMO).getReg(0));
8056     CurrOffset += CopyTy.getSizeInBytes();
8057   }
8058 
8059   CurrOffset = 0;
8060   for (unsigned I = 0; I < MemOps.size(); ++I) {
8061     LLT CopyTy = MemOps[I];
8062     // Now store the values loaded.
8063     auto *StoreMMO =
8064         MF.getMachineMemOperand(&DstMMO, CurrOffset, CopyTy.getSizeInBytes());
8065 
8066     Register StorePtr = Dst;
8067     if (CurrOffset != 0) {
8068       LLT DstTy = MRI.getType(Dst);
8069       auto Offset =
8070           MIB.buildConstant(LLT::scalar(DstTy.getSizeInBits()), CurrOffset);
8071       StorePtr = MIB.buildPtrAdd(DstTy, Dst, Offset).getReg(0);
8072     }
8073     MIB.buildStore(LoadVals[I], StorePtr, *StoreMMO);
8074     CurrOffset += CopyTy.getSizeInBytes();
8075   }
8076   MI.eraseFromParent();
8077   return Legalized;
8078 }
8079 
8080 LegalizerHelper::LegalizeResult
lowerMemCpyFamily(MachineInstr & MI,unsigned MaxLen)8081 LegalizerHelper::lowerMemCpyFamily(MachineInstr &MI, unsigned MaxLen) {
8082   const unsigned Opc = MI.getOpcode();
8083   // This combine is fairly complex so it's not written with a separate
8084   // matcher function.
8085   assert((Opc == TargetOpcode::G_MEMCPY || Opc == TargetOpcode::G_MEMMOVE ||
8086           Opc == TargetOpcode::G_MEMSET) &&
8087          "Expected memcpy like instruction");
8088 
8089   auto MMOIt = MI.memoperands_begin();
8090   const MachineMemOperand *MemOp = *MMOIt;
8091 
8092   Align DstAlign = MemOp->getBaseAlign();
8093   Align SrcAlign;
8094   Register Dst = MI.getOperand(0).getReg();
8095   Register Src = MI.getOperand(1).getReg();
8096   Register Len = MI.getOperand(2).getReg();
8097 
8098   if (Opc != TargetOpcode::G_MEMSET) {
8099     assert(MMOIt != MI.memoperands_end() && "Expected a second MMO on MI");
8100     MemOp = *(++MMOIt);
8101     SrcAlign = MemOp->getBaseAlign();
8102   }
8103 
8104   // See if this is a constant length copy
8105   auto LenVRegAndVal = getIConstantVRegValWithLookThrough(Len, MRI);
8106   if (!LenVRegAndVal)
8107     return UnableToLegalize;
8108   uint64_t KnownLen = LenVRegAndVal->Value.getZExtValue();
8109 
8110   if (KnownLen == 0) {
8111     MI.eraseFromParent();
8112     return Legalized;
8113   }
8114 
8115   bool IsVolatile = MemOp->isVolatile();
8116   if (Opc == TargetOpcode::G_MEMCPY_INLINE)
8117     return lowerMemcpyInline(MI, Dst, Src, KnownLen, DstAlign, SrcAlign,
8118                              IsVolatile);
8119 
8120   // Don't try to optimize volatile.
8121   if (IsVolatile)
8122     return UnableToLegalize;
8123 
8124   if (MaxLen && KnownLen > MaxLen)
8125     return UnableToLegalize;
8126 
8127   if (Opc == TargetOpcode::G_MEMCPY) {
8128     auto &MF = *MI.getParent()->getParent();
8129     const auto &TLI = *MF.getSubtarget().getTargetLowering();
8130     bool OptSize = shouldLowerMemFuncForSize(MF);
8131     uint64_t Limit = TLI.getMaxStoresPerMemcpy(OptSize);
8132     return lowerMemcpy(MI, Dst, Src, KnownLen, Limit, DstAlign, SrcAlign,
8133                        IsVolatile);
8134   }
8135   if (Opc == TargetOpcode::G_MEMMOVE)
8136     return lowerMemmove(MI, Dst, Src, KnownLen, DstAlign, SrcAlign, IsVolatile);
8137   if (Opc == TargetOpcode::G_MEMSET)
8138     return lowerMemset(MI, Dst, Src, KnownLen, DstAlign, IsVolatile);
8139   return UnableToLegalize;
8140 }
8141