1 //===-- RISCVInstrInfo.cpp - RISCV Instruction Information ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the RISCV implementation of the TargetInstrInfo class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "RISCVInstrInfo.h"
14 #include "MCTargetDesc/RISCVMatInt.h"
15 #include "RISCV.h"
16 #include "RISCVMachineFunctionInfo.h"
17 #include "RISCVSubtarget.h"
18 #include "RISCVTargetMachine.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Analysis/MemoryLocation.h"
22 #include "llvm/CodeGen/LiveVariables.h"
23 #include "llvm/CodeGen/MachineFunctionPass.h"
24 #include "llvm/CodeGen/MachineInstrBuilder.h"
25 #include "llvm/CodeGen/MachineRegisterInfo.h"
26 #include "llvm/CodeGen/RegisterScavenging.h"
27 #include "llvm/MC/MCInstBuilder.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include "llvm/Support/TargetRegistry.h"
30 
31 using namespace llvm;
32 
33 #define GEN_CHECK_COMPRESS_INSTR
34 #include "RISCVGenCompressInstEmitter.inc"
35 
36 #define GET_INSTRINFO_CTOR_DTOR
37 #include "RISCVGenInstrInfo.inc"
38 
39 namespace llvm {
40 namespace RISCVVPseudosTable {
41 
42 using namespace RISCV;
43 
44 #define GET_RISCVVPseudosTable_IMPL
45 #include "RISCVGenSearchableTables.inc"
46 
47 } // namespace RISCVVPseudosTable
48 } // namespace llvm
49 
RISCVInstrInfo(RISCVSubtarget & STI)50 RISCVInstrInfo::RISCVInstrInfo(RISCVSubtarget &STI)
51     : RISCVGenInstrInfo(RISCV::ADJCALLSTACKDOWN, RISCV::ADJCALLSTACKUP),
52       STI(STI) {}
53 
getNop() const54 MCInst RISCVInstrInfo::getNop() const {
55   if (STI.getFeatureBits()[RISCV::FeatureStdExtC])
56     return MCInstBuilder(RISCV::C_NOP);
57   return MCInstBuilder(RISCV::ADDI)
58       .addReg(RISCV::X0)
59       .addReg(RISCV::X0)
60       .addImm(0);
61 }
62 
isLoadFromStackSlot(const MachineInstr & MI,int & FrameIndex) const63 unsigned RISCVInstrInfo::isLoadFromStackSlot(const MachineInstr &MI,
64                                              int &FrameIndex) const {
65   switch (MI.getOpcode()) {
66   default:
67     return 0;
68   case RISCV::LB:
69   case RISCV::LBU:
70   case RISCV::LH:
71   case RISCV::LHU:
72   case RISCV::FLH:
73   case RISCV::LW:
74   case RISCV::FLW:
75   case RISCV::LWU:
76   case RISCV::LD:
77   case RISCV::FLD:
78     break;
79   }
80 
81   if (MI.getOperand(1).isFI() && MI.getOperand(2).isImm() &&
82       MI.getOperand(2).getImm() == 0) {
83     FrameIndex = MI.getOperand(1).getIndex();
84     return MI.getOperand(0).getReg();
85   }
86 
87   return 0;
88 }
89 
isStoreToStackSlot(const MachineInstr & MI,int & FrameIndex) const90 unsigned RISCVInstrInfo::isStoreToStackSlot(const MachineInstr &MI,
91                                             int &FrameIndex) const {
92   switch (MI.getOpcode()) {
93   default:
94     return 0;
95   case RISCV::SB:
96   case RISCV::SH:
97   case RISCV::SW:
98   case RISCV::FSH:
99   case RISCV::FSW:
100   case RISCV::SD:
101   case RISCV::FSD:
102     break;
103   }
104 
105   if (MI.getOperand(1).isFI() && MI.getOperand(2).isImm() &&
106       MI.getOperand(2).getImm() == 0) {
107     FrameIndex = MI.getOperand(1).getIndex();
108     return MI.getOperand(0).getReg();
109   }
110 
111   return 0;
112 }
113 
forwardCopyWillClobberTuple(unsigned DstReg,unsigned SrcReg,unsigned NumRegs)114 static bool forwardCopyWillClobberTuple(unsigned DstReg, unsigned SrcReg,
115                                         unsigned NumRegs) {
116   // We really want the positive remainder mod 32 here, that happens to be
117   // easily obtainable with a mask.
118   return ((DstReg - SrcReg) & 0x1f) < NumRegs;
119 }
120 
copyPhysReg(MachineBasicBlock & MBB,MachineBasicBlock::iterator MBBI,const DebugLoc & DL,MCRegister DstReg,MCRegister SrcReg,bool KillSrc) const121 void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
122                                  MachineBasicBlock::iterator MBBI,
123                                  const DebugLoc &DL, MCRegister DstReg,
124                                  MCRegister SrcReg, bool KillSrc) const {
125   if (RISCV::GPRRegClass.contains(DstReg, SrcReg)) {
126     BuildMI(MBB, MBBI, DL, get(RISCV::ADDI), DstReg)
127         .addReg(SrcReg, getKillRegState(KillSrc))
128         .addImm(0);
129     return;
130   }
131 
132   // FPR->FPR copies and VR->VR copies.
133   unsigned Opc;
134   bool IsScalableVector = true;
135   unsigned NF = 1;
136   unsigned LMul = 1;
137   unsigned SubRegIdx = RISCV::sub_vrm1_0;
138   if (RISCV::FPR16RegClass.contains(DstReg, SrcReg)) {
139     Opc = RISCV::FSGNJ_H;
140     IsScalableVector = false;
141   } else if (RISCV::FPR32RegClass.contains(DstReg, SrcReg)) {
142     Opc = RISCV::FSGNJ_S;
143     IsScalableVector = false;
144   } else if (RISCV::FPR64RegClass.contains(DstReg, SrcReg)) {
145     Opc = RISCV::FSGNJ_D;
146     IsScalableVector = false;
147   } else if (RISCV::VRRegClass.contains(DstReg, SrcReg)) {
148     Opc = RISCV::PseudoVMV1R_V;
149   } else if (RISCV::VRM2RegClass.contains(DstReg, SrcReg)) {
150     Opc = RISCV::PseudoVMV2R_V;
151   } else if (RISCV::VRM4RegClass.contains(DstReg, SrcReg)) {
152     Opc = RISCV::PseudoVMV4R_V;
153   } else if (RISCV::VRM8RegClass.contains(DstReg, SrcReg)) {
154     Opc = RISCV::PseudoVMV8R_V;
155   } else if (RISCV::VRN2M1RegClass.contains(DstReg, SrcReg)) {
156     Opc = RISCV::PseudoVMV1R_V;
157     SubRegIdx = RISCV::sub_vrm1_0;
158     NF = 2;
159     LMul = 1;
160   } else if (RISCV::VRN2M2RegClass.contains(DstReg, SrcReg)) {
161     Opc = RISCV::PseudoVMV2R_V;
162     SubRegIdx = RISCV::sub_vrm2_0;
163     NF = 2;
164     LMul = 2;
165   } else if (RISCV::VRN2M4RegClass.contains(DstReg, SrcReg)) {
166     Opc = RISCV::PseudoVMV4R_V;
167     SubRegIdx = RISCV::sub_vrm4_0;
168     NF = 2;
169     LMul = 4;
170   } else if (RISCV::VRN3M1RegClass.contains(DstReg, SrcReg)) {
171     Opc = RISCV::PseudoVMV1R_V;
172     SubRegIdx = RISCV::sub_vrm1_0;
173     NF = 3;
174     LMul = 1;
175   } else if (RISCV::VRN3M2RegClass.contains(DstReg, SrcReg)) {
176     Opc = RISCV::PseudoVMV2R_V;
177     SubRegIdx = RISCV::sub_vrm2_0;
178     NF = 3;
179     LMul = 2;
180   } else if (RISCV::VRN4M1RegClass.contains(DstReg, SrcReg)) {
181     Opc = RISCV::PseudoVMV1R_V;
182     SubRegIdx = RISCV::sub_vrm1_0;
183     NF = 4;
184     LMul = 1;
185   } else if (RISCV::VRN4M2RegClass.contains(DstReg, SrcReg)) {
186     Opc = RISCV::PseudoVMV2R_V;
187     SubRegIdx = RISCV::sub_vrm2_0;
188     NF = 4;
189     LMul = 2;
190   } else if (RISCV::VRN5M1RegClass.contains(DstReg, SrcReg)) {
191     Opc = RISCV::PseudoVMV1R_V;
192     SubRegIdx = RISCV::sub_vrm1_0;
193     NF = 5;
194     LMul = 1;
195   } else if (RISCV::VRN6M1RegClass.contains(DstReg, SrcReg)) {
196     Opc = RISCV::PseudoVMV1R_V;
197     SubRegIdx = RISCV::sub_vrm1_0;
198     NF = 6;
199     LMul = 1;
200   } else if (RISCV::VRN7M1RegClass.contains(DstReg, SrcReg)) {
201     Opc = RISCV::PseudoVMV1R_V;
202     SubRegIdx = RISCV::sub_vrm1_0;
203     NF = 7;
204     LMul = 1;
205   } else if (RISCV::VRN8M1RegClass.contains(DstReg, SrcReg)) {
206     Opc = RISCV::PseudoVMV1R_V;
207     SubRegIdx = RISCV::sub_vrm1_0;
208     NF = 8;
209     LMul = 1;
210   } else {
211     llvm_unreachable("Impossible reg-to-reg copy");
212   }
213 
214   if (IsScalableVector) {
215     if (NF == 1) {
216       BuildMI(MBB, MBBI, DL, get(Opc), DstReg)
217           .addReg(SrcReg, getKillRegState(KillSrc));
218     } else {
219       const TargetRegisterInfo *TRI = STI.getRegisterInfo();
220 
221       int I = 0, End = NF, Incr = 1;
222       unsigned SrcEncoding = TRI->getEncodingValue(SrcReg);
223       unsigned DstEncoding = TRI->getEncodingValue(DstReg);
224       if (forwardCopyWillClobberTuple(DstEncoding, SrcEncoding, NF * LMul)) {
225         I = NF - 1;
226         End = -1;
227         Incr = -1;
228       }
229 
230       for (; I != End; I += Incr) {
231         BuildMI(MBB, MBBI, DL, get(Opc), TRI->getSubReg(DstReg, SubRegIdx + I))
232             .addReg(TRI->getSubReg(SrcReg, SubRegIdx + I),
233                     getKillRegState(KillSrc));
234       }
235     }
236   } else {
237     BuildMI(MBB, MBBI, DL, get(Opc), DstReg)
238         .addReg(SrcReg, getKillRegState(KillSrc))
239         .addReg(SrcReg, getKillRegState(KillSrc));
240   }
241 }
242 
storeRegToStackSlot(MachineBasicBlock & MBB,MachineBasicBlock::iterator I,Register SrcReg,bool IsKill,int FI,const TargetRegisterClass * RC,const TargetRegisterInfo * TRI) const243 void RISCVInstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
244                                          MachineBasicBlock::iterator I,
245                                          Register SrcReg, bool IsKill, int FI,
246                                          const TargetRegisterClass *RC,
247                                          const TargetRegisterInfo *TRI) const {
248   DebugLoc DL;
249   if (I != MBB.end())
250     DL = I->getDebugLoc();
251 
252   MachineFunction *MF = MBB.getParent();
253   MachineFrameInfo &MFI = MF->getFrameInfo();
254 
255   unsigned Opcode;
256   bool IsScalableVector = true;
257   bool IsZvlsseg = true;
258   if (RISCV::GPRRegClass.hasSubClassEq(RC)) {
259     Opcode = TRI->getRegSizeInBits(RISCV::GPRRegClass) == 32 ?
260              RISCV::SW : RISCV::SD;
261     IsScalableVector = false;
262   } else if (RISCV::FPR16RegClass.hasSubClassEq(RC)) {
263     Opcode = RISCV::FSH;
264     IsScalableVector = false;
265   } else if (RISCV::FPR32RegClass.hasSubClassEq(RC)) {
266     Opcode = RISCV::FSW;
267     IsScalableVector = false;
268   } else if (RISCV::FPR64RegClass.hasSubClassEq(RC)) {
269     Opcode = RISCV::FSD;
270     IsScalableVector = false;
271   } else if (RISCV::VRRegClass.hasSubClassEq(RC)) {
272     Opcode = RISCV::PseudoVSPILL_M1;
273     IsZvlsseg = false;
274   } else if (RISCV::VRM2RegClass.hasSubClassEq(RC)) {
275     Opcode = RISCV::PseudoVSPILL_M2;
276     IsZvlsseg = false;
277   } else if (RISCV::VRM4RegClass.hasSubClassEq(RC)) {
278     Opcode = RISCV::PseudoVSPILL_M4;
279     IsZvlsseg = false;
280   } else if (RISCV::VRM8RegClass.hasSubClassEq(RC)) {
281     Opcode = RISCV::PseudoVSPILL_M8;
282     IsZvlsseg = false;
283   } else if (RISCV::VRN2M1RegClass.hasSubClassEq(RC))
284     Opcode = RISCV::PseudoVSPILL2_M1;
285   else if (RISCV::VRN2M2RegClass.hasSubClassEq(RC))
286     Opcode = RISCV::PseudoVSPILL2_M2;
287   else if (RISCV::VRN2M4RegClass.hasSubClassEq(RC))
288     Opcode = RISCV::PseudoVSPILL2_M4;
289   else if (RISCV::VRN3M1RegClass.hasSubClassEq(RC))
290     Opcode = RISCV::PseudoVSPILL3_M1;
291   else if (RISCV::VRN3M2RegClass.hasSubClassEq(RC))
292     Opcode = RISCV::PseudoVSPILL3_M2;
293   else if (RISCV::VRN4M1RegClass.hasSubClassEq(RC))
294     Opcode = RISCV::PseudoVSPILL4_M1;
295   else if (RISCV::VRN4M2RegClass.hasSubClassEq(RC))
296     Opcode = RISCV::PseudoVSPILL4_M2;
297   else if (RISCV::VRN5M1RegClass.hasSubClassEq(RC))
298     Opcode = RISCV::PseudoVSPILL5_M1;
299   else if (RISCV::VRN6M1RegClass.hasSubClassEq(RC))
300     Opcode = RISCV::PseudoVSPILL6_M1;
301   else if (RISCV::VRN7M1RegClass.hasSubClassEq(RC))
302     Opcode = RISCV::PseudoVSPILL7_M1;
303   else if (RISCV::VRN8M1RegClass.hasSubClassEq(RC))
304     Opcode = RISCV::PseudoVSPILL8_M1;
305   else
306     llvm_unreachable("Can't store this register to stack slot");
307 
308   if (IsScalableVector) {
309     MachineMemOperand *MMO = MF->getMachineMemOperand(
310         MachinePointerInfo::getFixedStack(*MF, FI), MachineMemOperand::MOStore,
311         MemoryLocation::UnknownSize, MFI.getObjectAlign(FI));
312 
313     MFI.setStackID(FI, TargetStackID::ScalableVector);
314     auto MIB = BuildMI(MBB, I, DL, get(Opcode))
315                    .addReg(SrcReg, getKillRegState(IsKill))
316                    .addFrameIndex(FI)
317                    .addMemOperand(MMO);
318     if (IsZvlsseg) {
319       // For spilling/reloading Zvlsseg registers, append the dummy field for
320       // the scaled vector length. The argument will be used when expanding
321       // these pseudo instructions.
322       MIB.addReg(RISCV::X0);
323     }
324   } else {
325     MachineMemOperand *MMO = MF->getMachineMemOperand(
326         MachinePointerInfo::getFixedStack(*MF, FI), MachineMemOperand::MOStore,
327         MFI.getObjectSize(FI), MFI.getObjectAlign(FI));
328 
329     BuildMI(MBB, I, DL, get(Opcode))
330         .addReg(SrcReg, getKillRegState(IsKill))
331         .addFrameIndex(FI)
332         .addImm(0)
333         .addMemOperand(MMO);
334   }
335 }
336 
loadRegFromStackSlot(MachineBasicBlock & MBB,MachineBasicBlock::iterator I,Register DstReg,int FI,const TargetRegisterClass * RC,const TargetRegisterInfo * TRI) const337 void RISCVInstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
338                                           MachineBasicBlock::iterator I,
339                                           Register DstReg, int FI,
340                                           const TargetRegisterClass *RC,
341                                           const TargetRegisterInfo *TRI) const {
342   DebugLoc DL;
343   if (I != MBB.end())
344     DL = I->getDebugLoc();
345 
346   MachineFunction *MF = MBB.getParent();
347   MachineFrameInfo &MFI = MF->getFrameInfo();
348 
349   unsigned Opcode;
350   bool IsScalableVector = true;
351   bool IsZvlsseg = true;
352   if (RISCV::GPRRegClass.hasSubClassEq(RC)) {
353     Opcode = TRI->getRegSizeInBits(RISCV::GPRRegClass) == 32 ?
354              RISCV::LW : RISCV::LD;
355     IsScalableVector = false;
356   } else if (RISCV::FPR16RegClass.hasSubClassEq(RC)) {
357     Opcode = RISCV::FLH;
358     IsScalableVector = false;
359   } else if (RISCV::FPR32RegClass.hasSubClassEq(RC)) {
360     Opcode = RISCV::FLW;
361     IsScalableVector = false;
362   } else if (RISCV::FPR64RegClass.hasSubClassEq(RC)) {
363     Opcode = RISCV::FLD;
364     IsScalableVector = false;
365   } else if (RISCV::VRRegClass.hasSubClassEq(RC)) {
366     Opcode = RISCV::PseudoVRELOAD_M1;
367     IsZvlsseg = false;
368   } else if (RISCV::VRM2RegClass.hasSubClassEq(RC)) {
369     Opcode = RISCV::PseudoVRELOAD_M2;
370     IsZvlsseg = false;
371   } else if (RISCV::VRM4RegClass.hasSubClassEq(RC)) {
372     Opcode = RISCV::PseudoVRELOAD_M4;
373     IsZvlsseg = false;
374   } else if (RISCV::VRM8RegClass.hasSubClassEq(RC)) {
375     Opcode = RISCV::PseudoVRELOAD_M8;
376     IsZvlsseg = false;
377   } else if (RISCV::VRN2M1RegClass.hasSubClassEq(RC))
378     Opcode = RISCV::PseudoVRELOAD2_M1;
379   else if (RISCV::VRN2M2RegClass.hasSubClassEq(RC))
380     Opcode = RISCV::PseudoVRELOAD2_M2;
381   else if (RISCV::VRN2M4RegClass.hasSubClassEq(RC))
382     Opcode = RISCV::PseudoVRELOAD2_M4;
383   else if (RISCV::VRN3M1RegClass.hasSubClassEq(RC))
384     Opcode = RISCV::PseudoVRELOAD3_M1;
385   else if (RISCV::VRN3M2RegClass.hasSubClassEq(RC))
386     Opcode = RISCV::PseudoVRELOAD3_M2;
387   else if (RISCV::VRN4M1RegClass.hasSubClassEq(RC))
388     Opcode = RISCV::PseudoVRELOAD4_M1;
389   else if (RISCV::VRN4M2RegClass.hasSubClassEq(RC))
390     Opcode = RISCV::PseudoVRELOAD4_M2;
391   else if (RISCV::VRN5M1RegClass.hasSubClassEq(RC))
392     Opcode = RISCV::PseudoVRELOAD5_M1;
393   else if (RISCV::VRN6M1RegClass.hasSubClassEq(RC))
394     Opcode = RISCV::PseudoVRELOAD6_M1;
395   else if (RISCV::VRN7M1RegClass.hasSubClassEq(RC))
396     Opcode = RISCV::PseudoVRELOAD7_M1;
397   else if (RISCV::VRN8M1RegClass.hasSubClassEq(RC))
398     Opcode = RISCV::PseudoVRELOAD8_M1;
399   else
400     llvm_unreachable("Can't load this register from stack slot");
401 
402   if (IsScalableVector) {
403     MachineMemOperand *MMO = MF->getMachineMemOperand(
404         MachinePointerInfo::getFixedStack(*MF, FI), MachineMemOperand::MOLoad,
405         MemoryLocation::UnknownSize, MFI.getObjectAlign(FI));
406 
407     MFI.setStackID(FI, TargetStackID::ScalableVector);
408     auto MIB = BuildMI(MBB, I, DL, get(Opcode), DstReg)
409                    .addFrameIndex(FI)
410                    .addMemOperand(MMO);
411     if (IsZvlsseg) {
412       // For spilling/reloading Zvlsseg registers, append the dummy field for
413       // the scaled vector length. The argument will be used when expanding
414       // these pseudo instructions.
415       MIB.addReg(RISCV::X0);
416     }
417   } else {
418     MachineMemOperand *MMO = MF->getMachineMemOperand(
419         MachinePointerInfo::getFixedStack(*MF, FI), MachineMemOperand::MOLoad,
420         MFI.getObjectSize(FI), MFI.getObjectAlign(FI));
421 
422     BuildMI(MBB, I, DL, get(Opcode), DstReg)
423         .addFrameIndex(FI)
424         .addImm(0)
425         .addMemOperand(MMO);
426   }
427 }
428 
movImm(MachineBasicBlock & MBB,MachineBasicBlock::iterator MBBI,const DebugLoc & DL,Register DstReg,uint64_t Val,MachineInstr::MIFlag Flag) const429 void RISCVInstrInfo::movImm(MachineBasicBlock &MBB,
430                             MachineBasicBlock::iterator MBBI,
431                             const DebugLoc &DL, Register DstReg, uint64_t Val,
432                             MachineInstr::MIFlag Flag) const {
433   MachineFunction *MF = MBB.getParent();
434   MachineRegisterInfo &MRI = MF->getRegInfo();
435   Register SrcReg = RISCV::X0;
436   Register Result = MRI.createVirtualRegister(&RISCV::GPRRegClass);
437   unsigned Num = 0;
438 
439   if (!STI.is64Bit() && !isInt<32>(Val))
440     report_fatal_error("Should only materialize 32-bit constants for RV32");
441 
442   RISCVMatInt::InstSeq Seq =
443       RISCVMatInt::generateInstSeq(Val, STI.getFeatureBits());
444   assert(!Seq.empty());
445 
446   for (RISCVMatInt::Inst &Inst : Seq) {
447     // Write the final result to DstReg if it's the last instruction in the Seq.
448     // Otherwise, write the result to the temp register.
449     if (++Num == Seq.size())
450       Result = DstReg;
451 
452     if (Inst.Opc == RISCV::LUI) {
453       BuildMI(MBB, MBBI, DL, get(RISCV::LUI), Result)
454           .addImm(Inst.Imm)
455           .setMIFlag(Flag);
456     } else if (Inst.Opc == RISCV::ADDUW) {
457       BuildMI(MBB, MBBI, DL, get(RISCV::ADDUW), Result)
458           .addReg(SrcReg, RegState::Kill)
459           .addReg(RISCV::X0)
460           .setMIFlag(Flag);
461     } else {
462       BuildMI(MBB, MBBI, DL, get(Inst.Opc), Result)
463           .addReg(SrcReg, RegState::Kill)
464           .addImm(Inst.Imm)
465           .setMIFlag(Flag);
466     }
467     // Only the first instruction has X0 as its source.
468     SrcReg = Result;
469   }
470 }
471 
472 // The contents of values added to Cond are not examined outside of
473 // RISCVInstrInfo, giving us flexibility in what to push to it. For RISCV, we
474 // push BranchOpcode, Reg1, Reg2.
parseCondBranch(MachineInstr & LastInst,MachineBasicBlock * & Target,SmallVectorImpl<MachineOperand> & Cond)475 static void parseCondBranch(MachineInstr &LastInst, MachineBasicBlock *&Target,
476                             SmallVectorImpl<MachineOperand> &Cond) {
477   // Block ends with fall-through condbranch.
478   assert(LastInst.getDesc().isConditionalBranch() &&
479          "Unknown conditional branch");
480   Target = LastInst.getOperand(2).getMBB();
481   Cond.push_back(MachineOperand::CreateImm(LastInst.getOpcode()));
482   Cond.push_back(LastInst.getOperand(0));
483   Cond.push_back(LastInst.getOperand(1));
484 }
485 
getOppositeBranchOpcode(int Opc)486 static unsigned getOppositeBranchOpcode(int Opc) {
487   switch (Opc) {
488   default:
489     llvm_unreachable("Unrecognized conditional branch");
490   case RISCV::BEQ:
491     return RISCV::BNE;
492   case RISCV::BNE:
493     return RISCV::BEQ;
494   case RISCV::BLT:
495     return RISCV::BGE;
496   case RISCV::BGE:
497     return RISCV::BLT;
498   case RISCV::BLTU:
499     return RISCV::BGEU;
500   case RISCV::BGEU:
501     return RISCV::BLTU;
502   }
503 }
504 
analyzeBranch(MachineBasicBlock & MBB,MachineBasicBlock * & TBB,MachineBasicBlock * & FBB,SmallVectorImpl<MachineOperand> & Cond,bool AllowModify) const505 bool RISCVInstrInfo::analyzeBranch(MachineBasicBlock &MBB,
506                                    MachineBasicBlock *&TBB,
507                                    MachineBasicBlock *&FBB,
508                                    SmallVectorImpl<MachineOperand> &Cond,
509                                    bool AllowModify) const {
510   TBB = FBB = nullptr;
511   Cond.clear();
512 
513   // If the block has no terminators, it just falls into the block after it.
514   MachineBasicBlock::iterator I = MBB.getLastNonDebugInstr();
515   if (I == MBB.end() || !isUnpredicatedTerminator(*I))
516     return false;
517 
518   // Count the number of terminators and find the first unconditional or
519   // indirect branch.
520   MachineBasicBlock::iterator FirstUncondOrIndirectBr = MBB.end();
521   int NumTerminators = 0;
522   for (auto J = I.getReverse(); J != MBB.rend() && isUnpredicatedTerminator(*J);
523        J++) {
524     NumTerminators++;
525     if (J->getDesc().isUnconditionalBranch() ||
526         J->getDesc().isIndirectBranch()) {
527       FirstUncondOrIndirectBr = J.getReverse();
528     }
529   }
530 
531   // If AllowModify is true, we can erase any terminators after
532   // FirstUncondOrIndirectBR.
533   if (AllowModify && FirstUncondOrIndirectBr != MBB.end()) {
534     while (std::next(FirstUncondOrIndirectBr) != MBB.end()) {
535       std::next(FirstUncondOrIndirectBr)->eraseFromParent();
536       NumTerminators--;
537     }
538     I = FirstUncondOrIndirectBr;
539   }
540 
541   // We can't handle blocks that end in an indirect branch.
542   if (I->getDesc().isIndirectBranch())
543     return true;
544 
545   // We can't handle blocks with more than 2 terminators.
546   if (NumTerminators > 2)
547     return true;
548 
549   // Handle a single unconditional branch.
550   if (NumTerminators == 1 && I->getDesc().isUnconditionalBranch()) {
551     TBB = getBranchDestBlock(*I);
552     return false;
553   }
554 
555   // Handle a single conditional branch.
556   if (NumTerminators == 1 && I->getDesc().isConditionalBranch()) {
557     parseCondBranch(*I, TBB, Cond);
558     return false;
559   }
560 
561   // Handle a conditional branch followed by an unconditional branch.
562   if (NumTerminators == 2 && std::prev(I)->getDesc().isConditionalBranch() &&
563       I->getDesc().isUnconditionalBranch()) {
564     parseCondBranch(*std::prev(I), TBB, Cond);
565     FBB = getBranchDestBlock(*I);
566     return false;
567   }
568 
569   // Otherwise, we can't handle this.
570   return true;
571 }
572 
removeBranch(MachineBasicBlock & MBB,int * BytesRemoved) const573 unsigned RISCVInstrInfo::removeBranch(MachineBasicBlock &MBB,
574                                       int *BytesRemoved) const {
575   if (BytesRemoved)
576     *BytesRemoved = 0;
577   MachineBasicBlock::iterator I = MBB.getLastNonDebugInstr();
578   if (I == MBB.end())
579     return 0;
580 
581   if (!I->getDesc().isUnconditionalBranch() &&
582       !I->getDesc().isConditionalBranch())
583     return 0;
584 
585   // Remove the branch.
586   if (BytesRemoved)
587     *BytesRemoved += getInstSizeInBytes(*I);
588   I->eraseFromParent();
589 
590   I = MBB.end();
591 
592   if (I == MBB.begin())
593     return 1;
594   --I;
595   if (!I->getDesc().isConditionalBranch())
596     return 1;
597 
598   // Remove the branch.
599   if (BytesRemoved)
600     *BytesRemoved += getInstSizeInBytes(*I);
601   I->eraseFromParent();
602   return 2;
603 }
604 
605 // Inserts a branch into the end of the specific MachineBasicBlock, returning
606 // the number of instructions inserted.
insertBranch(MachineBasicBlock & MBB,MachineBasicBlock * TBB,MachineBasicBlock * FBB,ArrayRef<MachineOperand> Cond,const DebugLoc & DL,int * BytesAdded) const607 unsigned RISCVInstrInfo::insertBranch(
608     MachineBasicBlock &MBB, MachineBasicBlock *TBB, MachineBasicBlock *FBB,
609     ArrayRef<MachineOperand> Cond, const DebugLoc &DL, int *BytesAdded) const {
610   if (BytesAdded)
611     *BytesAdded = 0;
612 
613   // Shouldn't be a fall through.
614   assert(TBB && "insertBranch must not be told to insert a fallthrough");
615   assert((Cond.size() == 3 || Cond.size() == 0) &&
616          "RISCV branch conditions have two components!");
617 
618   // Unconditional branch.
619   if (Cond.empty()) {
620     MachineInstr &MI = *BuildMI(&MBB, DL, get(RISCV::PseudoBR)).addMBB(TBB);
621     if (BytesAdded)
622       *BytesAdded += getInstSizeInBytes(MI);
623     return 1;
624   }
625 
626   // Either a one or two-way conditional branch.
627   unsigned Opc = Cond[0].getImm();
628   MachineInstr &CondMI =
629       *BuildMI(&MBB, DL, get(Opc)).add(Cond[1]).add(Cond[2]).addMBB(TBB);
630   if (BytesAdded)
631     *BytesAdded += getInstSizeInBytes(CondMI);
632 
633   // One-way conditional branch.
634   if (!FBB)
635     return 1;
636 
637   // Two-way conditional branch.
638   MachineInstr &MI = *BuildMI(&MBB, DL, get(RISCV::PseudoBR)).addMBB(FBB);
639   if (BytesAdded)
640     *BytesAdded += getInstSizeInBytes(MI);
641   return 2;
642 }
643 
insertIndirectBranch(MachineBasicBlock & MBB,MachineBasicBlock & DestBB,const DebugLoc & DL,int64_t BrOffset,RegScavenger * RS) const644 unsigned RISCVInstrInfo::insertIndirectBranch(MachineBasicBlock &MBB,
645                                               MachineBasicBlock &DestBB,
646                                               const DebugLoc &DL,
647                                               int64_t BrOffset,
648                                               RegScavenger *RS) const {
649   assert(RS && "RegScavenger required for long branching");
650   assert(MBB.empty() &&
651          "new block should be inserted for expanding unconditional branch");
652   assert(MBB.pred_size() == 1);
653 
654   MachineFunction *MF = MBB.getParent();
655   MachineRegisterInfo &MRI = MF->getRegInfo();
656 
657   if (!isInt<32>(BrOffset))
658     report_fatal_error(
659         "Branch offsets outside of the signed 32-bit range not supported");
660 
661   // FIXME: A virtual register must be used initially, as the register
662   // scavenger won't work with empty blocks (SIInstrInfo::insertIndirectBranch
663   // uses the same workaround).
664   Register ScratchReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
665   auto II = MBB.end();
666 
667   MachineInstr &MI = *BuildMI(MBB, II, DL, get(RISCV::PseudoJump))
668                           .addReg(ScratchReg, RegState::Define | RegState::Dead)
669                           .addMBB(&DestBB, RISCVII::MO_CALL);
670 
671   RS->enterBasicBlockEnd(MBB);
672   unsigned Scav = RS->scavengeRegisterBackwards(RISCV::GPRRegClass,
673                                                 MI.getIterator(), false, 0);
674   MRI.replaceRegWith(ScratchReg, Scav);
675   MRI.clearVirtRegs();
676   RS->setRegUsed(Scav);
677   return 8;
678 }
679 
reverseBranchCondition(SmallVectorImpl<MachineOperand> & Cond) const680 bool RISCVInstrInfo::reverseBranchCondition(
681     SmallVectorImpl<MachineOperand> &Cond) const {
682   assert((Cond.size() == 3) && "Invalid branch condition!");
683   Cond[0].setImm(getOppositeBranchOpcode(Cond[0].getImm()));
684   return false;
685 }
686 
687 MachineBasicBlock *
getBranchDestBlock(const MachineInstr & MI) const688 RISCVInstrInfo::getBranchDestBlock(const MachineInstr &MI) const {
689   assert(MI.getDesc().isBranch() && "Unexpected opcode!");
690   // The branch target is always the last operand.
691   int NumOp = MI.getNumExplicitOperands();
692   return MI.getOperand(NumOp - 1).getMBB();
693 }
694 
isBranchOffsetInRange(unsigned BranchOp,int64_t BrOffset) const695 bool RISCVInstrInfo::isBranchOffsetInRange(unsigned BranchOp,
696                                            int64_t BrOffset) const {
697   unsigned XLen = STI.getXLen();
698   // Ideally we could determine the supported branch offset from the
699   // RISCVII::FormMask, but this can't be used for Pseudo instructions like
700   // PseudoBR.
701   switch (BranchOp) {
702   default:
703     llvm_unreachable("Unexpected opcode!");
704   case RISCV::BEQ:
705   case RISCV::BNE:
706   case RISCV::BLT:
707   case RISCV::BGE:
708   case RISCV::BLTU:
709   case RISCV::BGEU:
710     return isIntN(13, BrOffset);
711   case RISCV::JAL:
712   case RISCV::PseudoBR:
713     return isIntN(21, BrOffset);
714   case RISCV::PseudoJump:
715     return isIntN(32, SignExtend64(BrOffset + 0x800, XLen));
716   }
717 }
718 
getInstSizeInBytes(const MachineInstr & MI) const719 unsigned RISCVInstrInfo::getInstSizeInBytes(const MachineInstr &MI) const {
720   unsigned Opcode = MI.getOpcode();
721 
722   switch (Opcode) {
723   default: {
724     if (MI.getParent() && MI.getParent()->getParent()) {
725       const auto MF = MI.getMF();
726       const auto &TM = static_cast<const RISCVTargetMachine &>(MF->getTarget());
727       const MCRegisterInfo &MRI = *TM.getMCRegisterInfo();
728       const MCSubtargetInfo &STI = *TM.getMCSubtargetInfo();
729       const RISCVSubtarget &ST = MF->getSubtarget<RISCVSubtarget>();
730       if (isCompressibleInst(MI, &ST, MRI, STI))
731         return 2;
732     }
733     return get(Opcode).getSize();
734   }
735   case TargetOpcode::EH_LABEL:
736   case TargetOpcode::IMPLICIT_DEF:
737   case TargetOpcode::KILL:
738   case TargetOpcode::DBG_VALUE:
739     return 0;
740   // These values are determined based on RISCVExpandAtomicPseudoInsts,
741   // RISCVExpandPseudoInsts and RISCVMCCodeEmitter, depending on where the
742   // pseudos are expanded.
743   case RISCV::PseudoCALLReg:
744   case RISCV::PseudoCALL:
745   case RISCV::PseudoJump:
746   case RISCV::PseudoTAIL:
747   case RISCV::PseudoLLA:
748   case RISCV::PseudoLA:
749   case RISCV::PseudoLA_TLS_IE:
750   case RISCV::PseudoLA_TLS_GD:
751     return 8;
752   case RISCV::PseudoAtomicLoadNand32:
753   case RISCV::PseudoAtomicLoadNand64:
754     return 20;
755   case RISCV::PseudoMaskedAtomicSwap32:
756   case RISCV::PseudoMaskedAtomicLoadAdd32:
757   case RISCV::PseudoMaskedAtomicLoadSub32:
758     return 28;
759   case RISCV::PseudoMaskedAtomicLoadNand32:
760     return 32;
761   case RISCV::PseudoMaskedAtomicLoadMax32:
762   case RISCV::PseudoMaskedAtomicLoadMin32:
763     return 44;
764   case RISCV::PseudoMaskedAtomicLoadUMax32:
765   case RISCV::PseudoMaskedAtomicLoadUMin32:
766     return 36;
767   case RISCV::PseudoCmpXchg32:
768   case RISCV::PseudoCmpXchg64:
769     return 16;
770   case RISCV::PseudoMaskedCmpXchg32:
771     return 32;
772   case TargetOpcode::INLINEASM:
773   case TargetOpcode::INLINEASM_BR: {
774     const MachineFunction &MF = *MI.getParent()->getParent();
775     const auto &TM = static_cast<const RISCVTargetMachine &>(MF.getTarget());
776     return getInlineAsmLength(MI.getOperand(0).getSymbolName(),
777                               *TM.getMCAsmInfo());
778   }
779   case RISCV::PseudoVSPILL2_M1:
780   case RISCV::PseudoVSPILL2_M2:
781   case RISCV::PseudoVSPILL2_M4:
782   case RISCV::PseudoVSPILL3_M1:
783   case RISCV::PseudoVSPILL3_M2:
784   case RISCV::PseudoVSPILL4_M1:
785   case RISCV::PseudoVSPILL4_M2:
786   case RISCV::PseudoVSPILL5_M1:
787   case RISCV::PseudoVSPILL6_M1:
788   case RISCV::PseudoVSPILL7_M1:
789   case RISCV::PseudoVSPILL8_M1:
790   case RISCV::PseudoVRELOAD2_M1:
791   case RISCV::PseudoVRELOAD2_M2:
792   case RISCV::PseudoVRELOAD2_M4:
793   case RISCV::PseudoVRELOAD3_M1:
794   case RISCV::PseudoVRELOAD3_M2:
795   case RISCV::PseudoVRELOAD4_M1:
796   case RISCV::PseudoVRELOAD4_M2:
797   case RISCV::PseudoVRELOAD5_M1:
798   case RISCV::PseudoVRELOAD6_M1:
799   case RISCV::PseudoVRELOAD7_M1:
800   case RISCV::PseudoVRELOAD8_M1: {
801     // The values are determined based on expandVSPILL and expandVRELOAD that
802     // expand the pseudos depending on NF.
803     unsigned NF = isRVVSpillForZvlsseg(Opcode)->first;
804     return 4 * (2 * NF - 1);
805   }
806   }
807 }
808 
isAsCheapAsAMove(const MachineInstr & MI) const809 bool RISCVInstrInfo::isAsCheapAsAMove(const MachineInstr &MI) const {
810   const unsigned Opcode = MI.getOpcode();
811   switch (Opcode) {
812   default:
813     break;
814   case RISCV::FSGNJ_D:
815   case RISCV::FSGNJ_S:
816     // The canonical floating-point move is fsgnj rd, rs, rs.
817     return MI.getOperand(1).isReg() && MI.getOperand(2).isReg() &&
818            MI.getOperand(1).getReg() == MI.getOperand(2).getReg();
819   case RISCV::ADDI:
820   case RISCV::ORI:
821   case RISCV::XORI:
822     return (MI.getOperand(1).isReg() &&
823             MI.getOperand(1).getReg() == RISCV::X0) ||
824            (MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0);
825   }
826   return MI.isAsCheapAsAMove();
827 }
828 
829 Optional<DestSourcePair>
isCopyInstrImpl(const MachineInstr & MI) const830 RISCVInstrInfo::isCopyInstrImpl(const MachineInstr &MI) const {
831   if (MI.isMoveReg())
832     return DestSourcePair{MI.getOperand(0), MI.getOperand(1)};
833   switch (MI.getOpcode()) {
834   default:
835     break;
836   case RISCV::ADDI:
837     // Operand 1 can be a frameindex but callers expect registers
838     if (MI.getOperand(1).isReg() && MI.getOperand(2).isImm() &&
839         MI.getOperand(2).getImm() == 0)
840       return DestSourcePair{MI.getOperand(0), MI.getOperand(1)};
841     break;
842   case RISCV::FSGNJ_D:
843   case RISCV::FSGNJ_S:
844     // The canonical floating-point move is fsgnj rd, rs, rs.
845     if (MI.getOperand(1).isReg() && MI.getOperand(2).isReg() &&
846         MI.getOperand(1).getReg() == MI.getOperand(2).getReg())
847       return DestSourcePair{MI.getOperand(0), MI.getOperand(1)};
848     break;
849   }
850   return None;
851 }
852 
verifyInstruction(const MachineInstr & MI,StringRef & ErrInfo) const853 bool RISCVInstrInfo::verifyInstruction(const MachineInstr &MI,
854                                        StringRef &ErrInfo) const {
855   const MCInstrInfo *MCII = STI.getInstrInfo();
856   MCInstrDesc const &Desc = MCII->get(MI.getOpcode());
857 
858   for (auto &OI : enumerate(Desc.operands())) {
859     unsigned OpType = OI.value().OperandType;
860     if (OpType >= RISCVOp::OPERAND_FIRST_RISCV_IMM &&
861         OpType <= RISCVOp::OPERAND_LAST_RISCV_IMM) {
862       const MachineOperand &MO = MI.getOperand(OI.index());
863       if (MO.isImm()) {
864         int64_t Imm = MO.getImm();
865         bool Ok;
866         switch (OpType) {
867         default:
868           llvm_unreachable("Unexpected operand type");
869         case RISCVOp::OPERAND_UIMM4:
870           Ok = isUInt<4>(Imm);
871           break;
872         case RISCVOp::OPERAND_UIMM5:
873           Ok = isUInt<5>(Imm);
874           break;
875         case RISCVOp::OPERAND_UIMM12:
876           Ok = isUInt<12>(Imm);
877           break;
878         case RISCVOp::OPERAND_SIMM12:
879           Ok = isInt<12>(Imm);
880           break;
881         case RISCVOp::OPERAND_UIMM20:
882           Ok = isUInt<20>(Imm);
883           break;
884         case RISCVOp::OPERAND_UIMMLOG2XLEN:
885           if (STI.getTargetTriple().isArch64Bit())
886             Ok = isUInt<6>(Imm);
887           else
888             Ok = isUInt<5>(Imm);
889           break;
890         }
891         if (!Ok) {
892           ErrInfo = "Invalid immediate";
893           return false;
894         }
895       }
896     }
897   }
898 
899   return true;
900 }
901 
902 // Return true if get the base operand, byte offset of an instruction and the
903 // memory width. Width is the size of memory that is being loaded/stored.
getMemOperandWithOffsetWidth(const MachineInstr & LdSt,const MachineOperand * & BaseReg,int64_t & Offset,unsigned & Width,const TargetRegisterInfo * TRI) const904 bool RISCVInstrInfo::getMemOperandWithOffsetWidth(
905     const MachineInstr &LdSt, const MachineOperand *&BaseReg, int64_t &Offset,
906     unsigned &Width, const TargetRegisterInfo *TRI) const {
907   if (!LdSt.mayLoadOrStore())
908     return false;
909 
910   // Here we assume the standard RISC-V ISA, which uses a base+offset
911   // addressing mode. You'll need to relax these conditions to support custom
912   // load/stores instructions.
913   if (LdSt.getNumExplicitOperands() != 3)
914     return false;
915   if (!LdSt.getOperand(1).isReg() || !LdSt.getOperand(2).isImm())
916     return false;
917 
918   if (!LdSt.hasOneMemOperand())
919     return false;
920 
921   Width = (*LdSt.memoperands_begin())->getSize();
922   BaseReg = &LdSt.getOperand(1);
923   Offset = LdSt.getOperand(2).getImm();
924   return true;
925 }
926 
areMemAccessesTriviallyDisjoint(const MachineInstr & MIa,const MachineInstr & MIb) const927 bool RISCVInstrInfo::areMemAccessesTriviallyDisjoint(
928     const MachineInstr &MIa, const MachineInstr &MIb) const {
929   assert(MIa.mayLoadOrStore() && "MIa must be a load or store.");
930   assert(MIb.mayLoadOrStore() && "MIb must be a load or store.");
931 
932   if (MIa.hasUnmodeledSideEffects() || MIb.hasUnmodeledSideEffects() ||
933       MIa.hasOrderedMemoryRef() || MIb.hasOrderedMemoryRef())
934     return false;
935 
936   // Retrieve the base register, offset from the base register and width. Width
937   // is the size of memory that is being loaded/stored (e.g. 1, 2, 4).  If
938   // base registers are identical, and the offset of a lower memory access +
939   // the width doesn't overlap the offset of a higher memory access,
940   // then the memory accesses are different.
941   const TargetRegisterInfo *TRI = STI.getRegisterInfo();
942   const MachineOperand *BaseOpA = nullptr, *BaseOpB = nullptr;
943   int64_t OffsetA = 0, OffsetB = 0;
944   unsigned int WidthA = 0, WidthB = 0;
945   if (getMemOperandWithOffsetWidth(MIa, BaseOpA, OffsetA, WidthA, TRI) &&
946       getMemOperandWithOffsetWidth(MIb, BaseOpB, OffsetB, WidthB, TRI)) {
947     if (BaseOpA->isIdenticalTo(*BaseOpB)) {
948       int LowOffset = std::min(OffsetA, OffsetB);
949       int HighOffset = std::max(OffsetA, OffsetB);
950       int LowWidth = (LowOffset == OffsetA) ? WidthA : WidthB;
951       if (LowOffset + LowWidth <= HighOffset)
952         return true;
953     }
954   }
955   return false;
956 }
957 
958 std::pair<unsigned, unsigned>
decomposeMachineOperandsTargetFlags(unsigned TF) const959 RISCVInstrInfo::decomposeMachineOperandsTargetFlags(unsigned TF) const {
960   const unsigned Mask = RISCVII::MO_DIRECT_FLAG_MASK;
961   return std::make_pair(TF & Mask, TF & ~Mask);
962 }
963 
964 ArrayRef<std::pair<unsigned, const char *>>
getSerializableDirectMachineOperandTargetFlags() const965 RISCVInstrInfo::getSerializableDirectMachineOperandTargetFlags() const {
966   using namespace RISCVII;
967   static const std::pair<unsigned, const char *> TargetFlags[] = {
968       {MO_CALL, "riscv-call"},
969       {MO_PLT, "riscv-plt"},
970       {MO_LO, "riscv-lo"},
971       {MO_HI, "riscv-hi"},
972       {MO_PCREL_LO, "riscv-pcrel-lo"},
973       {MO_PCREL_HI, "riscv-pcrel-hi"},
974       {MO_GOT_HI, "riscv-got-hi"},
975       {MO_TPREL_LO, "riscv-tprel-lo"},
976       {MO_TPREL_HI, "riscv-tprel-hi"},
977       {MO_TPREL_ADD, "riscv-tprel-add"},
978       {MO_TLS_GOT_HI, "riscv-tls-got-hi"},
979       {MO_TLS_GD_HI, "riscv-tls-gd-hi"}};
980   return makeArrayRef(TargetFlags);
981 }
isFunctionSafeToOutlineFrom(MachineFunction & MF,bool OutlineFromLinkOnceODRs) const982 bool RISCVInstrInfo::isFunctionSafeToOutlineFrom(
983     MachineFunction &MF, bool OutlineFromLinkOnceODRs) const {
984   const Function &F = MF.getFunction();
985 
986   // Can F be deduplicated by the linker? If it can, don't outline from it.
987   if (!OutlineFromLinkOnceODRs && F.hasLinkOnceODRLinkage())
988     return false;
989 
990   // Don't outline from functions with section markings; the program could
991   // expect that all the code is in the named section.
992   if (F.hasSection())
993     return false;
994 
995   // It's safe to outline from MF.
996   return true;
997 }
998 
isMBBSafeToOutlineFrom(MachineBasicBlock & MBB,unsigned & Flags) const999 bool RISCVInstrInfo::isMBBSafeToOutlineFrom(MachineBasicBlock &MBB,
1000                                             unsigned &Flags) const {
1001   // More accurate safety checking is done in getOutliningCandidateInfo.
1002   return TargetInstrInfo::isMBBSafeToOutlineFrom(MBB, Flags);
1003 }
1004 
1005 // Enum values indicating how an outlined call should be constructed.
1006 enum MachineOutlinerConstructionID {
1007   MachineOutlinerDefault
1008 };
1009 
getOutliningCandidateInfo(std::vector<outliner::Candidate> & RepeatedSequenceLocs) const1010 outliner::OutlinedFunction RISCVInstrInfo::getOutliningCandidateInfo(
1011     std::vector<outliner::Candidate> &RepeatedSequenceLocs) const {
1012 
1013   // First we need to filter out candidates where the X5 register (IE t0) can't
1014   // be used to setup the function call.
1015   auto CannotInsertCall = [](outliner::Candidate &C) {
1016     const TargetRegisterInfo *TRI = C.getMF()->getSubtarget().getRegisterInfo();
1017 
1018     C.initLRU(*TRI);
1019     LiveRegUnits LRU = C.LRU;
1020     return !LRU.available(RISCV::X5);
1021   };
1022 
1023   llvm::erase_if(RepeatedSequenceLocs, CannotInsertCall);
1024 
1025   // If the sequence doesn't have enough candidates left, then we're done.
1026   if (RepeatedSequenceLocs.size() < 2)
1027     return outliner::OutlinedFunction();
1028 
1029   unsigned SequenceSize = 0;
1030 
1031   auto I = RepeatedSequenceLocs[0].front();
1032   auto E = std::next(RepeatedSequenceLocs[0].back());
1033   for (; I != E; ++I)
1034     SequenceSize += getInstSizeInBytes(*I);
1035 
1036   // call t0, function = 8 bytes.
1037   unsigned CallOverhead = 8;
1038   for (auto &C : RepeatedSequenceLocs)
1039     C.setCallInfo(MachineOutlinerDefault, CallOverhead);
1040 
1041   // jr t0 = 4 bytes, 2 bytes if compressed instructions are enabled.
1042   unsigned FrameOverhead = 4;
1043   if (RepeatedSequenceLocs[0].getMF()->getSubtarget()
1044           .getFeatureBits()[RISCV::FeatureStdExtC])
1045     FrameOverhead = 2;
1046 
1047   return outliner::OutlinedFunction(RepeatedSequenceLocs, SequenceSize,
1048                                     FrameOverhead, MachineOutlinerDefault);
1049 }
1050 
1051 outliner::InstrType
getOutliningType(MachineBasicBlock::iterator & MBBI,unsigned Flags) const1052 RISCVInstrInfo::getOutliningType(MachineBasicBlock::iterator &MBBI,
1053                                  unsigned Flags) const {
1054   MachineInstr &MI = *MBBI;
1055   MachineBasicBlock *MBB = MI.getParent();
1056   const TargetRegisterInfo *TRI =
1057       MBB->getParent()->getSubtarget().getRegisterInfo();
1058 
1059   // Positions generally can't safely be outlined.
1060   if (MI.isPosition()) {
1061     // We can manually strip out CFI instructions later.
1062     if (MI.isCFIInstruction())
1063       return outliner::InstrType::Invisible;
1064 
1065     return outliner::InstrType::Illegal;
1066   }
1067 
1068   // Don't trust the user to write safe inline assembly.
1069   if (MI.isInlineAsm())
1070     return outliner::InstrType::Illegal;
1071 
1072   // We can't outline branches to other basic blocks.
1073   if (MI.isTerminator() && !MBB->succ_empty())
1074     return outliner::InstrType::Illegal;
1075 
1076   // We need support for tail calls to outlined functions before return
1077   // statements can be allowed.
1078   if (MI.isReturn())
1079     return outliner::InstrType::Illegal;
1080 
1081   // Don't allow modifying the X5 register which we use for return addresses for
1082   // these outlined functions.
1083   if (MI.modifiesRegister(RISCV::X5, TRI) ||
1084       MI.getDesc().hasImplicitDefOfPhysReg(RISCV::X5))
1085     return outliner::InstrType::Illegal;
1086 
1087   // Make sure the operands don't reference something unsafe.
1088   for (const auto &MO : MI.operands())
1089     if (MO.isMBB() || MO.isBlockAddress() || MO.isCPI())
1090       return outliner::InstrType::Illegal;
1091 
1092   // Don't allow instructions which won't be materialized to impact outlining
1093   // analysis.
1094   if (MI.isMetaInstruction())
1095     return outliner::InstrType::Invisible;
1096 
1097   return outliner::InstrType::Legal;
1098 }
1099 
buildOutlinedFrame(MachineBasicBlock & MBB,MachineFunction & MF,const outliner::OutlinedFunction & OF) const1100 void RISCVInstrInfo::buildOutlinedFrame(
1101     MachineBasicBlock &MBB, MachineFunction &MF,
1102     const outliner::OutlinedFunction &OF) const {
1103 
1104   // Strip out any CFI instructions
1105   bool Changed = true;
1106   while (Changed) {
1107     Changed = false;
1108     auto I = MBB.begin();
1109     auto E = MBB.end();
1110     for (; I != E; ++I) {
1111       if (I->isCFIInstruction()) {
1112         I->removeFromParent();
1113         Changed = true;
1114         break;
1115       }
1116     }
1117   }
1118 
1119   MBB.addLiveIn(RISCV::X5);
1120 
1121   // Add in a return instruction to the end of the outlined frame.
1122   MBB.insert(MBB.end(), BuildMI(MF, DebugLoc(), get(RISCV::JALR))
1123       .addReg(RISCV::X0, RegState::Define)
1124       .addReg(RISCV::X5)
1125       .addImm(0));
1126 }
1127 
insertOutlinedCall(Module & M,MachineBasicBlock & MBB,MachineBasicBlock::iterator & It,MachineFunction & MF,const outliner::Candidate & C) const1128 MachineBasicBlock::iterator RISCVInstrInfo::insertOutlinedCall(
1129     Module &M, MachineBasicBlock &MBB, MachineBasicBlock::iterator &It,
1130     MachineFunction &MF, const outliner::Candidate &C) const {
1131 
1132   // Add in a call instruction to the outlined function at the given location.
1133   It = MBB.insert(It,
1134                   BuildMI(MF, DebugLoc(), get(RISCV::PseudoCALLReg), RISCV::X5)
1135                       .addGlobalAddress(M.getNamedValue(MF.getName()), 0,
1136                                         RISCVII::MO_CALL));
1137   return It;
1138 }
1139 
1140 // clang-format off
1141 #define CASE_VFMA_OPCODE_COMMON(OP, TYPE, LMUL)                                \
1142   RISCV::PseudoV##OP##_##TYPE##_##LMUL##_COMMUTABLE
1143 
1144 #define CASE_VFMA_OPCODE_LMULS(OP, TYPE)                                       \
1145   CASE_VFMA_OPCODE_COMMON(OP, TYPE, MF8):                                      \
1146   case CASE_VFMA_OPCODE_COMMON(OP, TYPE, MF4):                                 \
1147   case CASE_VFMA_OPCODE_COMMON(OP, TYPE, MF2):                                 \
1148   case CASE_VFMA_OPCODE_COMMON(OP, TYPE, M1):                                  \
1149   case CASE_VFMA_OPCODE_COMMON(OP, TYPE, M2):                                  \
1150   case CASE_VFMA_OPCODE_COMMON(OP, TYPE, M4):                                  \
1151   case CASE_VFMA_OPCODE_COMMON(OP, TYPE, M8)
1152 
1153 #define CASE_VFMA_SPLATS(OP)                                                   \
1154   CASE_VFMA_OPCODE_LMULS(OP, VF16):                                            \
1155   case CASE_VFMA_OPCODE_LMULS(OP, VF32):                                       \
1156   case CASE_VFMA_OPCODE_LMULS(OP, VF64)
1157 // clang-format on
1158 
findCommutedOpIndices(const MachineInstr & MI,unsigned & SrcOpIdx1,unsigned & SrcOpIdx2) const1159 bool RISCVInstrInfo::findCommutedOpIndices(const MachineInstr &MI,
1160                                            unsigned &SrcOpIdx1,
1161                                            unsigned &SrcOpIdx2) const {
1162   const MCInstrDesc &Desc = MI.getDesc();
1163   if (!Desc.isCommutable())
1164     return false;
1165 
1166   switch (MI.getOpcode()) {
1167   case CASE_VFMA_SPLATS(FMADD):
1168   case CASE_VFMA_SPLATS(FMSUB):
1169   case CASE_VFMA_SPLATS(FMACC):
1170   case CASE_VFMA_SPLATS(FMSAC):
1171   case CASE_VFMA_SPLATS(FNMADD):
1172   case CASE_VFMA_SPLATS(FNMSUB):
1173   case CASE_VFMA_SPLATS(FNMACC):
1174   case CASE_VFMA_SPLATS(FNMSAC):
1175   case CASE_VFMA_OPCODE_LMULS(FMACC, VV):
1176   case CASE_VFMA_OPCODE_LMULS(FMSAC, VV):
1177   case CASE_VFMA_OPCODE_LMULS(FNMACC, VV):
1178   case CASE_VFMA_OPCODE_LMULS(FNMSAC, VV):
1179   case CASE_VFMA_OPCODE_LMULS(MADD, VX):
1180   case CASE_VFMA_OPCODE_LMULS(NMSUB, VX):
1181   case CASE_VFMA_OPCODE_LMULS(MACC, VX):
1182   case CASE_VFMA_OPCODE_LMULS(NMSAC, VX):
1183   case CASE_VFMA_OPCODE_LMULS(MACC, VV):
1184   case CASE_VFMA_OPCODE_LMULS(NMSAC, VV): {
1185     // For these instructions we can only swap operand 1 and operand 3 by
1186     // changing the opcode.
1187     unsigned CommutableOpIdx1 = 1;
1188     unsigned CommutableOpIdx2 = 3;
1189     if (!fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, CommutableOpIdx1,
1190                               CommutableOpIdx2))
1191       return false;
1192     return true;
1193   }
1194   case CASE_VFMA_OPCODE_LMULS(FMADD, VV):
1195   case CASE_VFMA_OPCODE_LMULS(FMSUB, VV):
1196   case CASE_VFMA_OPCODE_LMULS(FNMADD, VV):
1197   case CASE_VFMA_OPCODE_LMULS(FNMSUB, VV):
1198   case CASE_VFMA_OPCODE_LMULS(MADD, VV):
1199   case CASE_VFMA_OPCODE_LMULS(NMSUB, VV): {
1200     // For these instructions we have more freedom. We can commute with the
1201     // other multiplicand or with the addend/subtrahend/minuend.
1202 
1203     // Any fixed operand must be from source 1, 2 or 3.
1204     if (SrcOpIdx1 != CommuteAnyOperandIndex && SrcOpIdx1 > 3)
1205       return false;
1206     if (SrcOpIdx2 != CommuteAnyOperandIndex && SrcOpIdx2 > 3)
1207       return false;
1208 
1209     // It both ops are fixed one must be the tied source.
1210     if (SrcOpIdx1 != CommuteAnyOperandIndex &&
1211         SrcOpIdx2 != CommuteAnyOperandIndex && SrcOpIdx1 != 1 && SrcOpIdx2 != 1)
1212       return false;
1213 
1214     // Look for two different register operands assumed to be commutable
1215     // regardless of the FMA opcode. The FMA opcode is adjusted later if
1216     // needed.
1217     if (SrcOpIdx1 == CommuteAnyOperandIndex ||
1218         SrcOpIdx2 == CommuteAnyOperandIndex) {
1219       // At least one of operands to be commuted is not specified and
1220       // this method is free to choose appropriate commutable operands.
1221       unsigned CommutableOpIdx1 = SrcOpIdx1;
1222       if (SrcOpIdx1 == SrcOpIdx2) {
1223         // Both of operands are not fixed. Set one of commutable
1224         // operands to the tied source.
1225         CommutableOpIdx1 = 1;
1226       } else if (SrcOpIdx1 == CommuteAnyOperandIndex) {
1227         // Only one of the operands is not fixed.
1228         CommutableOpIdx1 = SrcOpIdx2;
1229       }
1230 
1231       // CommutableOpIdx1 is well defined now. Let's choose another commutable
1232       // operand and assign its index to CommutableOpIdx2.
1233       unsigned CommutableOpIdx2;
1234       if (CommutableOpIdx1 != 1) {
1235         // If we haven't already used the tied source, we must use it now.
1236         CommutableOpIdx2 = 1;
1237       } else {
1238         Register Op1Reg = MI.getOperand(CommutableOpIdx1).getReg();
1239 
1240         // The commuted operands should have different registers.
1241         // Otherwise, the commute transformation does not change anything and
1242         // is useless. We use this as a hint to make our decision.
1243         if (Op1Reg != MI.getOperand(2).getReg())
1244           CommutableOpIdx2 = 2;
1245         else
1246           CommutableOpIdx2 = 3;
1247       }
1248 
1249       // Assign the found pair of commutable indices to SrcOpIdx1 and
1250       // SrcOpIdx2 to return those values.
1251       if (!fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, CommutableOpIdx1,
1252                                 CommutableOpIdx2))
1253         return false;
1254     }
1255 
1256     return true;
1257   }
1258   }
1259 
1260   return TargetInstrInfo::findCommutedOpIndices(MI, SrcOpIdx1, SrcOpIdx2);
1261 }
1262 
1263 #define CASE_VFMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, LMUL)               \
1264   case RISCV::PseudoV##OLDOP##_##TYPE##_##LMUL##_COMMUTABLE:                   \
1265     Opc = RISCV::PseudoV##NEWOP##_##TYPE##_##LMUL##_COMMUTABLE;                \
1266     break;
1267 
1268 #define CASE_VFMA_CHANGE_OPCODE_LMULS(OLDOP, NEWOP, TYPE)                      \
1269   CASE_VFMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, MF8)                      \
1270   CASE_VFMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, MF4)                      \
1271   CASE_VFMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, MF2)                      \
1272   CASE_VFMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, M1)                       \
1273   CASE_VFMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, M2)                       \
1274   CASE_VFMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, M4)                       \
1275   CASE_VFMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, M8)
1276 
1277 #define CASE_VFMA_CHANGE_OPCODE_SPLATS(OLDOP, NEWOP)                           \
1278   CASE_VFMA_CHANGE_OPCODE_LMULS(OLDOP, NEWOP, VF16)                            \
1279   CASE_VFMA_CHANGE_OPCODE_LMULS(OLDOP, NEWOP, VF32)                            \
1280   CASE_VFMA_CHANGE_OPCODE_LMULS(OLDOP, NEWOP, VF64)
1281 
commuteInstructionImpl(MachineInstr & MI,bool NewMI,unsigned OpIdx1,unsigned OpIdx2) const1282 MachineInstr *RISCVInstrInfo::commuteInstructionImpl(MachineInstr &MI,
1283                                                      bool NewMI,
1284                                                      unsigned OpIdx1,
1285                                                      unsigned OpIdx2) const {
1286   auto cloneIfNew = [NewMI](MachineInstr &MI) -> MachineInstr & {
1287     if (NewMI)
1288       return *MI.getParent()->getParent()->CloneMachineInstr(&MI);
1289     return MI;
1290   };
1291 
1292   switch (MI.getOpcode()) {
1293   case CASE_VFMA_SPLATS(FMACC):
1294   case CASE_VFMA_SPLATS(FMADD):
1295   case CASE_VFMA_SPLATS(FMSAC):
1296   case CASE_VFMA_SPLATS(FMSUB):
1297   case CASE_VFMA_SPLATS(FNMACC):
1298   case CASE_VFMA_SPLATS(FNMADD):
1299   case CASE_VFMA_SPLATS(FNMSAC):
1300   case CASE_VFMA_SPLATS(FNMSUB):
1301   case CASE_VFMA_OPCODE_LMULS(FMACC, VV):
1302   case CASE_VFMA_OPCODE_LMULS(FMSAC, VV):
1303   case CASE_VFMA_OPCODE_LMULS(FNMACC, VV):
1304   case CASE_VFMA_OPCODE_LMULS(FNMSAC, VV):
1305   case CASE_VFMA_OPCODE_LMULS(MADD, VX):
1306   case CASE_VFMA_OPCODE_LMULS(NMSUB, VX):
1307   case CASE_VFMA_OPCODE_LMULS(MACC, VX):
1308   case CASE_VFMA_OPCODE_LMULS(NMSAC, VX):
1309   case CASE_VFMA_OPCODE_LMULS(MACC, VV):
1310   case CASE_VFMA_OPCODE_LMULS(NMSAC, VV): {
1311     // It only make sense to toggle these between clobbering the
1312     // addend/subtrahend/minuend one of the multiplicands.
1313     assert((OpIdx1 == 1 || OpIdx2 == 1) && "Unexpected opcode index");
1314     assert((OpIdx1 == 3 || OpIdx2 == 3) && "Unexpected opcode index");
1315     unsigned Opc;
1316     switch (MI.getOpcode()) {
1317       default:
1318         llvm_unreachable("Unexpected opcode");
1319       CASE_VFMA_CHANGE_OPCODE_SPLATS(FMACC, FMADD)
1320       CASE_VFMA_CHANGE_OPCODE_SPLATS(FMADD, FMACC)
1321       CASE_VFMA_CHANGE_OPCODE_SPLATS(FMSAC, FMSUB)
1322       CASE_VFMA_CHANGE_OPCODE_SPLATS(FMSUB, FMSAC)
1323       CASE_VFMA_CHANGE_OPCODE_SPLATS(FNMACC, FNMADD)
1324       CASE_VFMA_CHANGE_OPCODE_SPLATS(FNMADD, FNMACC)
1325       CASE_VFMA_CHANGE_OPCODE_SPLATS(FNMSAC, FNMSUB)
1326       CASE_VFMA_CHANGE_OPCODE_SPLATS(FNMSUB, FNMSAC)
1327       CASE_VFMA_CHANGE_OPCODE_LMULS(FMACC, FMADD, VV)
1328       CASE_VFMA_CHANGE_OPCODE_LMULS(FMSAC, FMSUB, VV)
1329       CASE_VFMA_CHANGE_OPCODE_LMULS(FNMACC, FNMADD, VV)
1330       CASE_VFMA_CHANGE_OPCODE_LMULS(FNMSAC, FNMSUB, VV)
1331       CASE_VFMA_CHANGE_OPCODE_LMULS(MACC, MADD, VX)
1332       CASE_VFMA_CHANGE_OPCODE_LMULS(MADD, MACC, VX)
1333       CASE_VFMA_CHANGE_OPCODE_LMULS(NMSAC, NMSUB, VX)
1334       CASE_VFMA_CHANGE_OPCODE_LMULS(NMSUB, NMSAC, VX)
1335       CASE_VFMA_CHANGE_OPCODE_LMULS(MACC, MADD, VV)
1336       CASE_VFMA_CHANGE_OPCODE_LMULS(NMSAC, NMSUB, VV)
1337     }
1338 
1339     auto &WorkingMI = cloneIfNew(MI);
1340     WorkingMI.setDesc(get(Opc));
1341     return TargetInstrInfo::commuteInstructionImpl(WorkingMI, /*NewMI=*/false,
1342                                                    OpIdx1, OpIdx2);
1343   }
1344   case CASE_VFMA_OPCODE_LMULS(FMADD, VV):
1345   case CASE_VFMA_OPCODE_LMULS(FMSUB, VV):
1346   case CASE_VFMA_OPCODE_LMULS(FNMADD, VV):
1347   case CASE_VFMA_OPCODE_LMULS(FNMSUB, VV):
1348   case CASE_VFMA_OPCODE_LMULS(MADD, VV):
1349   case CASE_VFMA_OPCODE_LMULS(NMSUB, VV): {
1350     assert((OpIdx1 == 1 || OpIdx2 == 1) && "Unexpected opcode index");
1351     // If one of the operands, is the addend we need to change opcode.
1352     // Otherwise we're just swapping 2 of the multiplicands.
1353     if (OpIdx1 == 3 || OpIdx2 == 3) {
1354       unsigned Opc;
1355       switch (MI.getOpcode()) {
1356         default:
1357           llvm_unreachable("Unexpected opcode");
1358         CASE_VFMA_CHANGE_OPCODE_LMULS(FMADD, FMACC, VV)
1359         CASE_VFMA_CHANGE_OPCODE_LMULS(FMSUB, FMSAC, VV)
1360         CASE_VFMA_CHANGE_OPCODE_LMULS(FNMADD, FNMACC, VV)
1361         CASE_VFMA_CHANGE_OPCODE_LMULS(FNMSUB, FNMSAC, VV)
1362         CASE_VFMA_CHANGE_OPCODE_LMULS(MADD, MACC, VV)
1363         CASE_VFMA_CHANGE_OPCODE_LMULS(NMSUB, NMSAC, VV)
1364       }
1365 
1366       auto &WorkingMI = cloneIfNew(MI);
1367       WorkingMI.setDesc(get(Opc));
1368       return TargetInstrInfo::commuteInstructionImpl(WorkingMI, /*NewMI=*/false,
1369                                                      OpIdx1, OpIdx2);
1370     }
1371     // Let the default code handle it.
1372     break;
1373   }
1374   }
1375 
1376   return TargetInstrInfo::commuteInstructionImpl(MI, NewMI, OpIdx1, OpIdx2);
1377 }
1378 
1379 #undef CASE_VFMA_CHANGE_OPCODE_SPLATS
1380 #undef CASE_VFMA_CHANGE_OPCODE_LMULS
1381 #undef CASE_VFMA_CHANGE_OPCODE_COMMON
1382 #undef CASE_VFMA_SPLATS
1383 #undef CASE_VFMA_OPCODE_LMULS
1384 #undef CASE_VFMA_OPCODE_COMMON
1385 
1386 // clang-format off
1387 #define CASE_WIDEOP_OPCODE_COMMON(OP, LMUL)                                    \
1388   RISCV::PseudoV##OP##_##LMUL##_TIED
1389 
1390 #define CASE_WIDEOP_OPCODE_LMULS(OP)                                           \
1391   CASE_WIDEOP_OPCODE_COMMON(OP, MF8):                                          \
1392   case CASE_WIDEOP_OPCODE_COMMON(OP, MF4):                                     \
1393   case CASE_WIDEOP_OPCODE_COMMON(OP, MF2):                                     \
1394   case CASE_WIDEOP_OPCODE_COMMON(OP, M1):                                      \
1395   case CASE_WIDEOP_OPCODE_COMMON(OP, M2):                                      \
1396   case CASE_WIDEOP_OPCODE_COMMON(OP, M4)
1397 // clang-format on
1398 
1399 #define CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, LMUL)                             \
1400   case RISCV::PseudoV##OP##_##LMUL##_TIED:                                     \
1401     NewOpc = RISCV::PseudoV##OP##_##LMUL;                                      \
1402     break;
1403 
1404 #define CASE_WIDEOP_CHANGE_OPCODE_LMULS(OP)                                    \
1405   CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF8)                                    \
1406   CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF4)                                    \
1407   CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF2)                                    \
1408   CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, M1)                                     \
1409   CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, M2)                                     \
1410   CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, M4)
1411 
convertToThreeAddress(MachineFunction::iterator & MBB,MachineInstr & MI,LiveVariables * LV) const1412 MachineInstr *RISCVInstrInfo::convertToThreeAddress(
1413     MachineFunction::iterator &MBB, MachineInstr &MI, LiveVariables *LV) const {
1414   switch (MI.getOpcode()) {
1415   default:
1416     break;
1417   case CASE_WIDEOP_OPCODE_LMULS(FWADD_WV):
1418   case CASE_WIDEOP_OPCODE_LMULS(FWSUB_WV):
1419   case CASE_WIDEOP_OPCODE_LMULS(WADD_WV):
1420   case CASE_WIDEOP_OPCODE_LMULS(WADDU_WV):
1421   case CASE_WIDEOP_OPCODE_LMULS(WSUB_WV):
1422   case CASE_WIDEOP_OPCODE_LMULS(WSUBU_WV): {
1423     // clang-format off
1424     unsigned NewOpc;
1425     switch (MI.getOpcode()) {
1426     default:
1427       llvm_unreachable("Unexpected opcode");
1428     CASE_WIDEOP_CHANGE_OPCODE_LMULS(FWADD_WV)
1429     CASE_WIDEOP_CHANGE_OPCODE_LMULS(FWSUB_WV)
1430     CASE_WIDEOP_CHANGE_OPCODE_LMULS(WADD_WV)
1431     CASE_WIDEOP_CHANGE_OPCODE_LMULS(WADDU_WV)
1432     CASE_WIDEOP_CHANGE_OPCODE_LMULS(WSUB_WV)
1433     CASE_WIDEOP_CHANGE_OPCODE_LMULS(WSUBU_WV)
1434     }
1435     //clang-format on
1436 
1437     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI.getDebugLoc(), get(NewOpc))
1438                                   .add(MI.getOperand(0))
1439                                   .add(MI.getOperand(1))
1440                                   .add(MI.getOperand(2))
1441                                   .add(MI.getOperand(3))
1442                                   .add(MI.getOperand(4));
1443     MIB.copyImplicitOps(MI);
1444 
1445     if (LV) {
1446       unsigned NumOps = MI.getNumOperands();
1447       for (unsigned I = 1; I < NumOps; ++I) {
1448         MachineOperand &Op = MI.getOperand(I);
1449         if (Op.isReg() && Op.isKill())
1450           LV->replaceKillInstruction(Op.getReg(), MI, *MIB);
1451       }
1452     }
1453 
1454     return MIB;
1455   }
1456   }
1457 
1458   return nullptr;
1459 }
1460 
1461 #undef CASE_WIDEOP_CHANGE_OPCODE_LMULS
1462 #undef CASE_WIDEOP_CHANGE_OPCODE_COMMON
1463 #undef CASE_WIDEOP_OPCODE_LMULS
1464 #undef CASE_WIDEOP_OPCODE_COMMON
1465 
getVLENFactoredAmount(MachineFunction & MF,MachineBasicBlock & MBB,MachineBasicBlock::iterator II,const DebugLoc & DL,int64_t Amount,MachineInstr::MIFlag Flag) const1466 Register RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
1467                                                MachineBasicBlock &MBB,
1468                                                MachineBasicBlock::iterator II,
1469                                                const DebugLoc &DL,
1470                                                int64_t Amount,
1471                                                MachineInstr::MIFlag Flag) const {
1472   assert(Amount > 0 && "There is no need to get VLEN scaled value.");
1473   assert(Amount % 8 == 0 &&
1474          "Reserve the stack by the multiple of one vector size.");
1475 
1476   MachineRegisterInfo &MRI = MF.getRegInfo();
1477   const RISCVInstrInfo *TII = MF.getSubtarget<RISCVSubtarget>().getInstrInfo();
1478   int64_t NumOfVReg = Amount / 8;
1479 
1480   Register VL = MRI.createVirtualRegister(&RISCV::GPRRegClass);
1481   BuildMI(MBB, II, DL, TII->get(RISCV::PseudoReadVLENB), VL)
1482     .setMIFlag(Flag);
1483   assert(isInt<32>(NumOfVReg) &&
1484          "Expect the number of vector registers within 32-bits.");
1485   if (isPowerOf2_32(NumOfVReg)) {
1486     uint32_t ShiftAmount = Log2_32(NumOfVReg);
1487     if (ShiftAmount == 0)
1488       return VL;
1489     BuildMI(MBB, II, DL, TII->get(RISCV::SLLI), VL)
1490         .addReg(VL, RegState::Kill)
1491         .addImm(ShiftAmount)
1492         .setMIFlag(Flag);
1493   } else if (isPowerOf2_32(NumOfVReg - 1)) {
1494     Register ScaledRegister = MRI.createVirtualRegister(&RISCV::GPRRegClass);
1495     uint32_t ShiftAmount = Log2_32(NumOfVReg - 1);
1496     BuildMI(MBB, II, DL, TII->get(RISCV::SLLI), ScaledRegister)
1497         .addReg(VL)
1498         .addImm(ShiftAmount)
1499         .setMIFlag(Flag);
1500     BuildMI(MBB, II, DL, TII->get(RISCV::ADD), VL)
1501         .addReg(ScaledRegister, RegState::Kill)
1502         .addReg(VL, RegState::Kill)
1503         .setMIFlag(Flag);
1504   } else if (isPowerOf2_32(NumOfVReg + 1)) {
1505     Register ScaledRegister = MRI.createVirtualRegister(&RISCV::GPRRegClass);
1506     uint32_t ShiftAmount = Log2_32(NumOfVReg + 1);
1507     BuildMI(MBB, II, DL, TII->get(RISCV::SLLI), ScaledRegister)
1508         .addReg(VL)
1509         .addImm(ShiftAmount)
1510         .setMIFlag(Flag);
1511     BuildMI(MBB, II, DL, TII->get(RISCV::SUB), VL)
1512         .addReg(ScaledRegister, RegState::Kill)
1513         .addReg(VL, RegState::Kill)
1514         .setMIFlag(Flag);
1515   } else {
1516     Register N = MRI.createVirtualRegister(&RISCV::GPRRegClass);
1517     if (!isInt<12>(NumOfVReg))
1518       movImm(MBB, II, DL, N, NumOfVReg);
1519     else {
1520       BuildMI(MBB, II, DL, TII->get(RISCV::ADDI), N)
1521           .addReg(RISCV::X0)
1522           .addImm(NumOfVReg)
1523           .setMIFlag(Flag);
1524     }
1525     if (!MF.getSubtarget<RISCVSubtarget>().hasStdExtM())
1526       MF.getFunction().getContext().diagnose(DiagnosticInfoUnsupported{
1527           MF.getFunction(),
1528           "M-extension must be enabled to calculate the vscaled size/offset."});
1529     BuildMI(MBB, II, DL, TII->get(RISCV::MUL), VL)
1530         .addReg(VL, RegState::Kill)
1531         .addReg(N, RegState::Kill)
1532         .setMIFlag(Flag);
1533   }
1534 
1535   return VL;
1536 }
1537 
isRVVWholeLoadStore(unsigned Opcode)1538 static bool isRVVWholeLoadStore(unsigned Opcode) {
1539   switch (Opcode) {
1540   default:
1541     return false;
1542   case RISCV::VS1R_V:
1543   case RISCV::VS2R_V:
1544   case RISCV::VS4R_V:
1545   case RISCV::VS8R_V:
1546   case RISCV::VL1RE8_V:
1547   case RISCV::VL2RE8_V:
1548   case RISCV::VL4RE8_V:
1549   case RISCV::VL8RE8_V:
1550   case RISCV::VL1RE16_V:
1551   case RISCV::VL2RE16_V:
1552   case RISCV::VL4RE16_V:
1553   case RISCV::VL8RE16_V:
1554   case RISCV::VL1RE32_V:
1555   case RISCV::VL2RE32_V:
1556   case RISCV::VL4RE32_V:
1557   case RISCV::VL8RE32_V:
1558   case RISCV::VL1RE64_V:
1559   case RISCV::VL2RE64_V:
1560   case RISCV::VL4RE64_V:
1561   case RISCV::VL8RE64_V:
1562     return true;
1563   }
1564 }
1565 
isRVVSpill(const MachineInstr & MI,bool CheckFIs) const1566 bool RISCVInstrInfo::isRVVSpill(const MachineInstr &MI, bool CheckFIs) const {
1567   // RVV lacks any support for immediate addressing for stack addresses, so be
1568   // conservative.
1569   unsigned Opcode = MI.getOpcode();
1570   if (!RISCVVPseudosTable::getPseudoInfo(Opcode) &&
1571       !isRVVWholeLoadStore(Opcode) && !isRVVSpillForZvlsseg(Opcode))
1572     return false;
1573   return !CheckFIs || any_of(MI.operands(), [](const MachineOperand &MO) {
1574     return MO.isFI();
1575   });
1576 }
1577 
1578 Optional<std::pair<unsigned, unsigned>>
isRVVSpillForZvlsseg(unsigned Opcode) const1579 RISCVInstrInfo::isRVVSpillForZvlsseg(unsigned Opcode) const {
1580   switch (Opcode) {
1581   default:
1582     return None;
1583   case RISCV::PseudoVSPILL2_M1:
1584   case RISCV::PseudoVRELOAD2_M1:
1585     return std::make_pair(2u, 1u);
1586   case RISCV::PseudoVSPILL2_M2:
1587   case RISCV::PseudoVRELOAD2_M2:
1588     return std::make_pair(2u, 2u);
1589   case RISCV::PseudoVSPILL2_M4:
1590   case RISCV::PseudoVRELOAD2_M4:
1591     return std::make_pair(2u, 4u);
1592   case RISCV::PseudoVSPILL3_M1:
1593   case RISCV::PseudoVRELOAD3_M1:
1594     return std::make_pair(3u, 1u);
1595   case RISCV::PseudoVSPILL3_M2:
1596   case RISCV::PseudoVRELOAD3_M2:
1597     return std::make_pair(3u, 2u);
1598   case RISCV::PseudoVSPILL4_M1:
1599   case RISCV::PseudoVRELOAD4_M1:
1600     return std::make_pair(4u, 1u);
1601   case RISCV::PseudoVSPILL4_M2:
1602   case RISCV::PseudoVRELOAD4_M2:
1603     return std::make_pair(4u, 2u);
1604   case RISCV::PseudoVSPILL5_M1:
1605   case RISCV::PseudoVRELOAD5_M1:
1606     return std::make_pair(5u, 1u);
1607   case RISCV::PseudoVSPILL6_M1:
1608   case RISCV::PseudoVRELOAD6_M1:
1609     return std::make_pair(6u, 1u);
1610   case RISCV::PseudoVSPILL7_M1:
1611   case RISCV::PseudoVRELOAD7_M1:
1612     return std::make_pair(7u, 1u);
1613   case RISCV::PseudoVSPILL8_M1:
1614   case RISCV::PseudoVRELOAD8_M1:
1615     return std::make_pair(8u, 1u);
1616   }
1617 }
1618