1 //===-- llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h -----*- C++ -*-//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file contains some helper functions which try to cleanup artifacts
9 // such as G_TRUNCs/G_[ZSA]EXTENDS that were created during legalization to make
10 // the types match. This file also contains some combines of merges that happens
11 // at the end of the legalization.
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_CODEGEN_GLOBALISEL_LEGALIZATIONARTIFACTCOMBINER_H
15 #define LLVM_CODEGEN_GLOBALISEL_LEGALIZATIONARTIFACTCOMBINER_H
16 
17 #include "llvm/ADT/SmallBitVector.h"
18 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
19 #include "llvm/CodeGen/GlobalISel/Legalizer.h"
20 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
21 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
22 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
23 #include "llvm/CodeGen/GlobalISel/Utils.h"
24 #include "llvm/CodeGen/MachineRegisterInfo.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "legalizer"
28 using namespace llvm::MIPatternMatch;
29 
30 namespace llvm {
31 class LegalizationArtifactCombiner {
32   MachineIRBuilder &Builder;
33   MachineRegisterInfo &MRI;
34   const LegalizerInfo &LI;
35 
isArtifactCast(unsigned Opc)36   static bool isArtifactCast(unsigned Opc) {
37     switch (Opc) {
38     case TargetOpcode::G_TRUNC:
39     case TargetOpcode::G_SEXT:
40     case TargetOpcode::G_ZEXT:
41     case TargetOpcode::G_ANYEXT:
42       return true;
43     default:
44       return false;
45     }
46   }
47 
48 public:
LegalizationArtifactCombiner(MachineIRBuilder & B,MachineRegisterInfo & MRI,const LegalizerInfo & LI)49   LegalizationArtifactCombiner(MachineIRBuilder &B, MachineRegisterInfo &MRI,
50                     const LegalizerInfo &LI)
51       : Builder(B), MRI(MRI), LI(LI) {}
52 
tryCombineAnyExt(MachineInstr & MI,SmallVectorImpl<MachineInstr * > & DeadInsts,SmallVectorImpl<Register> & UpdatedDefs)53   bool tryCombineAnyExt(MachineInstr &MI,
54                         SmallVectorImpl<MachineInstr *> &DeadInsts,
55                         SmallVectorImpl<Register> &UpdatedDefs) {
56     assert(MI.getOpcode() == TargetOpcode::G_ANYEXT);
57 
58     Builder.setInstrAndDebugLoc(MI);
59     Register DstReg = MI.getOperand(0).getReg();
60     Register SrcReg = lookThroughCopyInstrs(MI.getOperand(1).getReg());
61 
62     // aext(trunc x) - > aext/copy/trunc x
63     Register TruncSrc;
64     if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc)))) {
65       LLVM_DEBUG(dbgs() << ".. Combine MI: " << MI;);
66       Builder.buildAnyExtOrTrunc(DstReg, TruncSrc);
67       UpdatedDefs.push_back(DstReg);
68       markInstAndDefDead(MI, *MRI.getVRegDef(SrcReg), DeadInsts);
69       return true;
70     }
71 
72     // aext([asz]ext x) -> [asz]ext x
73     Register ExtSrc;
74     MachineInstr *ExtMI;
75     if (mi_match(SrcReg, MRI,
76                  m_all_of(m_MInstr(ExtMI), m_any_of(m_GAnyExt(m_Reg(ExtSrc)),
77                                                     m_GSExt(m_Reg(ExtSrc)),
78                                                     m_GZExt(m_Reg(ExtSrc)))))) {
79       Builder.buildInstr(ExtMI->getOpcode(), {DstReg}, {ExtSrc});
80       UpdatedDefs.push_back(DstReg);
81       markInstAndDefDead(MI, *ExtMI, DeadInsts);
82       return true;
83     }
84 
85     // Try to fold aext(g_constant) when the larger constant type is legal.
86     auto *SrcMI = MRI.getVRegDef(SrcReg);
87     if (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT) {
88       const LLT DstTy = MRI.getType(DstReg);
89       if (isInstLegal({TargetOpcode::G_CONSTANT, {DstTy}})) {
90         auto &CstVal = SrcMI->getOperand(1);
91         Builder.buildConstant(
92             DstReg, CstVal.getCImm()->getValue().sext(DstTy.getSizeInBits()));
93         UpdatedDefs.push_back(DstReg);
94         markInstAndDefDead(MI, *SrcMI, DeadInsts);
95         return true;
96       }
97     }
98     return tryFoldImplicitDef(MI, DeadInsts, UpdatedDefs);
99   }
100 
tryCombineZExt(MachineInstr & MI,SmallVectorImpl<MachineInstr * > & DeadInsts,SmallVectorImpl<Register> & UpdatedDefs,GISelObserverWrapper & Observer)101   bool tryCombineZExt(MachineInstr &MI,
102                       SmallVectorImpl<MachineInstr *> &DeadInsts,
103                       SmallVectorImpl<Register> &UpdatedDefs,
104                       GISelObserverWrapper &Observer) {
105     assert(MI.getOpcode() == TargetOpcode::G_ZEXT);
106 
107     Builder.setInstrAndDebugLoc(MI);
108     Register DstReg = MI.getOperand(0).getReg();
109     Register SrcReg = lookThroughCopyInstrs(MI.getOperand(1).getReg());
110 
111     // zext(trunc x) - > and (aext/copy/trunc x), mask
112     // zext(sext x) -> and (sext x), mask
113     Register TruncSrc;
114     Register SextSrc;
115     if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc))) ||
116         mi_match(SrcReg, MRI, m_GSExt(m_Reg(SextSrc)))) {
117       LLT DstTy = MRI.getType(DstReg);
118       if (isInstUnsupported({TargetOpcode::G_AND, {DstTy}}) ||
119           isConstantUnsupported(DstTy))
120         return false;
121       LLVM_DEBUG(dbgs() << ".. Combine MI: " << MI;);
122       LLT SrcTy = MRI.getType(SrcReg);
123       APInt MaskVal = APInt::getAllOnesValue(SrcTy.getScalarSizeInBits());
124       auto Mask = Builder.buildConstant(
125         DstTy, MaskVal.zext(DstTy.getScalarSizeInBits()));
126       auto Extended = SextSrc ? Builder.buildSExtOrTrunc(DstTy, SextSrc) :
127                                 Builder.buildAnyExtOrTrunc(DstTy, TruncSrc);
128       Builder.buildAnd(DstReg, Extended, Mask);
129       markInstAndDefDead(MI, *MRI.getVRegDef(SrcReg), DeadInsts);
130       return true;
131     }
132 
133     // zext(zext x) -> (zext x)
134     Register ZextSrc;
135     if (mi_match(SrcReg, MRI, m_GZExt(m_Reg(ZextSrc)))) {
136       LLVM_DEBUG(dbgs() << ".. Combine MI: " << MI);
137       Observer.changingInstr(MI);
138       MI.getOperand(1).setReg(ZextSrc);
139       Observer.changedInstr(MI);
140       UpdatedDefs.push_back(DstReg);
141       markDefDead(MI, *MRI.getVRegDef(SrcReg), DeadInsts);
142       return true;
143     }
144 
145     // Try to fold zext(g_constant) when the larger constant type is legal.
146     auto *SrcMI = MRI.getVRegDef(SrcReg);
147     if (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT) {
148       const LLT DstTy = MRI.getType(DstReg);
149       if (isInstLegal({TargetOpcode::G_CONSTANT, {DstTy}})) {
150         auto &CstVal = SrcMI->getOperand(1);
151         Builder.buildConstant(
152             DstReg, CstVal.getCImm()->getValue().zext(DstTy.getSizeInBits()));
153         UpdatedDefs.push_back(DstReg);
154         markInstAndDefDead(MI, *SrcMI, DeadInsts);
155         return true;
156       }
157     }
158     return tryFoldImplicitDef(MI, DeadInsts, UpdatedDefs);
159   }
160 
tryCombineSExt(MachineInstr & MI,SmallVectorImpl<MachineInstr * > & DeadInsts,SmallVectorImpl<Register> & UpdatedDefs)161   bool tryCombineSExt(MachineInstr &MI,
162                       SmallVectorImpl<MachineInstr *> &DeadInsts,
163                       SmallVectorImpl<Register> &UpdatedDefs) {
164     assert(MI.getOpcode() == TargetOpcode::G_SEXT);
165 
166     Builder.setInstrAndDebugLoc(MI);
167     Register DstReg = MI.getOperand(0).getReg();
168     Register SrcReg = lookThroughCopyInstrs(MI.getOperand(1).getReg());
169 
170     // sext(trunc x) - > (sext_inreg (aext/copy/trunc x), c)
171     Register TruncSrc;
172     if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc)))) {
173       LLT DstTy = MRI.getType(DstReg);
174       if (isInstUnsupported({TargetOpcode::G_SEXT_INREG, {DstTy}}))
175         return false;
176       LLVM_DEBUG(dbgs() << ".. Combine MI: " << MI;);
177       LLT SrcTy = MRI.getType(SrcReg);
178       uint64_t SizeInBits = SrcTy.getScalarSizeInBits();
179       Builder.buildInstr(
180           TargetOpcode::G_SEXT_INREG, {DstReg},
181           {Builder.buildAnyExtOrTrunc(DstTy, TruncSrc), SizeInBits});
182       markInstAndDefDead(MI, *MRI.getVRegDef(SrcReg), DeadInsts);
183       return true;
184     }
185 
186     // sext(zext x) -> (zext x)
187     // sext(sext x) -> (sext x)
188     Register ExtSrc;
189     MachineInstr *ExtMI;
190     if (mi_match(SrcReg, MRI,
191                  m_all_of(m_MInstr(ExtMI), m_any_of(m_GZExt(m_Reg(ExtSrc)),
192                                                     m_GSExt(m_Reg(ExtSrc)))))) {
193       LLVM_DEBUG(dbgs() << ".. Combine MI: " << MI);
194       Builder.buildInstr(ExtMI->getOpcode(), {DstReg}, {ExtSrc});
195       UpdatedDefs.push_back(DstReg);
196       markInstAndDefDead(MI, *MRI.getVRegDef(SrcReg), DeadInsts);
197       return true;
198     }
199 
200     // Try to fold sext(g_constant) when the larger constant type is legal.
201     auto *SrcMI = MRI.getVRegDef(SrcReg);
202     if (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT) {
203       const LLT DstTy = MRI.getType(DstReg);
204       if (isInstLegal({TargetOpcode::G_CONSTANT, {DstTy}})) {
205         auto &CstVal = SrcMI->getOperand(1);
206         Builder.buildConstant(
207             DstReg, CstVal.getCImm()->getValue().sext(DstTy.getSizeInBits()));
208         UpdatedDefs.push_back(DstReg);
209         markInstAndDefDead(MI, *SrcMI, DeadInsts);
210         return true;
211       }
212     }
213 
214     return tryFoldImplicitDef(MI, DeadInsts, UpdatedDefs);
215   }
216 
tryCombineTrunc(MachineInstr & MI,SmallVectorImpl<MachineInstr * > & DeadInsts,SmallVectorImpl<Register> & UpdatedDefs,GISelObserverWrapper & Observer)217   bool tryCombineTrunc(MachineInstr &MI,
218                        SmallVectorImpl<MachineInstr *> &DeadInsts,
219                        SmallVectorImpl<Register> &UpdatedDefs,
220                        GISelObserverWrapper &Observer) {
221     assert(MI.getOpcode() == TargetOpcode::G_TRUNC);
222 
223     Builder.setInstr(MI);
224     Register DstReg = MI.getOperand(0).getReg();
225     Register SrcReg = lookThroughCopyInstrs(MI.getOperand(1).getReg());
226 
227     // Try to fold trunc(g_constant) when the smaller constant type is legal.
228     auto *SrcMI = MRI.getVRegDef(SrcReg);
229     if (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT) {
230       const LLT DstTy = MRI.getType(DstReg);
231       if (isInstLegal({TargetOpcode::G_CONSTANT, {DstTy}})) {
232         auto &CstVal = SrcMI->getOperand(1);
233         Builder.buildConstant(
234             DstReg, CstVal.getCImm()->getValue().trunc(DstTy.getSizeInBits()));
235         UpdatedDefs.push_back(DstReg);
236         markInstAndDefDead(MI, *SrcMI, DeadInsts);
237         return true;
238       }
239     }
240 
241     // Try to fold trunc(merge) to directly use the source of the merge.
242     // This gets rid of large, difficult to legalize, merges
243     if (auto *SrcMerge = dyn_cast<GMerge>(SrcMI)) {
244       const Register MergeSrcReg = SrcMerge->getSourceReg(0);
245       const LLT MergeSrcTy = MRI.getType(MergeSrcReg);
246       const LLT DstTy = MRI.getType(DstReg);
247 
248       // We can only fold if the types are scalar
249       const unsigned DstSize = DstTy.getSizeInBits();
250       const unsigned MergeSrcSize = MergeSrcTy.getSizeInBits();
251       if (!DstTy.isScalar() || !MergeSrcTy.isScalar())
252         return false;
253 
254       if (DstSize < MergeSrcSize) {
255         // When the merge source is larger than the destination, we can just
256         // truncate the merge source directly
257         if (isInstUnsupported({TargetOpcode::G_TRUNC, {DstTy, MergeSrcTy}}))
258           return false;
259 
260         LLVM_DEBUG(dbgs() << "Combining G_TRUNC(G_MERGE_VALUES) to G_TRUNC: "
261                           << MI);
262 
263         Builder.buildTrunc(DstReg, MergeSrcReg);
264         UpdatedDefs.push_back(DstReg);
265       } else if (DstSize == MergeSrcSize) {
266         // If the sizes match we can simply try to replace the register
267         LLVM_DEBUG(
268             dbgs() << "Replacing G_TRUNC(G_MERGE_VALUES) with merge input: "
269                    << MI);
270         replaceRegOrBuildCopy(DstReg, MergeSrcReg, MRI, Builder, UpdatedDefs,
271                               Observer);
272       } else if (DstSize % MergeSrcSize == 0) {
273         // If the trunc size is a multiple of the merge source size we can use
274         // a smaller merge instead
275         if (isInstUnsupported(
276                 {TargetOpcode::G_MERGE_VALUES, {DstTy, MergeSrcTy}}))
277           return false;
278 
279         LLVM_DEBUG(
280             dbgs() << "Combining G_TRUNC(G_MERGE_VALUES) to G_MERGE_VALUES: "
281                    << MI);
282 
283         const unsigned NumSrcs = DstSize / MergeSrcSize;
284         assert(NumSrcs < SrcMI->getNumOperands() - 1 &&
285                "trunc(merge) should require less inputs than merge");
286         SmallVector<Register, 8> SrcRegs(NumSrcs);
287         for (unsigned i = 0; i < NumSrcs; ++i)
288           SrcRegs[i] = SrcMerge->getSourceReg(i);
289 
290         Builder.buildMerge(DstReg, SrcRegs);
291         UpdatedDefs.push_back(DstReg);
292       } else {
293         // Unable to combine
294         return false;
295       }
296 
297       markInstAndDefDead(MI, *SrcMerge, DeadInsts);
298       return true;
299     }
300 
301     // trunc(trunc) -> trunc
302     Register TruncSrc;
303     if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc)))) {
304       // Always combine trunc(trunc) since the eventual resulting trunc must be
305       // legal anyway as it must be legal for all outputs of the consumer type
306       // set.
307       LLVM_DEBUG(dbgs() << ".. Combine G_TRUNC(G_TRUNC): " << MI);
308 
309       Builder.buildTrunc(DstReg, TruncSrc);
310       UpdatedDefs.push_back(DstReg);
311       markInstAndDefDead(MI, *MRI.getVRegDef(TruncSrc), DeadInsts);
312       return true;
313     }
314 
315     return false;
316   }
317 
318   /// Try to fold G_[ASZ]EXT (G_IMPLICIT_DEF).
tryFoldImplicitDef(MachineInstr & MI,SmallVectorImpl<MachineInstr * > & DeadInsts,SmallVectorImpl<Register> & UpdatedDefs)319   bool tryFoldImplicitDef(MachineInstr &MI,
320                           SmallVectorImpl<MachineInstr *> &DeadInsts,
321                           SmallVectorImpl<Register> &UpdatedDefs) {
322     unsigned Opcode = MI.getOpcode();
323     assert(Opcode == TargetOpcode::G_ANYEXT || Opcode == TargetOpcode::G_ZEXT ||
324            Opcode == TargetOpcode::G_SEXT);
325 
326     if (MachineInstr *DefMI = getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF,
327                                            MI.getOperand(1).getReg(), MRI)) {
328       Builder.setInstr(MI);
329       Register DstReg = MI.getOperand(0).getReg();
330       LLT DstTy = MRI.getType(DstReg);
331 
332       if (Opcode == TargetOpcode::G_ANYEXT) {
333         // G_ANYEXT (G_IMPLICIT_DEF) -> G_IMPLICIT_DEF
334         if (!isInstLegal({TargetOpcode::G_IMPLICIT_DEF, {DstTy}}))
335           return false;
336         LLVM_DEBUG(dbgs() << ".. Combine G_ANYEXT(G_IMPLICIT_DEF): " << MI;);
337         Builder.buildInstr(TargetOpcode::G_IMPLICIT_DEF, {DstReg}, {});
338         UpdatedDefs.push_back(DstReg);
339       } else {
340         // G_[SZ]EXT (G_IMPLICIT_DEF) -> G_CONSTANT 0 because the top
341         // bits will be 0 for G_ZEXT and 0/1 for the G_SEXT.
342         if (isConstantUnsupported(DstTy))
343           return false;
344         LLVM_DEBUG(dbgs() << ".. Combine G_[SZ]EXT(G_IMPLICIT_DEF): " << MI;);
345         Builder.buildConstant(DstReg, 0);
346         UpdatedDefs.push_back(DstReg);
347       }
348 
349       markInstAndDefDead(MI, *DefMI, DeadInsts);
350       return true;
351     }
352     return false;
353   }
354 
tryFoldUnmergeCast(MachineInstr & MI,MachineInstr & CastMI,SmallVectorImpl<MachineInstr * > & DeadInsts,SmallVectorImpl<Register> & UpdatedDefs)355   bool tryFoldUnmergeCast(MachineInstr &MI, MachineInstr &CastMI,
356                           SmallVectorImpl<MachineInstr *> &DeadInsts,
357                           SmallVectorImpl<Register> &UpdatedDefs) {
358 
359     assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES);
360 
361     const unsigned CastOpc = CastMI.getOpcode();
362 
363     if (!isArtifactCast(CastOpc))
364       return false;
365 
366     const unsigned NumDefs = MI.getNumOperands() - 1;
367 
368     const Register CastSrcReg = CastMI.getOperand(1).getReg();
369     const LLT CastSrcTy = MRI.getType(CastSrcReg);
370     const LLT DestTy = MRI.getType(MI.getOperand(0).getReg());
371     const LLT SrcTy = MRI.getType(MI.getOperand(NumDefs).getReg());
372 
373     const unsigned CastSrcSize = CastSrcTy.getSizeInBits();
374     const unsigned DestSize = DestTy.getSizeInBits();
375 
376     if (CastOpc == TargetOpcode::G_TRUNC) {
377       if (SrcTy.isVector() && SrcTy.getScalarType() == DestTy.getScalarType()) {
378         //  %1:_(<4 x s8>) = G_TRUNC %0(<4 x s32>)
379         //  %2:_(s8), %3:_(s8), %4:_(s8), %5:_(s8) = G_UNMERGE_VALUES %1
380         // =>
381         //  %6:_(s32), %7:_(s32), %8:_(s32), %9:_(s32) = G_UNMERGE_VALUES %0
382         //  %2:_(s8) = G_TRUNC %6
383         //  %3:_(s8) = G_TRUNC %7
384         //  %4:_(s8) = G_TRUNC %8
385         //  %5:_(s8) = G_TRUNC %9
386 
387         unsigned UnmergeNumElts =
388             DestTy.isVector() ? CastSrcTy.getNumElements() / NumDefs : 1;
389         LLT UnmergeTy = CastSrcTy.changeElementCount(
390             ElementCount::getFixed(UnmergeNumElts));
391 
392         if (isInstUnsupported(
393                 {TargetOpcode::G_UNMERGE_VALUES, {UnmergeTy, CastSrcTy}}))
394           return false;
395 
396         Builder.setInstr(MI);
397         auto NewUnmerge = Builder.buildUnmerge(UnmergeTy, CastSrcReg);
398 
399         for (unsigned I = 0; I != NumDefs; ++I) {
400           Register DefReg = MI.getOperand(I).getReg();
401           UpdatedDefs.push_back(DefReg);
402           Builder.buildTrunc(DefReg, NewUnmerge.getReg(I));
403         }
404 
405         markInstAndDefDead(MI, CastMI, DeadInsts);
406         return true;
407       }
408 
409       if (CastSrcTy.isScalar() && SrcTy.isScalar() && !DestTy.isVector()) {
410         //  %1:_(s16) = G_TRUNC %0(s32)
411         //  %2:_(s8), %3:_(s8) = G_UNMERGE_VALUES %1
412         // =>
413         //  %2:_(s8), %3:_(s8), %4:_(s8), %5:_(s8) = G_UNMERGE_VALUES %0
414 
415         // Unmerge(trunc) can be combined if the trunc source size is a multiple
416         // of the unmerge destination size
417         if (CastSrcSize % DestSize != 0)
418           return false;
419 
420         // Check if the new unmerge is supported
421         if (isInstUnsupported(
422                 {TargetOpcode::G_UNMERGE_VALUES, {DestTy, CastSrcTy}}))
423           return false;
424 
425         // Gather the original destination registers and create new ones for the
426         // unused bits
427         const unsigned NewNumDefs = CastSrcSize / DestSize;
428         SmallVector<Register, 8> DstRegs(NewNumDefs);
429         for (unsigned Idx = 0; Idx < NewNumDefs; ++Idx) {
430           if (Idx < NumDefs)
431             DstRegs[Idx] = MI.getOperand(Idx).getReg();
432           else
433             DstRegs[Idx] = MRI.createGenericVirtualRegister(DestTy);
434         }
435 
436         // Build new unmerge
437         Builder.setInstr(MI);
438         Builder.buildUnmerge(DstRegs, CastSrcReg);
439         UpdatedDefs.append(DstRegs.begin(), DstRegs.begin() + NewNumDefs);
440         markInstAndDefDead(MI, CastMI, DeadInsts);
441         return true;
442       }
443     }
444 
445     // TODO: support combines with other casts as well
446     return false;
447   }
448 
canFoldMergeOpcode(unsigned MergeOp,unsigned ConvertOp,LLT OpTy,LLT DestTy)449   static bool canFoldMergeOpcode(unsigned MergeOp, unsigned ConvertOp,
450                                  LLT OpTy, LLT DestTy) {
451     // Check if we found a definition that is like G_MERGE_VALUES.
452     switch (MergeOp) {
453     default:
454       return false;
455     case TargetOpcode::G_BUILD_VECTOR:
456     case TargetOpcode::G_MERGE_VALUES:
457       // The convert operation that we will need to insert is
458       // going to convert the input of that type of instruction (scalar)
459       // to the destination type (DestTy).
460       // The conversion needs to stay in the same domain (scalar to scalar
461       // and vector to vector), so if we were to allow to fold the merge
462       // we would need to insert some bitcasts.
463       // E.g.,
464       // <2 x s16> = build_vector s16, s16
465       // <2 x s32> = zext <2 x s16>
466       // <2 x s16>, <2 x s16> = unmerge <2 x s32>
467       //
468       // As is the folding would produce:
469       // <2 x s16> = zext s16  <-- scalar to vector
470       // <2 x s16> = zext s16  <-- scalar to vector
471       // Which is invalid.
472       // Instead we would want to generate:
473       // s32 = zext s16
474       // <2 x s16> = bitcast s32
475       // s32 = zext s16
476       // <2 x s16> = bitcast s32
477       //
478       // That is not done yet.
479       if (ConvertOp == 0)
480         return true;
481       return !DestTy.isVector() && OpTy.isVector();
482     case TargetOpcode::G_CONCAT_VECTORS: {
483       if (ConvertOp == 0)
484         return true;
485       if (!DestTy.isVector())
486         return false;
487 
488       const unsigned OpEltSize = OpTy.getElementType().getSizeInBits();
489 
490       // Don't handle scalarization with a cast that isn't in the same
491       // direction as the vector cast. This could be handled, but it would
492       // require more intermediate unmerges.
493       if (ConvertOp == TargetOpcode::G_TRUNC)
494         return DestTy.getSizeInBits() <= OpEltSize;
495       return DestTy.getSizeInBits() >= OpEltSize;
496     }
497     }
498   }
499 
500   /// Try to replace DstReg with SrcReg or build a COPY instruction
501   /// depending on the register constraints.
replaceRegOrBuildCopy(Register DstReg,Register SrcReg,MachineRegisterInfo & MRI,MachineIRBuilder & Builder,SmallVectorImpl<Register> & UpdatedDefs,GISelChangeObserver & Observer)502   static void replaceRegOrBuildCopy(Register DstReg, Register SrcReg,
503                                     MachineRegisterInfo &MRI,
504                                     MachineIRBuilder &Builder,
505                                     SmallVectorImpl<Register> &UpdatedDefs,
506                                     GISelChangeObserver &Observer) {
507     if (!llvm::canReplaceReg(DstReg, SrcReg, MRI)) {
508       Builder.buildCopy(DstReg, SrcReg);
509       UpdatedDefs.push_back(DstReg);
510       return;
511     }
512     SmallVector<MachineInstr *, 4> UseMIs;
513     // Get the users and notify the observer before replacing.
514     for (auto &UseMI : MRI.use_instructions(DstReg)) {
515       UseMIs.push_back(&UseMI);
516       Observer.changingInstr(UseMI);
517     }
518     // Replace the registers.
519     MRI.replaceRegWith(DstReg, SrcReg);
520     UpdatedDefs.push_back(SrcReg);
521     // Notify the observer that we changed the instructions.
522     for (auto *UseMI : UseMIs)
523       Observer.changedInstr(*UseMI);
524   }
525 
526   /// Return the operand index in \p MI that defines \p Def
getDefIndex(const MachineInstr & MI,Register SearchDef)527   static unsigned getDefIndex(const MachineInstr &MI, Register SearchDef) {
528     unsigned DefIdx = 0;
529     for (const MachineOperand &Def : MI.defs()) {
530       if (Def.getReg() == SearchDef)
531         break;
532       ++DefIdx;
533     }
534 
535     return DefIdx;
536   }
537 
538   /// This class provides utilities for finding source registers of specific
539   /// bit ranges in an artifact. The routines can look through the source
540   /// registers if they're other artifacts to try to find a non-artifact source
541   /// of a value.
542   class ArtifactValueFinder {
543     MachineRegisterInfo &MRI;
544     MachineIRBuilder &MIB;
545     const LegalizerInfo &LI;
546 
547   private:
548     /// Given an concat_vector op \p Concat and a start bit and size, try to
549     /// find the origin of the value defined by that start position and size.
550     ///
551     /// \returns A register if a value can be found, otherwise an empty
552     /// Register.
findValueFromConcat(GConcatVectors & Concat,unsigned StartBit,unsigned Size)553     Register findValueFromConcat(GConcatVectors &Concat, unsigned StartBit,
554                                  unsigned Size) {
555       assert(Size > 0);
556 
557       // Find the source operand that provides the bits requested.
558       Register Src1Reg = Concat.getSourceReg(0);
559       unsigned SrcSize = MRI.getType(Src1Reg).getSizeInBits();
560 
561       // Operand index of the source that provides the start of the bit range.
562       unsigned StartSrcIdx = (StartBit / SrcSize) + 1;
563       // Offset into the source at which the bit range starts.
564       unsigned InRegOffset = StartBit % SrcSize;
565       // Check that the bits don't span multiple sources.
566       // FIXME: we might be able return multiple sources? Or create an
567       // appropriate concat to make it fit.
568       if (InRegOffset + Size > SrcSize)
569         return Register();
570 
571       // If the bits exactly cover a single source, then return the operand as
572       // our value reg.
573       Register SrcReg = Concat.getReg(StartSrcIdx);
574       if (InRegOffset == 0 && Size == SrcSize)
575         return SrcReg; // A source operand matches exactly.
576 
577       return findValueFromDef(SrcReg, InRegOffset, Size);
578     }
579 
580     /// Given an build_vector op \p BV and a start bit and size, try to find
581     /// the origin of the value defined by that start position and size.
582     ///
583     /// \returns A register if a value can be found, otherwise an empty
584     /// Register.
findValueFromBuildVector(GBuildVector & BV,unsigned StartBit,unsigned Size)585     Register findValueFromBuildVector(GBuildVector &BV, unsigned StartBit,
586                                       unsigned Size) {
587       assert(Size > 0);
588 
589       // Find the source operand that provides the bits requested.
590       Register Src1Reg = BV.getSourceReg(0);
591       unsigned SrcSize = MRI.getType(Src1Reg).getSizeInBits();
592 
593       // Operand index of the source that provides the start of the bit range.
594       unsigned StartSrcIdx = (StartBit / SrcSize) + 1;
595       // Offset into the source at which the bit range starts.
596       unsigned InRegOffset = StartBit % SrcSize;
597 
598       if (InRegOffset != 0)
599         return Register(); // Give up, bits don't start at a scalar source.
600       if (Size < SrcSize)
601         return Register(); // Scalar source is too large for requested bits.
602 
603       // If the bits cover multiple sources evenly, then create a new
604       // build_vector to synthesize the required size, if that's been requested.
605       if (Size > SrcSize) {
606         if (Size % SrcSize > 0)
607           return Register(); // Isn't covered exactly by sources.
608 
609         unsigned NumSrcsUsed = Size / SrcSize;
610         LLT SrcTy = MRI.getType(Src1Reg);
611         LLT NewBVTy = LLT::fixed_vector(NumSrcsUsed, SrcTy);
612 
613         // Check if the resulting build vector would be legal.
614         LegalizeActionStep ActionStep =
615             LI.getAction({TargetOpcode::G_BUILD_VECTOR, {NewBVTy, SrcTy}});
616         if (ActionStep.Action != LegalizeActions::Legal)
617           return Register();
618 
619         SmallVector<Register> NewSrcs;
620         for (unsigned SrcIdx = StartSrcIdx; SrcIdx < StartSrcIdx + NumSrcsUsed;
621              ++SrcIdx)
622           NewSrcs.push_back(BV.getReg(SrcIdx));
623         MIB.setInstrAndDebugLoc(BV);
624         return MIB.buildBuildVector(NewBVTy, NewSrcs).getReg(0);
625       }
626       // A single source is requested, just return it.
627       return BV.getReg(StartSrcIdx);
628     }
629 
630     /// Given an G_INSERT op \p MI and a start bit and size, try to find
631     /// the origin of the value defined by that start position and size.
632     ///
633     /// \returns A register if a value can be found, otherwise an empty
634     /// Register.
findValueFromInsert(MachineInstr & MI,unsigned StartBit,unsigned Size)635     Register findValueFromInsert(MachineInstr &MI, unsigned StartBit,
636                                  unsigned Size) {
637       assert(MI.getOpcode() == TargetOpcode::G_INSERT);
638       assert(Size > 0);
639 
640       Register ContainerSrcReg = MI.getOperand(1).getReg();
641       Register InsertedReg = MI.getOperand(2).getReg();
642       LLT InsertedRegTy = MRI.getType(InsertedReg);
643       unsigned InsertOffset = MI.getOperand(3).getImm();
644 
645       // There are 4 possible container/insertreg + requested bit-range layouts
646       // that the instruction and query could be representing.
647       // For: %_ = G_INSERT %CONTAINER, %INS, InsOff (abbrev. to 'IO')
648       // and a start bit 'SB', with size S, giving an end bit 'EB', we could
649       // have...
650       // Scenario A:
651       //   --------------------------
652       //  |  INS    |  CONTAINER     |
653       //   --------------------------
654       //       |   |
655       //       SB  EB
656       //
657       // Scenario B:
658       //   --------------------------
659       //  |  INS    |  CONTAINER     |
660       //   --------------------------
661       //                |    |
662       //                SB   EB
663       //
664       // Scenario C:
665       //   --------------------------
666       //  |  CONTAINER    |  INS     |
667       //   --------------------------
668       //       |    |
669       //       SB   EB
670       //
671       // Scenario D:
672       //   --------------------------
673       //  |  CONTAINER    |  INS     |
674       //   --------------------------
675       //                     |   |
676       //                     SB  EB
677       //
678       // So therefore, A and D are requesting data from the INS operand, while
679       // B and C are requesting from the container operand.
680 
681       unsigned InsertedEndBit = InsertOffset + InsertedRegTy.getSizeInBits();
682       unsigned EndBit = StartBit + Size;
683       unsigned NewStartBit;
684       Register SrcRegToUse;
685       if (EndBit <= InsertOffset || InsertedEndBit <= StartBit) {
686         SrcRegToUse = ContainerSrcReg;
687         NewStartBit = StartBit;
688         return findValueFromDef(SrcRegToUse, NewStartBit, Size);
689       }
690       if (InsertOffset <= StartBit && EndBit <= InsertedEndBit) {
691         SrcRegToUse = InsertedReg;
692         NewStartBit = StartBit - InsertOffset;
693         return findValueFromDef(SrcRegToUse, NewStartBit, Size);
694       }
695       // The bit range spans both the inserted and container regions.
696       return Register();
697     }
698 
699   public:
ArtifactValueFinder(MachineRegisterInfo & Mri,MachineIRBuilder & Builder,const LegalizerInfo & Info)700     ArtifactValueFinder(MachineRegisterInfo &Mri, MachineIRBuilder &Builder,
701                         const LegalizerInfo &Info)
702         : MRI(Mri), MIB(Builder), LI(Info) {}
703 
704     /// Try to find a source of the value defined in the def \p DefReg, starting
705     /// at position \p StartBit with size \p Size.
706     /// \returns an empty Register if no value could be found, or \p DefReg if
707     /// if that was the best we could do.
findValueFromDef(Register DefReg,unsigned StartBit,unsigned Size)708     Register findValueFromDef(Register DefReg, unsigned StartBit,
709                               unsigned Size) {
710       MachineInstr *Def = getDefIgnoringCopies(DefReg, MRI);
711       // If the instruction has a single def, then simply delegate the search.
712       // For unmerge however with multiple defs, we need to compute the offset
713       // into the source of the unmerge.
714       switch (Def->getOpcode()) {
715       case TargetOpcode::G_CONCAT_VECTORS:
716         return findValueFromConcat(cast<GConcatVectors>(*Def), StartBit, Size);
717       case TargetOpcode::G_UNMERGE_VALUES: {
718         unsigned DefStartBit = 0;
719         unsigned DefSize = MRI.getType(DefReg).getSizeInBits();
720         for (const auto &MO : Def->defs()) {
721           if (MO.getReg() == DefReg)
722             break;
723           DefStartBit += DefSize;
724         }
725         Register SrcReg = Def->getOperand(Def->getNumOperands() - 1).getReg();
726         Register SrcOriginReg =
727             findValueFromDef(SrcReg, StartBit + DefStartBit, Size);
728         if (SrcOriginReg)
729           return SrcOriginReg;
730         // Failed to find a further value. If the StartBit and Size perfectly
731         // covered the requested DefReg, return that since it's better than
732         // nothing.
733         if (StartBit == 0 && Size == DefSize)
734           return DefReg;
735         return Register();
736       }
737       case TargetOpcode::G_BUILD_VECTOR:
738         return findValueFromBuildVector(cast<GBuildVector>(*Def), StartBit,
739                                         Size);
740       case TargetOpcode::G_INSERT:
741         return findValueFromInsert(*Def, StartBit, Size);
742       default:
743         return Register();
744       }
745     }
746   };
747 
tryCombineUnmergeValues(GUnmerge & MI,SmallVectorImpl<MachineInstr * > & DeadInsts,SmallVectorImpl<Register> & UpdatedDefs,GISelChangeObserver & Observer)748   bool tryCombineUnmergeValues(GUnmerge &MI,
749                                SmallVectorImpl<MachineInstr *> &DeadInsts,
750                                SmallVectorImpl<Register> &UpdatedDefs,
751                                GISelChangeObserver &Observer) {
752     unsigned NumDefs = MI.getNumDefs();
753     Register SrcReg = MI.getSourceReg();
754     MachineInstr *SrcDef = getDefIgnoringCopies(SrcReg, MRI);
755     if (!SrcDef)
756       return false;
757 
758     LLT OpTy = MRI.getType(SrcReg);
759     LLT DestTy = MRI.getType(MI.getReg(0));
760     unsigned SrcDefIdx = getDefIndex(*SrcDef, SrcReg);
761 
762     Builder.setInstrAndDebugLoc(MI);
763 
764     auto tryCombineViaValueFinder = [&]() {
765       ArtifactValueFinder ValueFinder(MRI, Builder, LI);
766 
767       SmallBitVector DeadDefs(NumDefs);
768       for (unsigned DefIdx = 0; DefIdx < NumDefs; ++DefIdx) {
769         Register DefReg = MI.getReg(DefIdx);
770         Register FoundVal =
771             ValueFinder.findValueFromDef(DefReg, 0, DestTy.getSizeInBits());
772         if (!FoundVal || FoundVal == DefReg)
773           continue;
774         if (MRI.getType(FoundVal) != DestTy)
775           continue;
776 
777         replaceRegOrBuildCopy(DefReg, FoundVal, MRI, Builder, UpdatedDefs,
778                               Observer);
779         // We only want to replace the uses, not the def of the old reg.
780         Observer.changingInstr(MI);
781         MI.getOperand(DefIdx).setReg(DefReg);
782         Observer.changedInstr(MI);
783         DeadDefs[DefIdx] = true;
784       }
785       if (DeadDefs.all()) {
786         markInstAndDefDead(MI, *SrcDef, DeadInsts, SrcDefIdx);
787         return true;
788       }
789       return false;
790     };
791 
792     if (auto *SrcUnmerge = dyn_cast<GUnmerge>(SrcDef)) {
793       // %0:_(<4 x s16>) = G_FOO
794       // %1:_(<2 x s16>), %2:_(<2 x s16>) = G_UNMERGE_VALUES %0
795       // %3:_(s16), %4:_(s16) = G_UNMERGE_VALUES %1
796       //
797       // %3:_(s16), %4:_(s16), %5:_(s16), %6:_(s16) = G_UNMERGE_VALUES %0
798       Register SrcUnmergeSrc = SrcUnmerge->getSourceReg();
799       LLT SrcUnmergeSrcTy = MRI.getType(SrcUnmergeSrc);
800 
801       // If we need to decrease the number of vector elements in the result type
802       // of an unmerge, this would involve the creation of an equivalent unmerge
803       // to copy back to the original result registers.
804       LegalizeActionStep ActionStep = LI.getAction(
805           {TargetOpcode::G_UNMERGE_VALUES, {OpTy, SrcUnmergeSrcTy}});
806       switch (ActionStep.Action) {
807       case LegalizeActions::Lower:
808       case LegalizeActions::Unsupported:
809         break;
810       case LegalizeActions::FewerElements:
811       case LegalizeActions::NarrowScalar:
812         if (ActionStep.TypeIdx == 1)
813           return false;
814         break;
815       default:
816         return tryCombineViaValueFinder();
817       }
818 
819       auto NewUnmerge = Builder.buildUnmerge(DestTy, SrcUnmergeSrc);
820 
821       // TODO: Should we try to process out the other defs now? If the other
822       // defs of the source unmerge are also unmerged, we end up with a separate
823       // unmerge for each one.
824       for (unsigned I = 0; I != NumDefs; ++I) {
825         Register Def = MI.getReg(I);
826         replaceRegOrBuildCopy(Def, NewUnmerge.getReg(SrcDefIdx * NumDefs + I),
827                               MRI, Builder, UpdatedDefs, Observer);
828       }
829 
830       markInstAndDefDead(MI, *SrcUnmerge, DeadInsts, SrcDefIdx);
831       return true;
832     }
833 
834     MachineInstr *MergeI = SrcDef;
835     unsigned ConvertOp = 0;
836 
837     // Handle intermediate conversions
838     unsigned SrcOp = SrcDef->getOpcode();
839     if (isArtifactCast(SrcOp)) {
840       ConvertOp = SrcOp;
841       MergeI = getDefIgnoringCopies(SrcDef->getOperand(1).getReg(), MRI);
842     }
843 
844     if (!MergeI || !canFoldMergeOpcode(MergeI->getOpcode(),
845                                        ConvertOp, OpTy, DestTy)) {
846       // We might have a chance to combine later by trying to combine
847       // unmerge(cast) first
848       if (tryFoldUnmergeCast(MI, *SrcDef, DeadInsts, UpdatedDefs))
849         return true;
850 
851       // Try using the value finder.
852       return tryCombineViaValueFinder();
853     }
854 
855     const unsigned NumMergeRegs = MergeI->getNumOperands() - 1;
856 
857     if (NumMergeRegs < NumDefs) {
858       if (NumDefs % NumMergeRegs != 0)
859         return false;
860 
861       Builder.setInstr(MI);
862       // Transform to UNMERGEs, for example
863       //   %1 = G_MERGE_VALUES %4, %5
864       //   %9, %10, %11, %12 = G_UNMERGE_VALUES %1
865       // to
866       //   %9, %10 = G_UNMERGE_VALUES %4
867       //   %11, %12 = G_UNMERGE_VALUES %5
868 
869       const unsigned NewNumDefs = NumDefs / NumMergeRegs;
870       for (unsigned Idx = 0; Idx < NumMergeRegs; ++Idx) {
871         SmallVector<Register, 8> DstRegs;
872         for (unsigned j = 0, DefIdx = Idx * NewNumDefs; j < NewNumDefs;
873              ++j, ++DefIdx)
874           DstRegs.push_back(MI.getReg(DefIdx));
875 
876         if (ConvertOp) {
877           LLT MergeSrcTy = MRI.getType(MergeI->getOperand(1).getReg());
878 
879           // This is a vector that is being split and casted. Extract to the
880           // element type, and do the conversion on the scalars (or smaller
881           // vectors).
882           LLT MergeEltTy = MergeSrcTy.divide(NewNumDefs);
883 
884           // Handle split to smaller vectors, with conversions.
885           // %2(<8 x s8>) = G_CONCAT_VECTORS %0(<4 x s8>), %1(<4 x s8>)
886           // %3(<8 x s16>) = G_SEXT %2
887           // %4(<2 x s16>), %5(<2 x s16>), %6(<2 x s16>), %7(<2 x s16>) = G_UNMERGE_VALUES %3
888           //
889           // =>
890           //
891           // %8(<2 x s8>), %9(<2 x s8>) = G_UNMERGE_VALUES %0
892           // %10(<2 x s8>), %11(<2 x s8>) = G_UNMERGE_VALUES %1
893           // %4(<2 x s16>) = G_SEXT %8
894           // %5(<2 x s16>) = G_SEXT %9
895           // %6(<2 x s16>) = G_SEXT %10
896           // %7(<2 x s16>)= G_SEXT %11
897 
898           SmallVector<Register, 4> TmpRegs(NewNumDefs);
899           for (unsigned k = 0; k < NewNumDefs; ++k)
900             TmpRegs[k] = MRI.createGenericVirtualRegister(MergeEltTy);
901 
902           Builder.buildUnmerge(TmpRegs, MergeI->getOperand(Idx + 1).getReg());
903 
904           for (unsigned k = 0; k < NewNumDefs; ++k)
905             Builder.buildInstr(ConvertOp, {DstRegs[k]}, {TmpRegs[k]});
906         } else {
907           Builder.buildUnmerge(DstRegs, MergeI->getOperand(Idx + 1).getReg());
908         }
909         UpdatedDefs.append(DstRegs.begin(), DstRegs.end());
910       }
911 
912     } else if (NumMergeRegs > NumDefs) {
913       if (ConvertOp != 0 || NumMergeRegs % NumDefs != 0)
914         return false;
915 
916       Builder.setInstr(MI);
917       // Transform to MERGEs
918       //   %6 = G_MERGE_VALUES %17, %18, %19, %20
919       //   %7, %8 = G_UNMERGE_VALUES %6
920       // to
921       //   %7 = G_MERGE_VALUES %17, %18
922       //   %8 = G_MERGE_VALUES %19, %20
923 
924       const unsigned NumRegs = NumMergeRegs / NumDefs;
925       for (unsigned DefIdx = 0; DefIdx < NumDefs; ++DefIdx) {
926         SmallVector<Register, 8> Regs;
927         for (unsigned j = 0, Idx = NumRegs * DefIdx + 1; j < NumRegs;
928              ++j, ++Idx)
929           Regs.push_back(MergeI->getOperand(Idx).getReg());
930 
931         Register DefReg = MI.getReg(DefIdx);
932         Builder.buildMerge(DefReg, Regs);
933         UpdatedDefs.push_back(DefReg);
934       }
935 
936     } else {
937       LLT MergeSrcTy = MRI.getType(MergeI->getOperand(1).getReg());
938 
939       if (!ConvertOp && DestTy != MergeSrcTy)
940         ConvertOp = TargetOpcode::G_BITCAST;
941 
942       if (ConvertOp) {
943         Builder.setInstr(MI);
944 
945         for (unsigned Idx = 0; Idx < NumDefs; ++Idx) {
946           Register MergeSrc = MergeI->getOperand(Idx + 1).getReg();
947           Register DefReg = MI.getOperand(Idx).getReg();
948           Builder.buildInstr(ConvertOp, {DefReg}, {MergeSrc});
949           UpdatedDefs.push_back(DefReg);
950         }
951 
952         markInstAndDefDead(MI, *MergeI, DeadInsts);
953         return true;
954       }
955 
956       assert(DestTy == MergeSrcTy &&
957              "Bitcast and the other kinds of conversions should "
958              "have happened earlier");
959 
960       Builder.setInstr(MI);
961       for (unsigned Idx = 0; Idx < NumDefs; ++Idx) {
962         Register DstReg = MI.getOperand(Idx).getReg();
963         Register SrcReg = MergeI->getOperand(Idx + 1).getReg();
964         replaceRegOrBuildCopy(DstReg, SrcReg, MRI, Builder, UpdatedDefs,
965                               Observer);
966       }
967     }
968 
969     markInstAndDefDead(MI, *MergeI, DeadInsts);
970     return true;
971   }
972 
tryCombineExtract(MachineInstr & MI,SmallVectorImpl<MachineInstr * > & DeadInsts,SmallVectorImpl<Register> & UpdatedDefs)973   bool tryCombineExtract(MachineInstr &MI,
974                          SmallVectorImpl<MachineInstr *> &DeadInsts,
975                          SmallVectorImpl<Register> &UpdatedDefs) {
976     assert(MI.getOpcode() == TargetOpcode::G_EXTRACT);
977 
978     // Try to use the source registers from a G_MERGE_VALUES
979     //
980     // %2 = G_MERGE_VALUES %0, %1
981     // %3 = G_EXTRACT %2, N
982     // =>
983     //
984     // for N < %2.getSizeInBits() / 2
985     //     %3 = G_EXTRACT %0, N
986     //
987     // for N >= %2.getSizeInBits() / 2
988     //    %3 = G_EXTRACT %1, (N - %0.getSizeInBits()
989 
990     Register SrcReg = lookThroughCopyInstrs(MI.getOperand(1).getReg());
991     MachineInstr *MergeI = MRI.getVRegDef(SrcReg);
992     if (!MergeI || !isa<GMergeLikeOp>(MergeI))
993       return false;
994 
995     Register DstReg = MI.getOperand(0).getReg();
996     LLT DstTy = MRI.getType(DstReg);
997     LLT SrcTy = MRI.getType(SrcReg);
998 
999     // TODO: Do we need to check if the resulting extract is supported?
1000     unsigned ExtractDstSize = DstTy.getSizeInBits();
1001     unsigned Offset = MI.getOperand(2).getImm();
1002     unsigned NumMergeSrcs = MergeI->getNumOperands() - 1;
1003     unsigned MergeSrcSize = SrcTy.getSizeInBits() / NumMergeSrcs;
1004     unsigned MergeSrcIdx = Offset / MergeSrcSize;
1005 
1006     // Compute the offset of the last bit the extract needs.
1007     unsigned EndMergeSrcIdx = (Offset + ExtractDstSize - 1) / MergeSrcSize;
1008 
1009     // Can't handle the case where the extract spans multiple inputs.
1010     if (MergeSrcIdx != EndMergeSrcIdx)
1011       return false;
1012 
1013     // TODO: We could modify MI in place in most cases.
1014     Builder.setInstr(MI);
1015     Builder.buildExtract(DstReg, MergeI->getOperand(MergeSrcIdx + 1).getReg(),
1016                          Offset - MergeSrcIdx * MergeSrcSize);
1017     UpdatedDefs.push_back(DstReg);
1018     markInstAndDefDead(MI, *MergeI, DeadInsts);
1019     return true;
1020   }
1021 
1022   /// Try to combine away MI.
1023   /// Returns true if it combined away the MI.
1024   /// Adds instructions that are dead as a result of the combine
1025   /// into DeadInsts, which can include MI.
tryCombineInstruction(MachineInstr & MI,SmallVectorImpl<MachineInstr * > & DeadInsts,GISelObserverWrapper & WrapperObserver)1026   bool tryCombineInstruction(MachineInstr &MI,
1027                              SmallVectorImpl<MachineInstr *> &DeadInsts,
1028                              GISelObserverWrapper &WrapperObserver) {
1029     // This might be a recursive call, and we might have DeadInsts already
1030     // populated. To avoid bad things happening later with multiple vreg defs
1031     // etc, process the dead instructions now if any.
1032     if (!DeadInsts.empty())
1033       deleteMarkedDeadInsts(DeadInsts, WrapperObserver);
1034 
1035     // Put here every vreg that was redefined in such a way that it's at least
1036     // possible that one (or more) of its users (immediate or COPY-separated)
1037     // could become artifact combinable with the new definition (or the
1038     // instruction reachable from it through a chain of copies if any).
1039     SmallVector<Register, 4> UpdatedDefs;
1040     bool Changed = false;
1041     switch (MI.getOpcode()) {
1042     default:
1043       return false;
1044     case TargetOpcode::G_ANYEXT:
1045       Changed = tryCombineAnyExt(MI, DeadInsts, UpdatedDefs);
1046       break;
1047     case TargetOpcode::G_ZEXT:
1048       Changed = tryCombineZExt(MI, DeadInsts, UpdatedDefs, WrapperObserver);
1049       break;
1050     case TargetOpcode::G_SEXT:
1051       Changed = tryCombineSExt(MI, DeadInsts, UpdatedDefs);
1052       break;
1053     case TargetOpcode::G_UNMERGE_VALUES:
1054       Changed = tryCombineUnmergeValues(cast<GUnmerge>(MI), DeadInsts,
1055                                         UpdatedDefs, WrapperObserver);
1056       break;
1057     case TargetOpcode::G_MERGE_VALUES:
1058     case TargetOpcode::G_BUILD_VECTOR:
1059     case TargetOpcode::G_CONCAT_VECTORS:
1060       // If any of the users of this merge are an unmerge, then add them to the
1061       // artifact worklist in case there's folding that can be done looking up.
1062       for (MachineInstr &U : MRI.use_instructions(MI.getOperand(0).getReg())) {
1063         if (U.getOpcode() == TargetOpcode::G_UNMERGE_VALUES ||
1064             U.getOpcode() == TargetOpcode::G_TRUNC) {
1065           UpdatedDefs.push_back(MI.getOperand(0).getReg());
1066           break;
1067         }
1068       }
1069       break;
1070     case TargetOpcode::G_EXTRACT:
1071       Changed = tryCombineExtract(MI, DeadInsts, UpdatedDefs);
1072       break;
1073     case TargetOpcode::G_TRUNC:
1074       Changed = tryCombineTrunc(MI, DeadInsts, UpdatedDefs, WrapperObserver);
1075       if (!Changed) {
1076         // Try to combine truncates away even if they are legal. As all artifact
1077         // combines at the moment look only "up" the def-use chains, we achieve
1078         // that by throwing truncates' users (with look through copies) into the
1079         // ArtifactList again.
1080         UpdatedDefs.push_back(MI.getOperand(0).getReg());
1081       }
1082       break;
1083     }
1084     // If the main loop through the ArtifactList found at least one combinable
1085     // pair of artifacts, not only combine it away (as done above), but also
1086     // follow the def-use chain from there to combine everything that can be
1087     // combined within this def-use chain of artifacts.
1088     while (!UpdatedDefs.empty()) {
1089       Register NewDef = UpdatedDefs.pop_back_val();
1090       assert(NewDef.isVirtual() && "Unexpected redefinition of a physreg");
1091       for (MachineInstr &Use : MRI.use_instructions(NewDef)) {
1092         switch (Use.getOpcode()) {
1093         // Keep this list in sync with the list of all artifact combines.
1094         case TargetOpcode::G_ANYEXT:
1095         case TargetOpcode::G_ZEXT:
1096         case TargetOpcode::G_SEXT:
1097         case TargetOpcode::G_UNMERGE_VALUES:
1098         case TargetOpcode::G_EXTRACT:
1099         case TargetOpcode::G_TRUNC:
1100           // Adding Use to ArtifactList.
1101           WrapperObserver.changedInstr(Use);
1102           break;
1103         case TargetOpcode::COPY: {
1104           Register Copy = Use.getOperand(0).getReg();
1105           if (Copy.isVirtual())
1106             UpdatedDefs.push_back(Copy);
1107           break;
1108         }
1109         default:
1110           // If we do not have an artifact combine for the opcode, there is no
1111           // point in adding it to the ArtifactList as nothing interesting will
1112           // be done to it anyway.
1113           break;
1114         }
1115       }
1116     }
1117     return Changed;
1118   }
1119 
1120 private:
getArtifactSrcReg(const MachineInstr & MI)1121   static Register getArtifactSrcReg(const MachineInstr &MI) {
1122     switch (MI.getOpcode()) {
1123     case TargetOpcode::COPY:
1124     case TargetOpcode::G_TRUNC:
1125     case TargetOpcode::G_ZEXT:
1126     case TargetOpcode::G_ANYEXT:
1127     case TargetOpcode::G_SEXT:
1128     case TargetOpcode::G_EXTRACT:
1129       return MI.getOperand(1).getReg();
1130     case TargetOpcode::G_UNMERGE_VALUES:
1131       return MI.getOperand(MI.getNumOperands() - 1).getReg();
1132     default:
1133       llvm_unreachable("Not a legalization artifact happen");
1134     }
1135   }
1136 
1137   /// Mark a def of one of MI's original operands, DefMI, as dead if changing MI
1138   /// (either by killing it or changing operands) results in DefMI being dead
1139   /// too. In-between COPYs or artifact-casts are also collected if they are
1140   /// dead.
1141   /// MI is not marked dead.
1142   void markDefDead(MachineInstr &MI, MachineInstr &DefMI,
1143                    SmallVectorImpl<MachineInstr *> &DeadInsts,
1144                    unsigned DefIdx = 0) {
1145     // Collect all the copy instructions that are made dead, due to deleting
1146     // this instruction. Collect all of them until the Trunc(DefMI).
1147     // Eg,
1148     // %1(s1) = G_TRUNC %0(s32)
1149     // %2(s1) = COPY %1(s1)
1150     // %3(s1) = COPY %2(s1)
1151     // %4(s32) = G_ANYEXT %3(s1)
1152     // In this case, we would have replaced %4 with a copy of %0,
1153     // and as a result, %3, %2, %1 are dead.
1154     MachineInstr *PrevMI = &MI;
1155     while (PrevMI != &DefMI) {
1156       Register PrevRegSrc = getArtifactSrcReg(*PrevMI);
1157 
1158       MachineInstr *TmpDef = MRI.getVRegDef(PrevRegSrc);
1159       if (MRI.hasOneUse(PrevRegSrc)) {
1160         if (TmpDef != &DefMI) {
1161           assert((TmpDef->getOpcode() == TargetOpcode::COPY ||
1162                   isArtifactCast(TmpDef->getOpcode())) &&
1163                  "Expecting copy or artifact cast here");
1164 
1165           DeadInsts.push_back(TmpDef);
1166         }
1167       } else
1168         break;
1169       PrevMI = TmpDef;
1170     }
1171 
1172     if (PrevMI == &DefMI) {
1173       unsigned I = 0;
1174       bool IsDead = true;
1175       for (MachineOperand &Def : DefMI.defs()) {
1176         if (I != DefIdx) {
1177           if (!MRI.use_empty(Def.getReg())) {
1178             IsDead = false;
1179             break;
1180           }
1181         } else {
1182           if (!MRI.hasOneUse(DefMI.getOperand(DefIdx).getReg()))
1183             break;
1184         }
1185 
1186         ++I;
1187       }
1188 
1189       if (IsDead)
1190         DeadInsts.push_back(&DefMI);
1191     }
1192   }
1193 
1194   /// Mark MI as dead. If a def of one of MI's operands, DefMI, would also be
1195   /// dead due to MI being killed, then mark DefMI as dead too.
1196   /// Some of the combines (extends(trunc)), try to walk through redundant
1197   /// copies in between the extends and the truncs, and this attempts to collect
1198   /// the in between copies if they're dead.
1199   void markInstAndDefDead(MachineInstr &MI, MachineInstr &DefMI,
1200                           SmallVectorImpl<MachineInstr *> &DeadInsts,
1201                           unsigned DefIdx = 0) {
1202     DeadInsts.push_back(&MI);
1203     markDefDead(MI, DefMI, DeadInsts, DefIdx);
1204   }
1205 
1206   /// Erase the dead instructions in the list and call the observer hooks.
1207   /// Normally the Legalizer will deal with erasing instructions that have been
1208   /// marked dead. However, for the trunc(ext(x)) cases we can end up trying to
1209   /// process instructions which have been marked dead, but otherwise break the
1210   /// MIR by introducing multiple vreg defs. For those cases, allow the combines
1211   /// to explicitly delete the instructions before we run into trouble.
deleteMarkedDeadInsts(SmallVectorImpl<MachineInstr * > & DeadInsts,GISelObserverWrapper & WrapperObserver)1212   void deleteMarkedDeadInsts(SmallVectorImpl<MachineInstr *> &DeadInsts,
1213                              GISelObserverWrapper &WrapperObserver) {
1214     for (auto *DeadMI : DeadInsts) {
1215       LLVM_DEBUG(dbgs() << *DeadMI << "Is dead, eagerly deleting\n");
1216       WrapperObserver.erasingInstr(*DeadMI);
1217       DeadMI->eraseFromParentAndMarkDBGValuesForRemoval();
1218     }
1219     DeadInsts.clear();
1220   }
1221 
1222   /// Checks if the target legalizer info has specified anything about the
1223   /// instruction, or if unsupported.
isInstUnsupported(const LegalityQuery & Query)1224   bool isInstUnsupported(const LegalityQuery &Query) const {
1225     using namespace LegalizeActions;
1226     auto Step = LI.getAction(Query);
1227     return Step.Action == Unsupported || Step.Action == NotFound;
1228   }
1229 
isInstLegal(const LegalityQuery & Query)1230   bool isInstLegal(const LegalityQuery &Query) const {
1231     return LI.getAction(Query).Action == LegalizeActions::Legal;
1232   }
1233 
isConstantUnsupported(LLT Ty)1234   bool isConstantUnsupported(LLT Ty) const {
1235     if (!Ty.isVector())
1236       return isInstUnsupported({TargetOpcode::G_CONSTANT, {Ty}});
1237 
1238     LLT EltTy = Ty.getElementType();
1239     return isInstUnsupported({TargetOpcode::G_CONSTANT, {EltTy}}) ||
1240            isInstUnsupported({TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}});
1241   }
1242 
1243   /// Looks through copy instructions and returns the actual
1244   /// source register.
lookThroughCopyInstrs(Register Reg)1245   Register lookThroughCopyInstrs(Register Reg) {
1246     Register TmpReg;
1247     while (mi_match(Reg, MRI, m_Copy(m_Reg(TmpReg)))) {
1248       if (MRI.getType(TmpReg).isValid())
1249         Reg = TmpReg;
1250       else
1251         break;
1252     }
1253     return Reg;
1254   }
1255 };
1256 
1257 } // namespace llvm
1258 
1259 #endif // LLVM_CODEGEN_GLOBALISEL_LEGALIZATIONARTIFACTCOMBINER_H
1260