1 //===- AArch64RegisterBankInfo.cpp ----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file implements the targeting of the RegisterBankInfo class for
10 /// AArch64.
11 /// \todo This should be generated by TableGen.
12 //===----------------------------------------------------------------------===//
13 
14 #include "AArch64RegisterBankInfo.h"
15 #include "AArch64InstrInfo.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/CodeGen/GlobalISel/RegisterBank.h"
19 #include "llvm/CodeGen/GlobalISel/RegisterBankInfo.h"
20 #include "llvm/CodeGen/LowLevelType.h"
21 #include "llvm/CodeGen/MachineFunction.h"
22 #include "llvm/CodeGen/MachineInstr.h"
23 #include "llvm/CodeGen/MachineOperand.h"
24 #include "llvm/CodeGen/MachineRegisterInfo.h"
25 #include "llvm/CodeGen/TargetOpcodes.h"
26 #include "llvm/CodeGen/TargetRegisterInfo.h"
27 #include "llvm/CodeGen/TargetSubtargetInfo.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include <algorithm>
30 #include <cassert>
31 
32 #define GET_TARGET_REGBANK_IMPL
33 #include "AArch64GenRegisterBank.inc"
34 
35 // This file will be TableGen'ed at some point.
36 #include "AArch64GenRegisterBankInfo.def"
37 
38 using namespace llvm;
39 
40 AArch64RegisterBankInfo::AArch64RegisterBankInfo(const TargetRegisterInfo &TRI)
41     : AArch64GenRegisterBankInfo() {
42   static llvm::once_flag InitializeRegisterBankFlag;
43 
44   static auto InitializeRegisterBankOnce = [&]() {
45     // We have only one set of register banks, whatever the subtarget
46     // is. Therefore, the initialization of the RegBanks table should be
47     // done only once. Indeed the table of all register banks
48     // (AArch64::RegBanks) is unique in the compiler. At some point, it
49     // will get tablegen'ed and the whole constructor becomes empty.
50 
51     const RegisterBank &RBGPR = getRegBank(AArch64::GPRRegBankID);
52     (void)RBGPR;
53     assert(&AArch64::GPRRegBank == &RBGPR &&
54            "The order in RegBanks is messed up");
55 
56     const RegisterBank &RBFPR = getRegBank(AArch64::FPRRegBankID);
57     (void)RBFPR;
58     assert(&AArch64::FPRRegBank == &RBFPR &&
59            "The order in RegBanks is messed up");
60 
61     const RegisterBank &RBCCR = getRegBank(AArch64::CCRegBankID);
62     (void)RBCCR;
63     assert(&AArch64::CCRegBank == &RBCCR &&
64            "The order in RegBanks is messed up");
65 
66     // The GPR register bank is fully defined by all the registers in
67     // GR64all + its subclasses.
68     assert(RBGPR.covers(*TRI.getRegClass(AArch64::GPR32RegClassID)) &&
69            "Subclass not added?");
70     assert(RBGPR.getSize() == 64 && "GPRs should hold up to 64-bit");
71 
72     // The FPR register bank is fully defined by all the registers in
73     // GR64all + its subclasses.
74     assert(RBFPR.covers(*TRI.getRegClass(AArch64::QQRegClassID)) &&
75            "Subclass not added?");
76     assert(RBFPR.covers(*TRI.getRegClass(AArch64::FPR64RegClassID)) &&
77            "Subclass not added?");
78     assert(RBFPR.getSize() == 512 &&
79            "FPRs should hold up to 512-bit via QQQQ sequence");
80 
81     assert(RBCCR.covers(*TRI.getRegClass(AArch64::CCRRegClassID)) &&
82            "Class not added?");
83     assert(RBCCR.getSize() == 32 && "CCR should hold up to 32-bit");
84 
85     // Check that the TableGen'ed like file is in sync we our expectations.
86     // First, the Idx.
87     assert(checkPartialMappingIdx(PMI_FirstGPR, PMI_LastGPR,
88                                   {PMI_GPR32, PMI_GPR64}) &&
89            "PartialMappingIdx's are incorrectly ordered");
90     assert(checkPartialMappingIdx(PMI_FirstFPR, PMI_LastFPR,
91                                   {PMI_FPR16, PMI_FPR32, PMI_FPR64, PMI_FPR128,
92                                    PMI_FPR256, PMI_FPR512}) &&
93            "PartialMappingIdx's are incorrectly ordered");
94 // Now, the content.
95 // Check partial mapping.
96 #define CHECK_PARTIALMAP(Idx, ValStartIdx, ValLength, RB)                      \
97   do {                                                                         \
98     assert(                                                                    \
99         checkPartialMap(PartialMappingIdx::Idx, ValStartIdx, ValLength, RB) && \
100         #Idx " is incorrectly initialized");                                   \
101   } while (false)
102 
103     CHECK_PARTIALMAP(PMI_GPR32, 0, 32, RBGPR);
104     CHECK_PARTIALMAP(PMI_GPR64, 0, 64, RBGPR);
105     CHECK_PARTIALMAP(PMI_FPR16, 0, 16, RBFPR);
106     CHECK_PARTIALMAP(PMI_FPR32, 0, 32, RBFPR);
107     CHECK_PARTIALMAP(PMI_FPR64, 0, 64, RBFPR);
108     CHECK_PARTIALMAP(PMI_FPR128, 0, 128, RBFPR);
109     CHECK_PARTIALMAP(PMI_FPR256, 0, 256, RBFPR);
110     CHECK_PARTIALMAP(PMI_FPR512, 0, 512, RBFPR);
111 
112 // Check value mapping.
113 #define CHECK_VALUEMAP_IMPL(RBName, Size, Offset)                              \
114   do {                                                                         \
115     assert(checkValueMapImpl(PartialMappingIdx::PMI_##RBName##Size,            \
116                              PartialMappingIdx::PMI_First##RBName, Size,       \
117                              Offset) &&                                        \
118            #RBName #Size " " #Offset " is incorrectly initialized");           \
119   } while (false)
120 
121 #define CHECK_VALUEMAP(RBName, Size) CHECK_VALUEMAP_IMPL(RBName, Size, 0)
122 
123     CHECK_VALUEMAP(GPR, 32);
124     CHECK_VALUEMAP(GPR, 64);
125     CHECK_VALUEMAP(FPR, 16);
126     CHECK_VALUEMAP(FPR, 32);
127     CHECK_VALUEMAP(FPR, 64);
128     CHECK_VALUEMAP(FPR, 128);
129     CHECK_VALUEMAP(FPR, 256);
130     CHECK_VALUEMAP(FPR, 512);
131 
132 // Check the value mapping for 3-operands instructions where all the operands
133 // map to the same value mapping.
134 #define CHECK_VALUEMAP_3OPS(RBName, Size)                                      \
135   do {                                                                         \
136     CHECK_VALUEMAP_IMPL(RBName, Size, 0);                                      \
137     CHECK_VALUEMAP_IMPL(RBName, Size, 1);                                      \
138     CHECK_VALUEMAP_IMPL(RBName, Size, 2);                                      \
139   } while (false)
140 
141     CHECK_VALUEMAP_3OPS(GPR, 32);
142     CHECK_VALUEMAP_3OPS(GPR, 64);
143     CHECK_VALUEMAP_3OPS(FPR, 32);
144     CHECK_VALUEMAP_3OPS(FPR, 64);
145     CHECK_VALUEMAP_3OPS(FPR, 128);
146     CHECK_VALUEMAP_3OPS(FPR, 256);
147     CHECK_VALUEMAP_3OPS(FPR, 512);
148 
149 #define CHECK_VALUEMAP_CROSSREGCPY(RBNameDst, RBNameSrc, Size)                 \
150   do {                                                                         \
151     unsigned PartialMapDstIdx = PMI_##RBNameDst##Size - PMI_Min;               \
152     unsigned PartialMapSrcIdx = PMI_##RBNameSrc##Size - PMI_Min;               \
153     (void)PartialMapDstIdx;                                                    \
154     (void)PartialMapSrcIdx;                                                    \
155     const ValueMapping *Map = getCopyMapping(                                  \
156         AArch64::RBNameDst##RegBankID, AArch64::RBNameSrc##RegBankID, Size);  \
157     (void)Map;                                                                 \
158     assert(Map[0].BreakDown ==                                                 \
159                &AArch64GenRegisterBankInfo::PartMappings[PartialMapDstIdx] &&  \
160            Map[0].NumBreakDowns == 1 && #RBNameDst #Size                       \
161            " Dst is incorrectly initialized");                                 \
162     assert(Map[1].BreakDown ==                                                 \
163                &AArch64GenRegisterBankInfo::PartMappings[PartialMapSrcIdx] &&  \
164            Map[1].NumBreakDowns == 1 && #RBNameSrc #Size                       \
165            " Src is incorrectly initialized");                                 \
166                                                                                \
167   } while (false)
168 
169     CHECK_VALUEMAP_CROSSREGCPY(GPR, GPR, 32);
170     CHECK_VALUEMAP_CROSSREGCPY(GPR, FPR, 32);
171     CHECK_VALUEMAP_CROSSREGCPY(GPR, GPR, 64);
172     CHECK_VALUEMAP_CROSSREGCPY(GPR, FPR, 64);
173     CHECK_VALUEMAP_CROSSREGCPY(FPR, FPR, 32);
174     CHECK_VALUEMAP_CROSSREGCPY(FPR, GPR, 32);
175     CHECK_VALUEMAP_CROSSREGCPY(FPR, FPR, 64);
176     CHECK_VALUEMAP_CROSSREGCPY(FPR, GPR, 64);
177 
178 #define CHECK_VALUEMAP_FPEXT(DstSize, SrcSize)                                 \
179   do {                                                                         \
180     unsigned PartialMapDstIdx = PMI_FPR##DstSize - PMI_Min;                    \
181     unsigned PartialMapSrcIdx = PMI_FPR##SrcSize - PMI_Min;                    \
182     (void)PartialMapDstIdx;                                                    \
183     (void)PartialMapSrcIdx;                                                    \
184     const ValueMapping *Map = getFPExtMapping(DstSize, SrcSize);               \
185     (void)Map;                                                                 \
186     assert(Map[0].BreakDown ==                                                 \
187                &AArch64GenRegisterBankInfo::PartMappings[PartialMapDstIdx] &&  \
188            Map[0].NumBreakDowns == 1 && "FPR" #DstSize                         \
189                                         " Dst is incorrectly initialized");    \
190     assert(Map[1].BreakDown ==                                                 \
191                &AArch64GenRegisterBankInfo::PartMappings[PartialMapSrcIdx] &&  \
192            Map[1].NumBreakDowns == 1 && "FPR" #SrcSize                         \
193                                         " Src is incorrectly initialized");    \
194                                                                                \
195   } while (false)
196 
197     CHECK_VALUEMAP_FPEXT(32, 16);
198     CHECK_VALUEMAP_FPEXT(64, 16);
199     CHECK_VALUEMAP_FPEXT(64, 32);
200     CHECK_VALUEMAP_FPEXT(128, 64);
201 
202     assert(verify(TRI) && "Invalid register bank information");
203   };
204 
205   llvm::call_once(InitializeRegisterBankFlag, InitializeRegisterBankOnce);
206 }
207 
208 unsigned AArch64RegisterBankInfo::copyCost(const RegisterBank &A,
209                                            const RegisterBank &B,
210                                            unsigned Size) const {
211   // What do we do with different size?
212   // copy are same size.
213   // Will introduce other hooks for different size:
214   // * extract cost.
215   // * build_sequence cost.
216 
217   // Copy from (resp. to) GPR to (resp. from) FPR involves FMOV.
218   // FIXME: This should be deduced from the scheduling model.
219   if (&A == &AArch64::GPRRegBank && &B == &AArch64::FPRRegBank)
220     // FMOVXDr or FMOVWSr.
221     return 5;
222   if (&A == &AArch64::FPRRegBank && &B == &AArch64::GPRRegBank)
223     // FMOVDXr or FMOVSWr.
224     return 4;
225 
226   return RegisterBankInfo::copyCost(A, B, Size);
227 }
228 
229 const RegisterBank &
230 AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
231                                                 LLT) const {
232   switch (RC.getID()) {
233   case AArch64::FPR8RegClassID:
234   case AArch64::FPR16RegClassID:
235   case AArch64::FPR16_loRegClassID:
236   case AArch64::FPR32_with_hsub_in_FPR16_loRegClassID:
237   case AArch64::FPR32RegClassID:
238   case AArch64::FPR64RegClassID:
239   case AArch64::FPR64_loRegClassID:
240   case AArch64::FPR128RegClassID:
241   case AArch64::FPR128_loRegClassID:
242   case AArch64::DDRegClassID:
243   case AArch64::DDDRegClassID:
244   case AArch64::DDDDRegClassID:
245   case AArch64::QQRegClassID:
246   case AArch64::QQQRegClassID:
247   case AArch64::QQQQRegClassID:
248     return getRegBank(AArch64::FPRRegBankID);
249   case AArch64::GPR32commonRegClassID:
250   case AArch64::GPR32RegClassID:
251   case AArch64::GPR32spRegClassID:
252   case AArch64::GPR32sponlyRegClassID:
253   case AArch64::GPR32argRegClassID:
254   case AArch64::GPR32allRegClassID:
255   case AArch64::GPR64commonRegClassID:
256   case AArch64::GPR64RegClassID:
257   case AArch64::GPR64spRegClassID:
258   case AArch64::GPR64sponlyRegClassID:
259   case AArch64::GPR64argRegClassID:
260   case AArch64::GPR64allRegClassID:
261   case AArch64::GPR64noipRegClassID:
262   case AArch64::GPR64common_and_GPR64noipRegClassID:
263   case AArch64::GPR64noip_and_tcGPR64RegClassID:
264   case AArch64::tcGPR64RegClassID:
265   case AArch64::rtcGPR64RegClassID:
266   case AArch64::WSeqPairsClassRegClassID:
267   case AArch64::XSeqPairsClassRegClassID:
268     return getRegBank(AArch64::GPRRegBankID);
269   case AArch64::CCRRegClassID:
270     return getRegBank(AArch64::CCRegBankID);
271   default:
272     llvm_unreachable("Register class not supported");
273   }
274 }
275 
276 RegisterBankInfo::InstructionMappings
277 AArch64RegisterBankInfo::getInstrAlternativeMappings(
278     const MachineInstr &MI) const {
279   const MachineFunction &MF = *MI.getParent()->getParent();
280   const TargetSubtargetInfo &STI = MF.getSubtarget();
281   const TargetRegisterInfo &TRI = *STI.getRegisterInfo();
282   const MachineRegisterInfo &MRI = MF.getRegInfo();
283 
284   switch (MI.getOpcode()) {
285   case TargetOpcode::G_OR: {
286     // 32 and 64-bit or can be mapped on either FPR or
287     // GPR for the same cost.
288     unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
289     if (Size != 32 && Size != 64)
290       break;
291 
292     // If the instruction has any implicit-defs or uses,
293     // do not mess with it.
294     if (MI.getNumOperands() != 3)
295       break;
296     InstructionMappings AltMappings;
297     const InstructionMapping &GPRMapping = getInstructionMapping(
298         /*ID*/ 1, /*Cost*/ 1, getValueMapping(PMI_FirstGPR, Size),
299         /*NumOperands*/ 3);
300     const InstructionMapping &FPRMapping = getInstructionMapping(
301         /*ID*/ 2, /*Cost*/ 1, getValueMapping(PMI_FirstFPR, Size),
302         /*NumOperands*/ 3);
303 
304     AltMappings.push_back(&GPRMapping);
305     AltMappings.push_back(&FPRMapping);
306     return AltMappings;
307   }
308   case TargetOpcode::G_BITCAST: {
309     unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
310     if (Size != 32 && Size != 64)
311       break;
312 
313     // If the instruction has any implicit-defs or uses,
314     // do not mess with it.
315     if (MI.getNumOperands() != 2)
316       break;
317 
318     InstructionMappings AltMappings;
319     const InstructionMapping &GPRMapping = getInstructionMapping(
320         /*ID*/ 1, /*Cost*/ 1,
321         getCopyMapping(AArch64::GPRRegBankID, AArch64::GPRRegBankID, Size),
322         /*NumOperands*/ 2);
323     const InstructionMapping &FPRMapping = getInstructionMapping(
324         /*ID*/ 2, /*Cost*/ 1,
325         getCopyMapping(AArch64::FPRRegBankID, AArch64::FPRRegBankID, Size),
326         /*NumOperands*/ 2);
327     const InstructionMapping &GPRToFPRMapping = getInstructionMapping(
328         /*ID*/ 3,
329         /*Cost*/ copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank, Size),
330         getCopyMapping(AArch64::FPRRegBankID, AArch64::GPRRegBankID, Size),
331         /*NumOperands*/ 2);
332     const InstructionMapping &FPRToGPRMapping = getInstructionMapping(
333         /*ID*/ 3,
334         /*Cost*/ copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank, Size),
335         getCopyMapping(AArch64::GPRRegBankID, AArch64::FPRRegBankID, Size),
336         /*NumOperands*/ 2);
337 
338     AltMappings.push_back(&GPRMapping);
339     AltMappings.push_back(&FPRMapping);
340     AltMappings.push_back(&GPRToFPRMapping);
341     AltMappings.push_back(&FPRToGPRMapping);
342     return AltMappings;
343   }
344   case TargetOpcode::G_LOAD: {
345     unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
346     if (Size != 64)
347       break;
348 
349     // If the instruction has any implicit-defs or uses,
350     // do not mess with it.
351     if (MI.getNumOperands() != 2)
352       break;
353 
354     InstructionMappings AltMappings;
355     const InstructionMapping &GPRMapping = getInstructionMapping(
356         /*ID*/ 1, /*Cost*/ 1,
357         getOperandsMapping({getValueMapping(PMI_FirstGPR, Size),
358                             // Addresses are GPR 64-bit.
359                             getValueMapping(PMI_FirstGPR, 64)}),
360         /*NumOperands*/ 2);
361     const InstructionMapping &FPRMapping = getInstructionMapping(
362         /*ID*/ 2, /*Cost*/ 1,
363         getOperandsMapping({getValueMapping(PMI_FirstFPR, Size),
364                             // Addresses are GPR 64-bit.
365                             getValueMapping(PMI_FirstGPR, 64)}),
366         /*NumOperands*/ 2);
367 
368     AltMappings.push_back(&GPRMapping);
369     AltMappings.push_back(&FPRMapping);
370     return AltMappings;
371   }
372   default:
373     break;
374   }
375   return RegisterBankInfo::getInstrAlternativeMappings(MI);
376 }
377 
378 void AArch64RegisterBankInfo::applyMappingImpl(
379     const OperandsMapper &OpdMapper) const {
380   switch (OpdMapper.getMI().getOpcode()) {
381   case TargetOpcode::G_OR:
382   case TargetOpcode::G_BITCAST:
383   case TargetOpcode::G_LOAD:
384     // Those ID must match getInstrAlternativeMappings.
385     assert((OpdMapper.getInstrMapping().getID() >= 1 &&
386             OpdMapper.getInstrMapping().getID() <= 4) &&
387            "Don't know how to handle that ID");
388     return applyDefaultMapping(OpdMapper);
389   default:
390     llvm_unreachable("Don't know how to handle that operation");
391   }
392 }
393 
394 /// Returns whether opcode \p Opc is a pre-isel generic floating-point opcode,
395 /// having only floating-point operands.
396 static bool isPreISelGenericFloatingPointOpcode(unsigned Opc) {
397   switch (Opc) {
398   case TargetOpcode::G_FADD:
399   case TargetOpcode::G_FSUB:
400   case TargetOpcode::G_FMUL:
401   case TargetOpcode::G_FMA:
402   case TargetOpcode::G_FDIV:
403   case TargetOpcode::G_FCONSTANT:
404   case TargetOpcode::G_FPEXT:
405   case TargetOpcode::G_FPTRUNC:
406   case TargetOpcode::G_FCEIL:
407   case TargetOpcode::G_FFLOOR:
408   case TargetOpcode::G_FNEARBYINT:
409   case TargetOpcode::G_FNEG:
410   case TargetOpcode::G_FCOS:
411   case TargetOpcode::G_FSIN:
412   case TargetOpcode::G_FLOG10:
413   case TargetOpcode::G_FLOG:
414   case TargetOpcode::G_FLOG2:
415   case TargetOpcode::G_FSQRT:
416   case TargetOpcode::G_FABS:
417   case TargetOpcode::G_FEXP:
418   case TargetOpcode::G_FRINT:
419   case TargetOpcode::G_INTRINSIC_TRUNC:
420   case TargetOpcode::G_INTRINSIC_ROUND:
421     return true;
422   }
423   return false;
424 }
425 
426 const RegisterBankInfo::InstructionMapping &
427 AArch64RegisterBankInfo::getSameKindOfOperandsMapping(
428     const MachineInstr &MI) const {
429   const unsigned Opc = MI.getOpcode();
430   const MachineFunction &MF = *MI.getParent()->getParent();
431   const MachineRegisterInfo &MRI = MF.getRegInfo();
432 
433   unsigned NumOperands = MI.getNumOperands();
434   assert(NumOperands <= 3 &&
435          "This code is for instructions with 3 or less operands");
436 
437   LLT Ty = MRI.getType(MI.getOperand(0).getReg());
438   unsigned Size = Ty.getSizeInBits();
439   bool IsFPR = Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc);
440 
441   PartialMappingIdx RBIdx = IsFPR ? PMI_FirstFPR : PMI_FirstGPR;
442 
443 #ifndef NDEBUG
444   // Make sure all the operands are using similar size and type.
445   // Should probably be checked by the machine verifier.
446   // This code won't catch cases where the number of lanes is
447   // different between the operands.
448   // If we want to go to that level of details, it is probably
449   // best to check that the types are the same, period.
450   // Currently, we just check that the register banks are the same
451   // for each types.
452   for (unsigned Idx = 1; Idx != NumOperands; ++Idx) {
453     LLT OpTy = MRI.getType(MI.getOperand(Idx).getReg());
454     assert(
455         AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(
456             RBIdx, OpTy.getSizeInBits()) ==
457             AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(RBIdx, Size) &&
458         "Operand has incompatible size");
459     bool OpIsFPR = OpTy.isVector() || isPreISelGenericFloatingPointOpcode(Opc);
460     (void)OpIsFPR;
461     assert(IsFPR == OpIsFPR && "Operand has incompatible type");
462   }
463 #endif // End NDEBUG.
464 
465   return getInstructionMapping(DefaultMappingID, 1,
466                                getValueMapping(RBIdx, Size), NumOperands);
467 }
468 
469 bool AArch64RegisterBankInfo::hasFPConstraints(const MachineInstr &MI,
470                                                const MachineRegisterInfo &MRI,
471                                                const TargetRegisterInfo &TRI,
472                                                unsigned Depth) const {
473   unsigned Op = MI.getOpcode();
474 
475   // Do we have an explicit floating point instruction?
476   if (isPreISelGenericFloatingPointOpcode(Op))
477     return true;
478 
479   // No. Check if we have a copy-like instruction. If we do, then we could
480   // still be fed by floating point instructions.
481   if (Op != TargetOpcode::COPY && !MI.isPHI())
482     return false;
483 
484   // Check if we already know the register bank.
485   auto *RB = getRegBank(MI.getOperand(0).getReg(), MRI, TRI);
486   if (RB == &AArch64::FPRRegBank)
487     return true;
488   if (RB == &AArch64::GPRRegBank)
489     return false;
490 
491   // We don't know anything.
492   //
493   // If we have a phi, we may be able to infer that it will be assigned a FPR
494   // based off of its inputs.
495   if (!MI.isPHI() || Depth > MaxFPRSearchDepth)
496     return false;
497 
498   return any_of(MI.explicit_uses(), [&](const MachineOperand &Op) {
499     return Op.isReg() &&
500            onlyDefinesFP(*MRI.getVRegDef(Op.getReg()), MRI, TRI, Depth + 1);
501   });
502 }
503 
504 bool AArch64RegisterBankInfo::onlyUsesFP(const MachineInstr &MI,
505                                          const MachineRegisterInfo &MRI,
506                                          const TargetRegisterInfo &TRI,
507                                          unsigned Depth) const {
508   switch (MI.getOpcode()) {
509   case TargetOpcode::G_FPTOSI:
510   case TargetOpcode::G_FPTOUI:
511   case TargetOpcode::G_FCMP:
512     return true;
513   default:
514     break;
515   }
516   return hasFPConstraints(MI, MRI, TRI, Depth);
517 }
518 
519 bool AArch64RegisterBankInfo::onlyDefinesFP(const MachineInstr &MI,
520                                             const MachineRegisterInfo &MRI,
521                                             const TargetRegisterInfo &TRI,
522                                             unsigned Depth) const {
523   switch (MI.getOpcode()) {
524   case AArch64::G_DUP:
525   case TargetOpcode::G_SITOFP:
526   case TargetOpcode::G_UITOFP:
527   case TargetOpcode::G_EXTRACT_VECTOR_ELT:
528   case TargetOpcode::G_INSERT_VECTOR_ELT:
529     return true;
530   default:
531     break;
532   }
533   return hasFPConstraints(MI, MRI, TRI, Depth);
534 }
535 
536 const RegisterBankInfo::InstructionMapping &
537 AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
538   const unsigned Opc = MI.getOpcode();
539 
540   // Try the default logic for non-generic instructions that are either copies
541   // or already have some operands assigned to banks.
542   if ((Opc != TargetOpcode::COPY && !isPreISelGenericOpcode(Opc)) ||
543       Opc == TargetOpcode::G_PHI) {
544     const RegisterBankInfo::InstructionMapping &Mapping =
545         getInstrMappingImpl(MI);
546     if (Mapping.isValid())
547       return Mapping;
548   }
549 
550   const MachineFunction &MF = *MI.getParent()->getParent();
551   const MachineRegisterInfo &MRI = MF.getRegInfo();
552   const TargetSubtargetInfo &STI = MF.getSubtarget();
553   const TargetRegisterInfo &TRI = *STI.getRegisterInfo();
554 
555   switch (Opc) {
556     // G_{F|S|U}REM are not listed because they are not legal.
557     // Arithmetic ops.
558   case TargetOpcode::G_ADD:
559   case TargetOpcode::G_SUB:
560   case TargetOpcode::G_PTR_ADD:
561   case TargetOpcode::G_MUL:
562   case TargetOpcode::G_SDIV:
563   case TargetOpcode::G_UDIV:
564     // Bitwise ops.
565   case TargetOpcode::G_AND:
566   case TargetOpcode::G_OR:
567   case TargetOpcode::G_XOR:
568     // Floating point ops.
569   case TargetOpcode::G_FADD:
570   case TargetOpcode::G_FSUB:
571   case TargetOpcode::G_FMUL:
572   case TargetOpcode::G_FDIV:
573     return getSameKindOfOperandsMapping(MI);
574   case TargetOpcode::G_FPEXT: {
575     LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
576     LLT SrcTy = MRI.getType(MI.getOperand(1).getReg());
577     return getInstructionMapping(
578         DefaultMappingID, /*Cost*/ 1,
579         getFPExtMapping(DstTy.getSizeInBits(), SrcTy.getSizeInBits()),
580         /*NumOperands*/ 2);
581   }
582     // Shifts.
583   case TargetOpcode::G_SHL:
584   case TargetOpcode::G_LSHR:
585   case TargetOpcode::G_ASHR: {
586     LLT ShiftAmtTy = MRI.getType(MI.getOperand(2).getReg());
587     LLT SrcTy = MRI.getType(MI.getOperand(1).getReg());
588     if (ShiftAmtTy.getSizeInBits() == 64 && SrcTy.getSizeInBits() == 32)
589       return getInstructionMapping(DefaultMappingID, 1,
590                                    &ValMappings[Shift64Imm], 3);
591     return getSameKindOfOperandsMapping(MI);
592   }
593   case TargetOpcode::COPY: {
594     Register DstReg = MI.getOperand(0).getReg();
595     Register SrcReg = MI.getOperand(1).getReg();
596     // Check if one of the register is not a generic register.
597     if ((Register::isPhysicalRegister(DstReg) ||
598          !MRI.getType(DstReg).isValid()) ||
599         (Register::isPhysicalRegister(SrcReg) ||
600          !MRI.getType(SrcReg).isValid())) {
601       const RegisterBank *DstRB = getRegBank(DstReg, MRI, TRI);
602       const RegisterBank *SrcRB = getRegBank(SrcReg, MRI, TRI);
603       if (!DstRB)
604         DstRB = SrcRB;
605       else if (!SrcRB)
606         SrcRB = DstRB;
607       // If both RB are null that means both registers are generic.
608       // We shouldn't be here.
609       assert(DstRB && SrcRB && "Both RegBank were nullptr");
610       unsigned Size = getSizeInBits(DstReg, MRI, TRI);
611       return getInstructionMapping(
612           DefaultMappingID, copyCost(*DstRB, *SrcRB, Size),
613           getCopyMapping(DstRB->getID(), SrcRB->getID(), Size),
614           // We only care about the mapping of the destination.
615           /*NumOperands*/ 1);
616     }
617     // Both registers are generic, use G_BITCAST.
618     LLVM_FALLTHROUGH;
619   }
620   case TargetOpcode::G_BITCAST: {
621     LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
622     LLT SrcTy = MRI.getType(MI.getOperand(1).getReg());
623     unsigned Size = DstTy.getSizeInBits();
624     bool DstIsGPR = !DstTy.isVector() && DstTy.getSizeInBits() <= 64;
625     bool SrcIsGPR = !SrcTy.isVector() && SrcTy.getSizeInBits() <= 64;
626     const RegisterBank &DstRB =
627         DstIsGPR ? AArch64::GPRRegBank : AArch64::FPRRegBank;
628     const RegisterBank &SrcRB =
629         SrcIsGPR ? AArch64::GPRRegBank : AArch64::FPRRegBank;
630     return getInstructionMapping(
631         DefaultMappingID, copyCost(DstRB, SrcRB, Size),
632         getCopyMapping(DstRB.getID(), SrcRB.getID(), Size),
633         // We only care about the mapping of the destination for COPY.
634         /*NumOperands*/ Opc == TargetOpcode::G_BITCAST ? 2 : 1);
635   }
636   default:
637     break;
638   }
639 
640   unsigned NumOperands = MI.getNumOperands();
641 
642   // Track the size and bank of each register.  We don't do partial mappings.
643   SmallVector<unsigned, 4> OpSize(NumOperands);
644   SmallVector<PartialMappingIdx, 4> OpRegBankIdx(NumOperands);
645   for (unsigned Idx = 0; Idx < NumOperands; ++Idx) {
646     auto &MO = MI.getOperand(Idx);
647     if (!MO.isReg() || !MO.getReg())
648       continue;
649 
650     LLT Ty = MRI.getType(MO.getReg());
651     OpSize[Idx] = Ty.getSizeInBits();
652 
653     // As a top-level guess, vectors go in FPRs, scalars and pointers in GPRs.
654     // For floating-point instructions, scalars go in FPRs.
655     if (Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc) ||
656         Ty.getSizeInBits() > 64)
657       OpRegBankIdx[Idx] = PMI_FirstFPR;
658     else
659       OpRegBankIdx[Idx] = PMI_FirstGPR;
660   }
661 
662   unsigned Cost = 1;
663   // Some of the floating-point instructions have mixed GPR and FPR operands:
664   // fine-tune the computed mapping.
665   switch (Opc) {
666   case AArch64::G_DUP: {
667     Register ScalarReg = MI.getOperand(1).getReg();
668     auto ScalarDef = MRI.getVRegDef(ScalarReg);
669     if (getRegBank(ScalarReg, MRI, TRI) == &AArch64::FPRRegBank ||
670         onlyDefinesFP(*ScalarDef, MRI, TRI))
671       OpRegBankIdx = {PMI_FirstFPR, PMI_FirstFPR};
672     else
673       OpRegBankIdx = {PMI_FirstFPR, PMI_FirstGPR};
674     break;
675   }
676   case TargetOpcode::G_TRUNC: {
677     LLT SrcTy = MRI.getType(MI.getOperand(1).getReg());
678     if (!SrcTy.isVector() && SrcTy.getSizeInBits() == 128)
679       OpRegBankIdx = {PMI_FirstFPR, PMI_FirstFPR};
680     break;
681   }
682   case TargetOpcode::G_SITOFP:
683   case TargetOpcode::G_UITOFP: {
684     if (MRI.getType(MI.getOperand(0).getReg()).isVector())
685       break;
686     // Integer to FP conversions don't necessarily happen between GPR -> FPR
687     // regbanks. They can also be done within an FPR register.
688     Register SrcReg = MI.getOperand(1).getReg();
689     if (getRegBank(SrcReg, MRI, TRI) == &AArch64::FPRRegBank)
690       OpRegBankIdx = {PMI_FirstFPR, PMI_FirstFPR};
691     else
692       OpRegBankIdx = {PMI_FirstFPR, PMI_FirstGPR};
693     break;
694   }
695   case TargetOpcode::G_FPTOSI:
696   case TargetOpcode::G_FPTOUI:
697     if (MRI.getType(MI.getOperand(0).getReg()).isVector())
698       break;
699     OpRegBankIdx = {PMI_FirstGPR, PMI_FirstFPR};
700     break;
701   case TargetOpcode::G_FCMP:
702     OpRegBankIdx = {PMI_FirstGPR,
703                     /* Predicate */ PMI_None, PMI_FirstFPR, PMI_FirstFPR};
704     break;
705   case TargetOpcode::G_BITCAST:
706     // This is going to be a cross register bank copy and this is expensive.
707     if (OpRegBankIdx[0] != OpRegBankIdx[1])
708       Cost = copyCost(
709           *AArch64GenRegisterBankInfo::PartMappings[OpRegBankIdx[0]].RegBank,
710           *AArch64GenRegisterBankInfo::PartMappings[OpRegBankIdx[1]].RegBank,
711           OpSize[0]);
712     break;
713   case TargetOpcode::G_LOAD:
714     // Loading in vector unit is slightly more expensive.
715     // This is actually only true for the LD1R and co instructions,
716     // but anyway for the fast mode this number does not matter and
717     // for the greedy mode the cost of the cross bank copy will
718     // offset this number.
719     // FIXME: Should be derived from the scheduling model.
720     if (OpRegBankIdx[0] != PMI_FirstGPR)
721       Cost = 2;
722     else
723       // Check if that load feeds fp instructions.
724       // In that case, we want the default mapping to be on FPR
725       // instead of blind map every scalar to GPR.
726       for (const MachineInstr &UseMI :
727            MRI.use_nodbg_instructions(MI.getOperand(0).getReg())) {
728         // If we have at least one direct use in a FP instruction,
729         // assume this was a floating point load in the IR.
730         // If it was not, we would have had a bitcast before
731         // reaching that instruction.
732         // Int->FP conversion operations are also captured in onlyDefinesFP().
733         if (onlyUsesFP(UseMI, MRI, TRI) || onlyDefinesFP(UseMI, MRI, TRI)) {
734           OpRegBankIdx[0] = PMI_FirstFPR;
735           break;
736         }
737       }
738     break;
739   case TargetOpcode::G_STORE:
740     // Check if that store is fed by fp instructions.
741     if (OpRegBankIdx[0] == PMI_FirstGPR) {
742       Register VReg = MI.getOperand(0).getReg();
743       if (!VReg)
744         break;
745       MachineInstr *DefMI = MRI.getVRegDef(VReg);
746       if (onlyDefinesFP(*DefMI, MRI, TRI))
747         OpRegBankIdx[0] = PMI_FirstFPR;
748       break;
749     }
750     break;
751   case TargetOpcode::G_SELECT: {
752     // If the destination is FPR, preserve that.
753     if (OpRegBankIdx[0] != PMI_FirstGPR)
754       break;
755 
756     // If we're taking in vectors, we have no choice but to put everything on
757     // FPRs, except for the condition. The condition must always be on a GPR.
758     LLT SrcTy = MRI.getType(MI.getOperand(2).getReg());
759     if (SrcTy.isVector()) {
760       OpRegBankIdx = {PMI_FirstFPR, PMI_FirstGPR, PMI_FirstFPR, PMI_FirstFPR};
761       break;
762     }
763 
764     // Try to minimize the number of copies. If we have more floating point
765     // constrained values than not, then we'll put everything on FPR. Otherwise,
766     // everything has to be on GPR.
767     unsigned NumFP = 0;
768 
769     // Check if the uses of the result always produce floating point values.
770     //
771     // For example:
772     //
773     // %z = G_SELECT %cond %x %y
774     // fpr = G_FOO %z ...
775     if (any_of(MRI.use_nodbg_instructions(MI.getOperand(0).getReg()),
776                [&](MachineInstr &MI) { return onlyUsesFP(MI, MRI, TRI); }))
777       ++NumFP;
778 
779     // Check if the defs of the source values always produce floating point
780     // values.
781     //
782     // For example:
783     //
784     // %x = G_SOMETHING_ALWAYS_FLOAT %a ...
785     // %z = G_SELECT %cond %x %y
786     //
787     // Also check whether or not the sources have already been decided to be
788     // FPR. Keep track of this.
789     //
790     // This doesn't check the condition, since it's just whatever is in NZCV.
791     // This isn't passed explicitly in a register to fcsel/csel.
792     for (unsigned Idx = 2; Idx < 4; ++Idx) {
793       Register VReg = MI.getOperand(Idx).getReg();
794       MachineInstr *DefMI = MRI.getVRegDef(VReg);
795       if (getRegBank(VReg, MRI, TRI) == &AArch64::FPRRegBank ||
796           onlyDefinesFP(*DefMI, MRI, TRI))
797         ++NumFP;
798     }
799 
800     // If we have more FP constraints than not, then move everything over to
801     // FPR.
802     if (NumFP >= 2)
803       OpRegBankIdx = {PMI_FirstFPR, PMI_FirstGPR, PMI_FirstFPR, PMI_FirstFPR};
804 
805     break;
806   }
807   case TargetOpcode::G_UNMERGE_VALUES: {
808     // If the first operand belongs to a FPR register bank, then make sure that
809     // we preserve that.
810     if (OpRegBankIdx[0] != PMI_FirstGPR)
811       break;
812 
813     LLT SrcTy = MRI.getType(MI.getOperand(MI.getNumOperands()-1).getReg());
814     // UNMERGE into scalars from a vector should always use FPR.
815     // Likewise if any of the uses are FP instructions.
816     if (SrcTy.isVector() || SrcTy == LLT::scalar(128) ||
817         any_of(MRI.use_nodbg_instructions(MI.getOperand(0).getReg()),
818                [&](MachineInstr &MI) { return onlyUsesFP(MI, MRI, TRI); })) {
819       // Set the register bank of every operand to FPR.
820       for (unsigned Idx = 0, NumOperands = MI.getNumOperands();
821            Idx < NumOperands; ++Idx)
822         OpRegBankIdx[Idx] = PMI_FirstFPR;
823     }
824     break;
825   }
826   case TargetOpcode::G_EXTRACT_VECTOR_ELT:
827     // Destination and source need to be FPRs.
828     OpRegBankIdx[0] = PMI_FirstFPR;
829     OpRegBankIdx[1] = PMI_FirstFPR;
830 
831     // Index needs to be a GPR.
832     OpRegBankIdx[2] = PMI_FirstGPR;
833     break;
834   case TargetOpcode::G_INSERT_VECTOR_ELT:
835     OpRegBankIdx[0] = PMI_FirstFPR;
836     OpRegBankIdx[1] = PMI_FirstFPR;
837 
838     // The element may be either a GPR or FPR. Preserve that behaviour.
839     if (getRegBank(MI.getOperand(2).getReg(), MRI, TRI) == &AArch64::FPRRegBank)
840       OpRegBankIdx[2] = PMI_FirstFPR;
841     else
842       OpRegBankIdx[2] = PMI_FirstGPR;
843 
844     // Index needs to be a GPR.
845     OpRegBankIdx[3] = PMI_FirstGPR;
846     break;
847   case TargetOpcode::G_EXTRACT: {
848     // For s128 sources we have to use fpr.
849     LLT SrcTy = MRI.getType(MI.getOperand(1).getReg());
850     if (SrcTy.getSizeInBits() == 128) {
851       OpRegBankIdx[0] = PMI_FirstFPR;
852       OpRegBankIdx[1] = PMI_FirstFPR;
853     }
854     break;
855   }
856   case TargetOpcode::G_BUILD_VECTOR: {
857     // If the first source operand belongs to a FPR register bank, then make
858     // sure that we preserve that.
859     if (OpRegBankIdx[1] != PMI_FirstGPR)
860       break;
861     Register VReg = MI.getOperand(1).getReg();
862     if (!VReg)
863       break;
864 
865     // Get the instruction that defined the source operand reg, and check if
866     // it's a floating point operation. Or, if it's a type like s16 which
867     // doesn't have a exact size gpr register class. The exception is if the
868     // build_vector has all constant operands, which may be better to leave as
869     // gpr without copies, so it can be matched in imported patterns.
870     MachineInstr *DefMI = MRI.getVRegDef(VReg);
871     unsigned DefOpc = DefMI->getOpcode();
872     const LLT SrcTy = MRI.getType(VReg);
873     if (all_of(MI.operands(), [&](const MachineOperand &Op) {
874           return Op.isDef() || MRI.getVRegDef(Op.getReg())->getOpcode() ==
875                                    TargetOpcode::G_CONSTANT;
876         }))
877       break;
878     if (isPreISelGenericFloatingPointOpcode(DefOpc) ||
879         SrcTy.getSizeInBits() < 32) {
880       // Have a floating point op.
881       // Make sure every operand gets mapped to a FPR register class.
882       unsigned NumOperands = MI.getNumOperands();
883       for (unsigned Idx = 0; Idx < NumOperands; ++Idx)
884         OpRegBankIdx[Idx] = PMI_FirstFPR;
885     }
886     break;
887   }
888   case TargetOpcode::G_VECREDUCE_FADD:
889   case TargetOpcode::G_VECREDUCE_FMUL:
890   case TargetOpcode::G_VECREDUCE_FMAX:
891   case TargetOpcode::G_VECREDUCE_FMIN:
892   case TargetOpcode::G_VECREDUCE_ADD:
893   case TargetOpcode::G_VECREDUCE_MUL:
894   case TargetOpcode::G_VECREDUCE_AND:
895   case TargetOpcode::G_VECREDUCE_OR:
896   case TargetOpcode::G_VECREDUCE_XOR:
897   case TargetOpcode::G_VECREDUCE_SMAX:
898   case TargetOpcode::G_VECREDUCE_SMIN:
899   case TargetOpcode::G_VECREDUCE_UMAX:
900   case TargetOpcode::G_VECREDUCE_UMIN:
901     // Reductions produce a scalar value from a vector, the scalar should be on
902     // FPR bank.
903     OpRegBankIdx = {PMI_FirstFPR, PMI_FirstFPR};
904     break;
905   case TargetOpcode::G_VECREDUCE_SEQ_FADD:
906   case TargetOpcode::G_VECREDUCE_SEQ_FMUL:
907     // These reductions also take a scalar accumulator input.
908     // Assign them FPR for now.
909     OpRegBankIdx = {PMI_FirstFPR, PMI_FirstFPR, PMI_FirstFPR};
910     break;
911   }
912 
913   // Finally construct the computed mapping.
914   SmallVector<const ValueMapping *, 8> OpdsMapping(NumOperands);
915   for (unsigned Idx = 0; Idx < NumOperands; ++Idx) {
916     if (MI.getOperand(Idx).isReg() && MI.getOperand(Idx).getReg()) {
917       auto Mapping = getValueMapping(OpRegBankIdx[Idx], OpSize[Idx]);
918       if (!Mapping->isValid())
919         return getInvalidInstructionMapping();
920 
921       OpdsMapping[Idx] = Mapping;
922     }
923   }
924 
925   return getInstructionMapping(DefaultMappingID, Cost,
926                                getOperandsMapping(OpdsMapping), NumOperands);
927 }
928