1 //===-- llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h -----*- C++ -*-//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file contains some helper functions which try to cleanup artifacts
9 // such as G_TRUNCs/G_[ZSA]EXTENDS that were created during legalization to make
10 // the types match. This file also contains some combines of merges that happens
11 // at the end of the legalization.
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/GlobalISel/Legalizer.h"
15 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
16 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
17 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
18 #include "llvm/CodeGen/GlobalISel/Utils.h"
19 #include "llvm/CodeGen/MachineRegisterInfo.h"
20 #include "llvm/Support/Debug.h"
21 
22 #define DEBUG_TYPE "legalizer"
23 using namespace llvm::MIPatternMatch;
24 
25 namespace llvm {
26 class LegalizationArtifactCombiner {
27   MachineIRBuilder &Builder;
28   MachineRegisterInfo &MRI;
29   const LegalizerInfo &LI;
30 
31   static bool isArtifactCast(unsigned Opc) {
32     switch (Opc) {
33     case TargetOpcode::G_TRUNC:
34     case TargetOpcode::G_SEXT:
35     case TargetOpcode::G_ZEXT:
36     case TargetOpcode::G_ANYEXT:
37       return true;
38     default:
39       return false;
40     }
41   }
42 
43 public:
44   LegalizationArtifactCombiner(MachineIRBuilder &B, MachineRegisterInfo &MRI,
45                     const LegalizerInfo &LI)
46       : Builder(B), MRI(MRI), LI(LI) {}
47 
48   bool tryCombineAnyExt(MachineInstr &MI,
49                         SmallVectorImpl<MachineInstr *> &DeadInsts,
50                         SmallVectorImpl<Register> &UpdatedDefs) {
51     assert(MI.getOpcode() == TargetOpcode::G_ANYEXT);
52 
53     Builder.setInstrAndDebugLoc(MI);
54     Register DstReg = MI.getOperand(0).getReg();
55     Register SrcReg = lookThroughCopyInstrs(MI.getOperand(1).getReg());
56 
57     // aext(trunc x) - > aext/copy/trunc x
58     Register TruncSrc;
59     if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc)))) {
60       LLVM_DEBUG(dbgs() << ".. Combine MI: " << MI;);
61       Builder.buildAnyExtOrTrunc(DstReg, TruncSrc);
62       UpdatedDefs.push_back(DstReg);
63       markInstAndDefDead(MI, *MRI.getVRegDef(SrcReg), DeadInsts);
64       return true;
65     }
66 
67     // aext([asz]ext x) -> [asz]ext x
68     Register ExtSrc;
69     MachineInstr *ExtMI;
70     if (mi_match(SrcReg, MRI,
71                  m_all_of(m_MInstr(ExtMI), m_any_of(m_GAnyExt(m_Reg(ExtSrc)),
72                                                     m_GSExt(m_Reg(ExtSrc)),
73                                                     m_GZExt(m_Reg(ExtSrc)))))) {
74       Builder.buildInstr(ExtMI->getOpcode(), {DstReg}, {ExtSrc});
75       UpdatedDefs.push_back(DstReg);
76       markInstAndDefDead(MI, *ExtMI, DeadInsts);
77       return true;
78     }
79 
80     // Try to fold aext(g_constant) when the larger constant type is legal.
81     // Can't use MIPattern because we don't have a specific constant in mind.
82     auto *SrcMI = MRI.getVRegDef(SrcReg);
83     if (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT) {
84       const LLT DstTy = MRI.getType(DstReg);
85       if (isInstLegal({TargetOpcode::G_CONSTANT, {DstTy}})) {
86         auto &CstVal = SrcMI->getOperand(1);
87         Builder.buildConstant(
88             DstReg, CstVal.getCImm()->getValue().sext(DstTy.getSizeInBits()));
89         UpdatedDefs.push_back(DstReg);
90         markInstAndDefDead(MI, *SrcMI, DeadInsts);
91         return true;
92       }
93     }
94     return tryFoldImplicitDef(MI, DeadInsts, UpdatedDefs);
95   }
96 
97   bool tryCombineZExt(MachineInstr &MI,
98                       SmallVectorImpl<MachineInstr *> &DeadInsts,
99                       SmallVectorImpl<Register> &UpdatedDefs,
100                       GISelObserverWrapper &Observer) {
101     assert(MI.getOpcode() == TargetOpcode::G_ZEXT);
102 
103     Builder.setInstrAndDebugLoc(MI);
104     Register DstReg = MI.getOperand(0).getReg();
105     Register SrcReg = lookThroughCopyInstrs(MI.getOperand(1).getReg());
106 
107     // zext(trunc x) - > and (aext/copy/trunc x), mask
108     Register TruncSrc;
109     if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc)))) {
110       LLT DstTy = MRI.getType(DstReg);
111       if (isInstUnsupported({TargetOpcode::G_AND, {DstTy}}) ||
112           isConstantUnsupported(DstTy))
113         return false;
114       LLVM_DEBUG(dbgs() << ".. Combine MI: " << MI;);
115       LLT SrcTy = MRI.getType(SrcReg);
116       APInt Mask = APInt::getAllOnesValue(SrcTy.getScalarSizeInBits());
117       auto MIBMask = Builder.buildConstant(
118         DstTy, Mask.zext(DstTy.getScalarSizeInBits()));
119       Builder.buildAnd(DstReg, Builder.buildAnyExtOrTrunc(DstTy, TruncSrc),
120                        MIBMask);
121       markInstAndDefDead(MI, *MRI.getVRegDef(SrcReg), DeadInsts);
122       return true;
123     }
124 
125     // zext(zext x) -> (zext x)
126     Register ZextSrc;
127     if (mi_match(SrcReg, MRI, m_GZExt(m_Reg(ZextSrc)))) {
128       LLVM_DEBUG(dbgs() << ".. Combine MI: " << MI);
129       Observer.changingInstr(MI);
130       MI.getOperand(1).setReg(ZextSrc);
131       Observer.changedInstr(MI);
132       UpdatedDefs.push_back(DstReg);
133       markDefDead(MI, *MRI.getVRegDef(SrcReg), DeadInsts);
134       return true;
135     }
136 
137     // Try to fold zext(g_constant) when the larger constant type is legal.
138     // Can't use MIPattern because we don't have a specific constant in mind.
139     auto *SrcMI = MRI.getVRegDef(SrcReg);
140     if (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT) {
141       const LLT DstTy = MRI.getType(DstReg);
142       if (isInstLegal({TargetOpcode::G_CONSTANT, {DstTy}})) {
143         auto &CstVal = SrcMI->getOperand(1);
144         Builder.buildConstant(
145             DstReg, CstVal.getCImm()->getValue().zext(DstTy.getSizeInBits()));
146         UpdatedDefs.push_back(DstReg);
147         markInstAndDefDead(MI, *SrcMI, DeadInsts);
148         return true;
149       }
150     }
151     return tryFoldImplicitDef(MI, DeadInsts, UpdatedDefs);
152   }
153 
154   bool tryCombineSExt(MachineInstr &MI,
155                       SmallVectorImpl<MachineInstr *> &DeadInsts,
156                       SmallVectorImpl<Register> &UpdatedDefs) {
157     assert(MI.getOpcode() == TargetOpcode::G_SEXT);
158 
159     Builder.setInstrAndDebugLoc(MI);
160     Register DstReg = MI.getOperand(0).getReg();
161     Register SrcReg = lookThroughCopyInstrs(MI.getOperand(1).getReg());
162 
163     // sext(trunc x) - > (sext_inreg (aext/copy/trunc x), c)
164     Register TruncSrc;
165     if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc)))) {
166       LLT DstTy = MRI.getType(DstReg);
167       if (isInstUnsupported({TargetOpcode::G_SEXT_INREG, {DstTy}}))
168         return false;
169       LLVM_DEBUG(dbgs() << ".. Combine MI: " << MI;);
170       LLT SrcTy = MRI.getType(SrcReg);
171       uint64_t SizeInBits = SrcTy.getScalarSizeInBits();
172       Builder.buildInstr(
173           TargetOpcode::G_SEXT_INREG, {DstReg},
174           {Builder.buildAnyExtOrTrunc(DstTy, TruncSrc), SizeInBits});
175       markInstAndDefDead(MI, *MRI.getVRegDef(SrcReg), DeadInsts);
176       return true;
177     }
178 
179     // sext(zext x) -> (zext x)
180     // sext(sext x) -> (sext x)
181     Register ExtSrc;
182     MachineInstr *ExtMI;
183     if (mi_match(SrcReg, MRI,
184                  m_all_of(m_MInstr(ExtMI), m_any_of(m_GZExt(m_Reg(ExtSrc)),
185                                                     m_GSExt(m_Reg(ExtSrc)))))) {
186       LLVM_DEBUG(dbgs() << ".. Combine MI: " << MI);
187       Builder.buildInstr(ExtMI->getOpcode(), {DstReg}, {ExtSrc});
188       UpdatedDefs.push_back(DstReg);
189       markInstAndDefDead(MI, *MRI.getVRegDef(SrcReg), DeadInsts);
190       return true;
191     }
192 
193     return tryFoldImplicitDef(MI, DeadInsts, UpdatedDefs);
194   }
195 
196   bool tryCombineTrunc(MachineInstr &MI,
197                        SmallVectorImpl<MachineInstr *> &DeadInsts,
198                        SmallVectorImpl<Register> &UpdatedDefs,
199                        GISelObserverWrapper &Observer) {
200     assert(MI.getOpcode() == TargetOpcode::G_TRUNC);
201 
202     Builder.setInstr(MI);
203     Register DstReg = MI.getOperand(0).getReg();
204     Register SrcReg = lookThroughCopyInstrs(MI.getOperand(1).getReg());
205 
206     // Try to fold trunc(g_constant) when the smaller constant type is legal.
207     // Can't use MIPattern because we don't have a specific constant in mind.
208     auto *SrcMI = MRI.getVRegDef(SrcReg);
209     if (SrcMI->getOpcode() == TargetOpcode::G_CONSTANT) {
210       const LLT DstTy = MRI.getType(DstReg);
211       if (isInstLegal({TargetOpcode::G_CONSTANT, {DstTy}})) {
212         auto &CstVal = SrcMI->getOperand(1);
213         Builder.buildConstant(
214             DstReg, CstVal.getCImm()->getValue().trunc(DstTy.getSizeInBits()));
215         UpdatedDefs.push_back(DstReg);
216         markInstAndDefDead(MI, *SrcMI, DeadInsts);
217         return true;
218       }
219     }
220 
221     // Try to fold trunc(merge) to directly use the source of the merge.
222     // This gets rid of large, difficult to legalize, merges
223     if (SrcMI->getOpcode() == TargetOpcode::G_MERGE_VALUES) {
224       const Register MergeSrcReg = SrcMI->getOperand(1).getReg();
225       const LLT MergeSrcTy = MRI.getType(MergeSrcReg);
226       const LLT DstTy = MRI.getType(DstReg);
227 
228       // We can only fold if the types are scalar
229       const unsigned DstSize = DstTy.getSizeInBits();
230       const unsigned MergeSrcSize = MergeSrcTy.getSizeInBits();
231       if (!DstTy.isScalar() || !MergeSrcTy.isScalar())
232         return false;
233 
234       if (DstSize < MergeSrcSize) {
235         // When the merge source is larger than the destination, we can just
236         // truncate the merge source directly
237         if (isInstUnsupported({TargetOpcode::G_TRUNC, {DstTy, MergeSrcTy}}))
238           return false;
239 
240         LLVM_DEBUG(dbgs() << "Combining G_TRUNC(G_MERGE_VALUES) to G_TRUNC: "
241                           << MI);
242 
243         Builder.buildTrunc(DstReg, MergeSrcReg);
244         UpdatedDefs.push_back(DstReg);
245       } else if (DstSize == MergeSrcSize) {
246         // If the sizes match we can simply try to replace the register
247         LLVM_DEBUG(
248             dbgs() << "Replacing G_TRUNC(G_MERGE_VALUES) with merge input: "
249                    << MI);
250         replaceRegOrBuildCopy(DstReg, MergeSrcReg, MRI, Builder, UpdatedDefs,
251                               Observer);
252       } else if (DstSize % MergeSrcSize == 0) {
253         // If the trunc size is a multiple of the merge source size we can use
254         // a smaller merge instead
255         if (isInstUnsupported(
256                 {TargetOpcode::G_MERGE_VALUES, {DstTy, MergeSrcTy}}))
257           return false;
258 
259         LLVM_DEBUG(
260             dbgs() << "Combining G_TRUNC(G_MERGE_VALUES) to G_MERGE_VALUES: "
261                    << MI);
262 
263         const unsigned NumSrcs = DstSize / MergeSrcSize;
264         assert(NumSrcs < SrcMI->getNumOperands() - 1 &&
265                "trunc(merge) should require less inputs than merge");
266         SmallVector<Register, 8> SrcRegs(NumSrcs);
267         for (unsigned i = 0; i < NumSrcs; ++i)
268           SrcRegs[i] = SrcMI->getOperand(i + 1).getReg();
269 
270         Builder.buildMerge(DstReg, SrcRegs);
271         UpdatedDefs.push_back(DstReg);
272       } else {
273         // Unable to combine
274         return false;
275       }
276 
277       markInstAndDefDead(MI, *SrcMI, DeadInsts);
278       return true;
279     }
280 
281     // trunc(trunc) -> trunc
282     Register TruncSrc;
283     if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc)))) {
284       // Always combine trunc(trunc) since the eventual resulting trunc must be
285       // legal anyway as it must be legal for all outputs of the consumer type
286       // set.
287       LLVM_DEBUG(dbgs() << ".. Combine G_TRUNC(G_TRUNC): " << MI);
288 
289       Builder.buildTrunc(DstReg, TruncSrc);
290       UpdatedDefs.push_back(DstReg);
291       markInstAndDefDead(MI, *MRI.getVRegDef(TruncSrc), DeadInsts);
292       return true;
293     }
294 
295     return false;
296   }
297 
298   /// Try to fold G_[ASZ]EXT (G_IMPLICIT_DEF).
299   bool tryFoldImplicitDef(MachineInstr &MI,
300                           SmallVectorImpl<MachineInstr *> &DeadInsts,
301                           SmallVectorImpl<Register> &UpdatedDefs) {
302     unsigned Opcode = MI.getOpcode();
303     assert(Opcode == TargetOpcode::G_ANYEXT || Opcode == TargetOpcode::G_ZEXT ||
304            Opcode == TargetOpcode::G_SEXT);
305 
306     if (MachineInstr *DefMI = getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF,
307                                            MI.getOperand(1).getReg(), MRI)) {
308       Builder.setInstr(MI);
309       Register DstReg = MI.getOperand(0).getReg();
310       LLT DstTy = MRI.getType(DstReg);
311 
312       if (Opcode == TargetOpcode::G_ANYEXT) {
313         // G_ANYEXT (G_IMPLICIT_DEF) -> G_IMPLICIT_DEF
314         if (!isInstLegal({TargetOpcode::G_IMPLICIT_DEF, {DstTy}}))
315           return false;
316         LLVM_DEBUG(dbgs() << ".. Combine G_ANYEXT(G_IMPLICIT_DEF): " << MI;);
317         Builder.buildInstr(TargetOpcode::G_IMPLICIT_DEF, {DstReg}, {});
318         UpdatedDefs.push_back(DstReg);
319       } else {
320         // G_[SZ]EXT (G_IMPLICIT_DEF) -> G_CONSTANT 0 because the top
321         // bits will be 0 for G_ZEXT and 0/1 for the G_SEXT.
322         if (isConstantUnsupported(DstTy))
323           return false;
324         LLVM_DEBUG(dbgs() << ".. Combine G_[SZ]EXT(G_IMPLICIT_DEF): " << MI;);
325         Builder.buildConstant(DstReg, 0);
326         UpdatedDefs.push_back(DstReg);
327       }
328 
329       markInstAndDefDead(MI, *DefMI, DeadInsts);
330       return true;
331     }
332     return false;
333   }
334 
335   bool tryFoldUnmergeCast(MachineInstr &MI, MachineInstr &CastMI,
336                           SmallVectorImpl<MachineInstr *> &DeadInsts,
337                           SmallVectorImpl<Register> &UpdatedDefs) {
338 
339     assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES);
340 
341     const unsigned CastOpc = CastMI.getOpcode();
342 
343     if (!isArtifactCast(CastOpc))
344       return false;
345 
346     const unsigned NumDefs = MI.getNumOperands() - 1;
347 
348     const Register CastSrcReg = CastMI.getOperand(1).getReg();
349     const LLT CastSrcTy = MRI.getType(CastSrcReg);
350     const LLT DestTy = MRI.getType(MI.getOperand(0).getReg());
351     const LLT SrcTy = MRI.getType(MI.getOperand(NumDefs).getReg());
352 
353     const unsigned CastSrcSize = CastSrcTy.getSizeInBits();
354     const unsigned DestSize = DestTy.getSizeInBits();
355 
356     if (CastOpc == TargetOpcode::G_TRUNC) {
357       if (SrcTy.isVector() && SrcTy.getScalarType() == DestTy.getScalarType()) {
358         //  %1:_(<4 x s8>) = G_TRUNC %0(<4 x s32>)
359         //  %2:_(s8), %3:_(s8), %4:_(s8), %5:_(s8) = G_UNMERGE_VALUES %1
360         // =>
361         //  %6:_(s32), %7:_(s32), %8:_(s32), %9:_(s32) = G_UNMERGE_VALUES %0
362         //  %2:_(s8) = G_TRUNC %6
363         //  %3:_(s8) = G_TRUNC %7
364         //  %4:_(s8) = G_TRUNC %8
365         //  %5:_(s8) = G_TRUNC %9
366 
367         unsigned UnmergeNumElts =
368             DestTy.isVector() ? CastSrcTy.getNumElements() / NumDefs : 1;
369         LLT UnmergeTy = CastSrcTy.changeNumElements(UnmergeNumElts);
370 
371         if (isInstUnsupported(
372                 {TargetOpcode::G_UNMERGE_VALUES, {UnmergeTy, CastSrcTy}}))
373           return false;
374 
375         Builder.setInstr(MI);
376         auto NewUnmerge = Builder.buildUnmerge(UnmergeTy, CastSrcReg);
377 
378         for (unsigned I = 0; I != NumDefs; ++I) {
379           Register DefReg = MI.getOperand(I).getReg();
380           UpdatedDefs.push_back(DefReg);
381           Builder.buildTrunc(DefReg, NewUnmerge.getReg(I));
382         }
383 
384         markInstAndDefDead(MI, CastMI, DeadInsts);
385         return true;
386       }
387 
388       if (CastSrcTy.isScalar() && SrcTy.isScalar() && !DestTy.isVector()) {
389         //  %1:_(s16) = G_TRUNC %0(s32)
390         //  %2:_(s8), %3:_(s8) = G_UNMERGE_VALUES %1
391         // =>
392         //  %2:_(s8), %3:_(s8), %4:_(s8), %5:_(s8) = G_UNMERGE_VALUES %0
393 
394         // Unmerge(trunc) can be combined if the trunc source size is a multiple
395         // of the unmerge destination size
396         if (CastSrcSize % DestSize != 0)
397           return false;
398 
399         // Check if the new unmerge is supported
400         if (isInstUnsupported(
401                 {TargetOpcode::G_UNMERGE_VALUES, {DestTy, CastSrcTy}}))
402           return false;
403 
404         // Gather the original destination registers and create new ones for the
405         // unused bits
406         const unsigned NewNumDefs = CastSrcSize / DestSize;
407         SmallVector<Register, 8> DstRegs(NewNumDefs);
408         for (unsigned Idx = 0; Idx < NewNumDefs; ++Idx) {
409           if (Idx < NumDefs)
410             DstRegs[Idx] = MI.getOperand(Idx).getReg();
411           else
412             DstRegs[Idx] = MRI.createGenericVirtualRegister(DestTy);
413         }
414 
415         // Build new unmerge
416         Builder.setInstr(MI);
417         Builder.buildUnmerge(DstRegs, CastSrcReg);
418         UpdatedDefs.append(DstRegs.begin(), DstRegs.begin() + NewNumDefs);
419         markInstAndDefDead(MI, CastMI, DeadInsts);
420         return true;
421       }
422     }
423 
424     // TODO: support combines with other casts as well
425     return false;
426   }
427 
428   static bool canFoldMergeOpcode(unsigned MergeOp, unsigned ConvertOp,
429                                  LLT OpTy, LLT DestTy) {
430     // Check if we found a definition that is like G_MERGE_VALUES.
431     switch (MergeOp) {
432     default:
433       return false;
434     case TargetOpcode::G_BUILD_VECTOR:
435     case TargetOpcode::G_MERGE_VALUES:
436       // The convert operation that we will need to insert is
437       // going to convert the input of that type of instruction (scalar)
438       // to the destination type (DestTy).
439       // The conversion needs to stay in the same domain (scalar to scalar
440       // and vector to vector), so if we were to allow to fold the merge
441       // we would need to insert some bitcasts.
442       // E.g.,
443       // <2 x s16> = build_vector s16, s16
444       // <2 x s32> = zext <2 x s16>
445       // <2 x s16>, <2 x s16> = unmerge <2 x s32>
446       //
447       // As is the folding would produce:
448       // <2 x s16> = zext s16  <-- scalar to vector
449       // <2 x s16> = zext s16  <-- scalar to vector
450       // Which is invalid.
451       // Instead we would want to generate:
452       // s32 = zext s16
453       // <2 x s16> = bitcast s32
454       // s32 = zext s16
455       // <2 x s16> = bitcast s32
456       //
457       // That is not done yet.
458       if (ConvertOp == 0)
459         return true;
460       return !DestTy.isVector() && OpTy.isVector();
461     case TargetOpcode::G_CONCAT_VECTORS: {
462       if (ConvertOp == 0)
463         return true;
464       if (!DestTy.isVector())
465         return false;
466 
467       const unsigned OpEltSize = OpTy.getElementType().getSizeInBits();
468 
469       // Don't handle scalarization with a cast that isn't in the same
470       // direction as the vector cast. This could be handled, but it would
471       // require more intermediate unmerges.
472       if (ConvertOp == TargetOpcode::G_TRUNC)
473         return DestTy.getSizeInBits() <= OpEltSize;
474       return DestTy.getSizeInBits() >= OpEltSize;
475     }
476     }
477   }
478 
479   /// Try to replace DstReg with SrcReg or build a COPY instruction
480   /// depending on the register constraints.
481   static void replaceRegOrBuildCopy(Register DstReg, Register SrcReg,
482                                     MachineRegisterInfo &MRI,
483                                     MachineIRBuilder &Builder,
484                                     SmallVectorImpl<Register> &UpdatedDefs,
485                                     GISelObserverWrapper &Observer) {
486     if (!llvm::canReplaceReg(DstReg, SrcReg, MRI)) {
487       Builder.buildCopy(DstReg, SrcReg);
488       UpdatedDefs.push_back(DstReg);
489       return;
490     }
491     SmallVector<MachineInstr *, 4> UseMIs;
492     // Get the users and notify the observer before replacing.
493     for (auto &UseMI : MRI.use_instructions(DstReg)) {
494       UseMIs.push_back(&UseMI);
495       Observer.changingInstr(UseMI);
496     }
497     // Replace the registers.
498     MRI.replaceRegWith(DstReg, SrcReg);
499     UpdatedDefs.push_back(SrcReg);
500     // Notify the observer that we changed the instructions.
501     for (auto *UseMI : UseMIs)
502       Observer.changedInstr(*UseMI);
503   }
504 
505   bool tryCombineMerges(MachineInstr &MI,
506                         SmallVectorImpl<MachineInstr *> &DeadInsts,
507                         SmallVectorImpl<Register> &UpdatedDefs,
508                         GISelObserverWrapper &Observer) {
509     assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES);
510 
511     unsigned NumDefs = MI.getNumOperands() - 1;
512     MachineInstr *SrcDef =
513         getDefIgnoringCopies(MI.getOperand(NumDefs).getReg(), MRI);
514     if (!SrcDef)
515       return false;
516 
517     LLT OpTy = MRI.getType(MI.getOperand(NumDefs).getReg());
518     LLT DestTy = MRI.getType(MI.getOperand(0).getReg());
519     MachineInstr *MergeI = SrcDef;
520     unsigned ConvertOp = 0;
521 
522     // Handle intermediate conversions
523     unsigned SrcOp = SrcDef->getOpcode();
524     if (isArtifactCast(SrcOp)) {
525       ConvertOp = SrcOp;
526       MergeI = getDefIgnoringCopies(SrcDef->getOperand(1).getReg(), MRI);
527     }
528 
529     if (!MergeI || !canFoldMergeOpcode(MergeI->getOpcode(),
530                                        ConvertOp, OpTy, DestTy)) {
531       // We might have a chance to combine later by trying to combine
532       // unmerge(cast) first
533       return tryFoldUnmergeCast(MI, *SrcDef, DeadInsts, UpdatedDefs);
534     }
535 
536     const unsigned NumMergeRegs = MergeI->getNumOperands() - 1;
537 
538     if (NumMergeRegs < NumDefs) {
539       if (NumDefs % NumMergeRegs != 0)
540         return false;
541 
542       Builder.setInstr(MI);
543       // Transform to UNMERGEs, for example
544       //   %1 = G_MERGE_VALUES %4, %5
545       //   %9, %10, %11, %12 = G_UNMERGE_VALUES %1
546       // to
547       //   %9, %10 = G_UNMERGE_VALUES %4
548       //   %11, %12 = G_UNMERGE_VALUES %5
549 
550       const unsigned NewNumDefs = NumDefs / NumMergeRegs;
551       for (unsigned Idx = 0; Idx < NumMergeRegs; ++Idx) {
552         SmallVector<Register, 8> DstRegs;
553         for (unsigned j = 0, DefIdx = Idx * NewNumDefs; j < NewNumDefs;
554              ++j, ++DefIdx)
555           DstRegs.push_back(MI.getOperand(DefIdx).getReg());
556 
557         if (ConvertOp) {
558           LLT MergeSrcTy = MRI.getType(MergeI->getOperand(1).getReg());
559 
560           // This is a vector that is being split and casted. Extract to the
561           // element type, and do the conversion on the scalars (or smaller
562           // vectors).
563           LLT MergeEltTy = MergeSrcTy.divide(NewNumDefs);
564 
565           // Handle split to smaller vectors, with conversions.
566           // %2(<8 x s8>) = G_CONCAT_VECTORS %0(<4 x s8>), %1(<4 x s8>)
567           // %3(<8 x s16>) = G_SEXT %2
568           // %4(<2 x s16>), %5(<2 x s16>), %6(<2 x s16>), %7(<2 x s16>) = G_UNMERGE_VALUES %3
569           //
570           // =>
571           //
572           // %8(<2 x s8>), %9(<2 x s8>) = G_UNMERGE_VALUES %0
573           // %10(<2 x s8>), %11(<2 x s8>) = G_UNMERGE_VALUES %1
574           // %4(<2 x s16>) = G_SEXT %8
575           // %5(<2 x s16>) = G_SEXT %9
576           // %6(<2 x s16>) = G_SEXT %10
577           // %7(<2 x s16>)= G_SEXT %11
578 
579           SmallVector<Register, 4> TmpRegs(NewNumDefs);
580           for (unsigned k = 0; k < NewNumDefs; ++k)
581             TmpRegs[k] = MRI.createGenericVirtualRegister(MergeEltTy);
582 
583           Builder.buildUnmerge(TmpRegs, MergeI->getOperand(Idx + 1).getReg());
584 
585           for (unsigned k = 0; k < NewNumDefs; ++k)
586             Builder.buildInstr(ConvertOp, {DstRegs[k]}, {TmpRegs[k]});
587         } else {
588           Builder.buildUnmerge(DstRegs, MergeI->getOperand(Idx + 1).getReg());
589         }
590         UpdatedDefs.append(DstRegs.begin(), DstRegs.end());
591       }
592 
593     } else if (NumMergeRegs > NumDefs) {
594       if (ConvertOp != 0 || NumMergeRegs % NumDefs != 0)
595         return false;
596 
597       Builder.setInstr(MI);
598       // Transform to MERGEs
599       //   %6 = G_MERGE_VALUES %17, %18, %19, %20
600       //   %7, %8 = G_UNMERGE_VALUES %6
601       // to
602       //   %7 = G_MERGE_VALUES %17, %18
603       //   %8 = G_MERGE_VALUES %19, %20
604 
605       const unsigned NumRegs = NumMergeRegs / NumDefs;
606       for (unsigned DefIdx = 0; DefIdx < NumDefs; ++DefIdx) {
607         SmallVector<Register, 8> Regs;
608         for (unsigned j = 0, Idx = NumRegs * DefIdx + 1; j < NumRegs;
609              ++j, ++Idx)
610           Regs.push_back(MergeI->getOperand(Idx).getReg());
611 
612         Register DefReg = MI.getOperand(DefIdx).getReg();
613         Builder.buildMerge(DefReg, Regs);
614         UpdatedDefs.push_back(DefReg);
615       }
616 
617     } else {
618       LLT MergeSrcTy = MRI.getType(MergeI->getOperand(1).getReg());
619 
620       if (!ConvertOp && DestTy != MergeSrcTy)
621         ConvertOp = TargetOpcode::G_BITCAST;
622 
623       if (ConvertOp) {
624         Builder.setInstr(MI);
625 
626         for (unsigned Idx = 0; Idx < NumDefs; ++Idx) {
627           Register MergeSrc = MergeI->getOperand(Idx + 1).getReg();
628           Register DefReg = MI.getOperand(Idx).getReg();
629           Builder.buildInstr(ConvertOp, {DefReg}, {MergeSrc});
630           UpdatedDefs.push_back(DefReg);
631         }
632 
633         markInstAndDefDead(MI, *MergeI, DeadInsts);
634         return true;
635       }
636 
637       assert(DestTy == MergeSrcTy &&
638              "Bitcast and the other kinds of conversions should "
639              "have happened earlier");
640 
641       Builder.setInstr(MI);
642       for (unsigned Idx = 0; Idx < NumDefs; ++Idx) {
643         Register DstReg = MI.getOperand(Idx).getReg();
644         Register SrcReg = MergeI->getOperand(Idx + 1).getReg();
645         replaceRegOrBuildCopy(DstReg, SrcReg, MRI, Builder, UpdatedDefs,
646                               Observer);
647       }
648     }
649 
650     markInstAndDefDead(MI, *MergeI, DeadInsts);
651     return true;
652   }
653 
654   static bool isMergeLikeOpcode(unsigned Opc) {
655     switch (Opc) {
656     case TargetOpcode::G_MERGE_VALUES:
657     case TargetOpcode::G_BUILD_VECTOR:
658     case TargetOpcode::G_CONCAT_VECTORS:
659       return true;
660     default:
661       return false;
662     }
663   }
664 
665   bool tryCombineExtract(MachineInstr &MI,
666                          SmallVectorImpl<MachineInstr *> &DeadInsts,
667                          SmallVectorImpl<Register> &UpdatedDefs) {
668     assert(MI.getOpcode() == TargetOpcode::G_EXTRACT);
669 
670     // Try to use the source registers from a G_MERGE_VALUES
671     //
672     // %2 = G_MERGE_VALUES %0, %1
673     // %3 = G_EXTRACT %2, N
674     // =>
675     //
676     // for N < %2.getSizeInBits() / 2
677     //     %3 = G_EXTRACT %0, N
678     //
679     // for N >= %2.getSizeInBits() / 2
680     //    %3 = G_EXTRACT %1, (N - %0.getSizeInBits()
681 
682     Register SrcReg = lookThroughCopyInstrs(MI.getOperand(1).getReg());
683     MachineInstr *MergeI = MRI.getVRegDef(SrcReg);
684     if (!MergeI || !isMergeLikeOpcode(MergeI->getOpcode()))
685       return false;
686 
687     Register DstReg = MI.getOperand(0).getReg();
688     LLT DstTy = MRI.getType(DstReg);
689     LLT SrcTy = MRI.getType(SrcReg);
690 
691     // TODO: Do we need to check if the resulting extract is supported?
692     unsigned ExtractDstSize = DstTy.getSizeInBits();
693     unsigned Offset = MI.getOperand(2).getImm();
694     unsigned NumMergeSrcs = MergeI->getNumOperands() - 1;
695     unsigned MergeSrcSize = SrcTy.getSizeInBits() / NumMergeSrcs;
696     unsigned MergeSrcIdx = Offset / MergeSrcSize;
697 
698     // Compute the offset of the last bit the extract needs.
699     unsigned EndMergeSrcIdx = (Offset + ExtractDstSize - 1) / MergeSrcSize;
700 
701     // Can't handle the case where the extract spans multiple inputs.
702     if (MergeSrcIdx != EndMergeSrcIdx)
703       return false;
704 
705     // TODO: We could modify MI in place in most cases.
706     Builder.setInstr(MI);
707     Builder.buildExtract(DstReg, MergeI->getOperand(MergeSrcIdx + 1).getReg(),
708                          Offset - MergeSrcIdx * MergeSrcSize);
709     UpdatedDefs.push_back(DstReg);
710     markInstAndDefDead(MI, *MergeI, DeadInsts);
711     return true;
712   }
713 
714   /// Try to combine away MI.
715   /// Returns true if it combined away the MI.
716   /// Adds instructions that are dead as a result of the combine
717   /// into DeadInsts, which can include MI.
718   bool tryCombineInstruction(MachineInstr &MI,
719                              SmallVectorImpl<MachineInstr *> &DeadInsts,
720                              GISelObserverWrapper &WrapperObserver) {
721     // This might be a recursive call, and we might have DeadInsts already
722     // populated. To avoid bad things happening later with multiple vreg defs
723     // etc, process the dead instructions now if any.
724     if (!DeadInsts.empty())
725       deleteMarkedDeadInsts(DeadInsts, WrapperObserver);
726 
727     // Put here every vreg that was redefined in such a way that it's at least
728     // possible that one (or more) of its users (immediate or COPY-separated)
729     // could become artifact combinable with the new definition (or the
730     // instruction reachable from it through a chain of copies if any).
731     SmallVector<Register, 4> UpdatedDefs;
732     bool Changed = false;
733     switch (MI.getOpcode()) {
734     default:
735       return false;
736     case TargetOpcode::G_ANYEXT:
737       Changed = tryCombineAnyExt(MI, DeadInsts, UpdatedDefs);
738       break;
739     case TargetOpcode::G_ZEXT:
740       Changed = tryCombineZExt(MI, DeadInsts, UpdatedDefs, WrapperObserver);
741       break;
742     case TargetOpcode::G_SEXT:
743       Changed = tryCombineSExt(MI, DeadInsts, UpdatedDefs);
744       break;
745     case TargetOpcode::G_UNMERGE_VALUES:
746       Changed = tryCombineMerges(MI, DeadInsts, UpdatedDefs, WrapperObserver);
747       break;
748     case TargetOpcode::G_MERGE_VALUES:
749       // If any of the users of this merge are an unmerge, then add them to the
750       // artifact worklist in case there's folding that can be done looking up.
751       for (MachineInstr &U : MRI.use_instructions(MI.getOperand(0).getReg())) {
752         if (U.getOpcode() == TargetOpcode::G_UNMERGE_VALUES ||
753             U.getOpcode() == TargetOpcode::G_TRUNC) {
754           UpdatedDefs.push_back(MI.getOperand(0).getReg());
755           break;
756         }
757       }
758       break;
759     case TargetOpcode::G_EXTRACT:
760       Changed = tryCombineExtract(MI, DeadInsts, UpdatedDefs);
761       break;
762     case TargetOpcode::G_TRUNC:
763       Changed = tryCombineTrunc(MI, DeadInsts, UpdatedDefs, WrapperObserver);
764       if (!Changed) {
765         // Try to combine truncates away even if they are legal. As all artifact
766         // combines at the moment look only "up" the def-use chains, we achieve
767         // that by throwing truncates' users (with look through copies) into the
768         // ArtifactList again.
769         UpdatedDefs.push_back(MI.getOperand(0).getReg());
770       }
771       break;
772     }
773     // If the main loop through the ArtifactList found at least one combinable
774     // pair of artifacts, not only combine it away (as done above), but also
775     // follow the def-use chain from there to combine everything that can be
776     // combined within this def-use chain of artifacts.
777     while (!UpdatedDefs.empty()) {
778       Register NewDef = UpdatedDefs.pop_back_val();
779       assert(NewDef.isVirtual() && "Unexpected redefinition of a physreg");
780       for (MachineInstr &Use : MRI.use_instructions(NewDef)) {
781         switch (Use.getOpcode()) {
782         // Keep this list in sync with the list of all artifact combines.
783         case TargetOpcode::G_ANYEXT:
784         case TargetOpcode::G_ZEXT:
785         case TargetOpcode::G_SEXT:
786         case TargetOpcode::G_UNMERGE_VALUES:
787         case TargetOpcode::G_EXTRACT:
788         case TargetOpcode::G_TRUNC:
789           // Adding Use to ArtifactList.
790           WrapperObserver.changedInstr(Use);
791           break;
792         case TargetOpcode::COPY: {
793           Register Copy = Use.getOperand(0).getReg();
794           if (Copy.isVirtual())
795             UpdatedDefs.push_back(Copy);
796           break;
797         }
798         default:
799           // If we do not have an artifact combine for the opcode, there is no
800           // point in adding it to the ArtifactList as nothing interesting will
801           // be done to it anyway.
802           break;
803         }
804       }
805     }
806     return Changed;
807   }
808 
809 private:
810   static Register getArtifactSrcReg(const MachineInstr &MI) {
811     switch (MI.getOpcode()) {
812     case TargetOpcode::COPY:
813     case TargetOpcode::G_TRUNC:
814     case TargetOpcode::G_ZEXT:
815     case TargetOpcode::G_ANYEXT:
816     case TargetOpcode::G_SEXT:
817     case TargetOpcode::G_EXTRACT:
818       return MI.getOperand(1).getReg();
819     case TargetOpcode::G_UNMERGE_VALUES:
820       return MI.getOperand(MI.getNumOperands() - 1).getReg();
821     default:
822       llvm_unreachable("Not a legalization artifact happen");
823     }
824   }
825 
826   /// Mark a def of one of MI's original operands, DefMI, as dead if changing MI
827   /// (either by killing it or changing operands) results in DefMI being dead
828   /// too. In-between COPYs or artifact-casts are also collected if they are
829   /// dead.
830   /// MI is not marked dead.
831   void markDefDead(MachineInstr &MI, MachineInstr &DefMI,
832                    SmallVectorImpl<MachineInstr *> &DeadInsts) {
833     // Collect all the copy instructions that are made dead, due to deleting
834     // this instruction. Collect all of them until the Trunc(DefMI).
835     // Eg,
836     // %1(s1) = G_TRUNC %0(s32)
837     // %2(s1) = COPY %1(s1)
838     // %3(s1) = COPY %2(s1)
839     // %4(s32) = G_ANYEXT %3(s1)
840     // In this case, we would have replaced %4 with a copy of %0,
841     // and as a result, %3, %2, %1 are dead.
842     MachineInstr *PrevMI = &MI;
843     while (PrevMI != &DefMI) {
844       Register PrevRegSrc = getArtifactSrcReg(*PrevMI);
845 
846       MachineInstr *TmpDef = MRI.getVRegDef(PrevRegSrc);
847       if (MRI.hasOneUse(PrevRegSrc)) {
848         if (TmpDef != &DefMI) {
849           assert((TmpDef->getOpcode() == TargetOpcode::COPY ||
850                   isArtifactCast(TmpDef->getOpcode())) &&
851                  "Expecting copy or artifact cast here");
852 
853           DeadInsts.push_back(TmpDef);
854         }
855       } else
856         break;
857       PrevMI = TmpDef;
858     }
859     if (PrevMI == &DefMI && MRI.hasOneUse(DefMI.getOperand(0).getReg()))
860       DeadInsts.push_back(&DefMI);
861   }
862 
863   /// Mark MI as dead. If a def of one of MI's operands, DefMI, would also be
864   /// dead due to MI being killed, then mark DefMI as dead too.
865   /// Some of the combines (extends(trunc)), try to walk through redundant
866   /// copies in between the extends and the truncs, and this attempts to collect
867   /// the in between copies if they're dead.
868   void markInstAndDefDead(MachineInstr &MI, MachineInstr &DefMI,
869                           SmallVectorImpl<MachineInstr *> &DeadInsts) {
870     DeadInsts.push_back(&MI);
871     markDefDead(MI, DefMI, DeadInsts);
872   }
873 
874   /// Erase the dead instructions in the list and call the observer hooks.
875   /// Normally the Legalizer will deal with erasing instructions that have been
876   /// marked dead. However, for the trunc(ext(x)) cases we can end up trying to
877   /// process instructions which have been marked dead, but otherwise break the
878   /// MIR by introducing multiple vreg defs. For those cases, allow the combines
879   /// to explicitly delete the instructions before we run into trouble.
880   void deleteMarkedDeadInsts(SmallVectorImpl<MachineInstr *> &DeadInsts,
881                              GISelObserverWrapper &WrapperObserver) {
882     for (auto *DeadMI : DeadInsts) {
883       LLVM_DEBUG(dbgs() << *DeadMI << "Is dead, eagerly deleting\n");
884       WrapperObserver.erasingInstr(*DeadMI);
885       DeadMI->eraseFromParentAndMarkDBGValuesForRemoval();
886     }
887     DeadInsts.clear();
888   }
889 
890   /// Checks if the target legalizer info has specified anything about the
891   /// instruction, or if unsupported.
892   bool isInstUnsupported(const LegalityQuery &Query) const {
893     using namespace LegalizeActions;
894     auto Step = LI.getAction(Query);
895     return Step.Action == Unsupported || Step.Action == NotFound;
896   }
897 
898   bool isInstLegal(const LegalityQuery &Query) const {
899     return LI.getAction(Query).Action == LegalizeActions::Legal;
900   }
901 
902   bool isConstantUnsupported(LLT Ty) const {
903     if (!Ty.isVector())
904       return isInstUnsupported({TargetOpcode::G_CONSTANT, {Ty}});
905 
906     LLT EltTy = Ty.getElementType();
907     return isInstUnsupported({TargetOpcode::G_CONSTANT, {EltTy}}) ||
908            isInstUnsupported({TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}});
909   }
910 
911   /// Looks through copy instructions and returns the actual
912   /// source register.
913   Register lookThroughCopyInstrs(Register Reg) {
914     Register TmpReg;
915     while (mi_match(Reg, MRI, m_Copy(m_Reg(TmpReg)))) {
916       if (MRI.getType(TmpReg).isValid())
917         Reg = TmpReg;
918       else
919         break;
920     }
921     return Reg;
922   }
923 };
924 
925 } // namespace llvm
926