1 //===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 //
9 // This pass does some optimizations for *W instructions at the MI level.
10 //
11 // First it removes unneeded sext.w instructions. Either because the sign
12 // extended bits aren't consumed or because the input was already sign extended
13 // by an earlier instruction.
14 //
15 // Then it removes the -w suffix from opw instructions whenever all users are
16 // dependent only on the lower word of the result of the instruction.
17 // The cases handled are:
18 // * addw because c.add has a larger register encoding than c.addw.
19 // * addiw because it helps reduce test differences between RV32 and RV64
20 //   w/o being a pessimization.
21 // * mulw because c.mulw doesn't exist but c.mul does (w/ zcb)
22 // * slliw because c.slliw doesn't exist and c.slli does
23 //
24 //===---------------------------------------------------------------------===//
25 
26 #include "RISCV.h"
27 #include "RISCVMachineFunctionInfo.h"
28 #include "RISCVSubtarget.h"
29 #include "llvm/ADT/SmallSet.h"
30 #include "llvm/ADT/Statistic.h"
31 #include "llvm/CodeGen/MachineFunctionPass.h"
32 #include "llvm/CodeGen/TargetInstrInfo.h"
33 
34 using namespace llvm;
35 
36 #define DEBUG_TYPE "riscv-opt-w-instrs"
37 #define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions"
38 
39 STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions");
40 STATISTIC(NumTransformedToWInstrs,
41           "Number of instructions transformed to W-ops");
42 
43 static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal",
44                                          cl::desc("Disable removal of sext.w"),
45                                          cl::init(false), cl::Hidden);
46 static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix",
47                                          cl::desc("Disable strip W suffix"),
48                                          cl::init(false), cl::Hidden);
49 
50 namespace {
51 
52 class RISCVOptWInstrs : public MachineFunctionPass {
53 public:
54   static char ID;
55 
RISCVOptWInstrs()56   RISCVOptWInstrs() : MachineFunctionPass(ID) {}
57 
58   bool runOnMachineFunction(MachineFunction &MF) override;
59   bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII,
60                          const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
61   bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
62                       const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
63 
getAnalysisUsage(AnalysisUsage & AU) const64   void getAnalysisUsage(AnalysisUsage &AU) const override {
65     AU.setPreservesCFG();
66     MachineFunctionPass::getAnalysisUsage(AU);
67   }
68 
getPassName() const69   StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; }
70 };
71 
72 } // end anonymous namespace
73 
74 char RISCVOptWInstrs::ID = 0;
INITIALIZE_PASS(RISCVOptWInstrs,DEBUG_TYPE,RISCV_OPT_W_INSTRS_NAME,false,false)75 INITIALIZE_PASS(RISCVOptWInstrs, DEBUG_TYPE, RISCV_OPT_W_INSTRS_NAME, false,
76                 false)
77 
78 FunctionPass *llvm::createRISCVOptWInstrsPass() {
79   return new RISCVOptWInstrs();
80 }
81 
vectorPseudoHasAllNBitUsers(const MachineOperand & UserOp,unsigned Bits)82 static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp,
83                                         unsigned Bits) {
84   const MachineInstr &MI = *UserOp.getParent();
85   unsigned MCOpcode = RISCV::getRVVMCOpcode(MI.getOpcode());
86 
87   if (!MCOpcode)
88     return false;
89 
90   const MCInstrDesc &MCID = MI.getDesc();
91   const uint64_t TSFlags = MCID.TSFlags;
92   if (!RISCVII::hasSEWOp(TSFlags))
93     return false;
94   assert(RISCVII::hasVLOp(TSFlags));
95   const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm();
96 
97   if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID))
98     return false;
99 
100   auto NumDemandedBits =
101       RISCV::getVectorLowDemandedScalarBits(MCOpcode, Log2SEW);
102   return NumDemandedBits && Bits >= *NumDemandedBits;
103 }
104 
105 // Checks if all users only demand the lower \p OrigBits of the original
106 // instruction's result.
107 // TODO: handle multiple interdependent transformations
hasAllNBitUsers(const MachineInstr & OrigMI,const RISCVSubtarget & ST,const MachineRegisterInfo & MRI,unsigned OrigBits)108 static bool hasAllNBitUsers(const MachineInstr &OrigMI,
109                             const RISCVSubtarget &ST,
110                             const MachineRegisterInfo &MRI, unsigned OrigBits) {
111 
112   SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited;
113   SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist;
114 
115   Worklist.push_back(std::make_pair(&OrigMI, OrigBits));
116 
117   while (!Worklist.empty()) {
118     auto P = Worklist.pop_back_val();
119     const MachineInstr *MI = P.first;
120     unsigned Bits = P.second;
121 
122     if (!Visited.insert(P).second)
123       continue;
124 
125     // Only handle instructions with one def.
126     if (MI->getNumExplicitDefs() != 1)
127       return false;
128 
129     Register DestReg = MI->getOperand(0).getReg();
130     if (!DestReg.isVirtual())
131       return false;
132 
133     for (auto &UserOp : MRI.use_nodbg_operands(DestReg)) {
134       const MachineInstr *UserMI = UserOp.getParent();
135       unsigned OpIdx = UserOp.getOperandNo();
136 
137       switch (UserMI->getOpcode()) {
138       default:
139         if (vectorPseudoHasAllNBitUsers(UserOp, Bits))
140           break;
141         return false;
142 
143       case RISCV::ADDIW:
144       case RISCV::ADDW:
145       case RISCV::DIVUW:
146       case RISCV::DIVW:
147       case RISCV::MULW:
148       case RISCV::REMUW:
149       case RISCV::REMW:
150       case RISCV::SLLIW:
151       case RISCV::SLLW:
152       case RISCV::SRAIW:
153       case RISCV::SRAW:
154       case RISCV::SRLIW:
155       case RISCV::SRLW:
156       case RISCV::SUBW:
157       case RISCV::ROLW:
158       case RISCV::RORW:
159       case RISCV::RORIW:
160       case RISCV::CLZW:
161       case RISCV::CTZW:
162       case RISCV::CPOPW:
163       case RISCV::SLLI_UW:
164       case RISCV::FMV_W_X:
165       case RISCV::FCVT_H_W:
166       case RISCV::FCVT_H_WU:
167       case RISCV::FCVT_S_W:
168       case RISCV::FCVT_S_WU:
169       case RISCV::FCVT_D_W:
170       case RISCV::FCVT_D_WU:
171         if (Bits >= 32)
172           break;
173         return false;
174       case RISCV::SEXT_B:
175       case RISCV::PACKH:
176         if (Bits >= 8)
177           break;
178         return false;
179       case RISCV::SEXT_H:
180       case RISCV::FMV_H_X:
181       case RISCV::ZEXT_H_RV32:
182       case RISCV::ZEXT_H_RV64:
183       case RISCV::PACKW:
184         if (Bits >= 16)
185           break;
186         return false;
187 
188       case RISCV::PACK:
189         if (Bits >= (ST.getXLen() / 2))
190           break;
191         return false;
192 
193       case RISCV::SRLI: {
194         // If we are shifting right by less than Bits, and users don't demand
195         // any bits that were shifted into [Bits-1:0], then we can consider this
196         // as an N-Bit user.
197         unsigned ShAmt = UserMI->getOperand(2).getImm();
198         if (Bits > ShAmt) {
199           Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt));
200           break;
201         }
202         return false;
203       }
204 
205       // these overwrite higher input bits, otherwise the lower word of output
206       // depends only on the lower word of input. So check their uses read W.
207       case RISCV::SLLI:
208         if (Bits >= (ST.getXLen() - UserMI->getOperand(2).getImm()))
209           break;
210         Worklist.push_back(std::make_pair(UserMI, Bits));
211         break;
212       case RISCV::ANDI: {
213         uint64_t Imm = UserMI->getOperand(2).getImm();
214         if (Bits >= (unsigned)llvm::bit_width(Imm))
215           break;
216         Worklist.push_back(std::make_pair(UserMI, Bits));
217         break;
218       }
219       case RISCV::ORI: {
220         uint64_t Imm = UserMI->getOperand(2).getImm();
221         if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm))
222           break;
223         Worklist.push_back(std::make_pair(UserMI, Bits));
224         break;
225       }
226 
227       case RISCV::SLL:
228       case RISCV::BSET:
229       case RISCV::BCLR:
230       case RISCV::BINV:
231         // Operand 2 is the shift amount which uses log2(xlen) bits.
232         if (OpIdx == 2) {
233           if (Bits >= Log2_32(ST.getXLen()))
234             break;
235           return false;
236         }
237         Worklist.push_back(std::make_pair(UserMI, Bits));
238         break;
239 
240       case RISCV::SRA:
241       case RISCV::SRL:
242       case RISCV::ROL:
243       case RISCV::ROR:
244         // Operand 2 is the shift amount which uses 6 bits.
245         if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen()))
246           break;
247         return false;
248 
249       case RISCV::ADD_UW:
250       case RISCV::SH1ADD_UW:
251       case RISCV::SH2ADD_UW:
252       case RISCV::SH3ADD_UW:
253         // Operand 1 is implicitly zero extended.
254         if (OpIdx == 1 && Bits >= 32)
255           break;
256         Worklist.push_back(std::make_pair(UserMI, Bits));
257         break;
258 
259       case RISCV::BEXTI:
260         if (UserMI->getOperand(2).getImm() >= Bits)
261           return false;
262         break;
263 
264       case RISCV::SB:
265         // The first argument is the value to store.
266         if (OpIdx == 0 && Bits >= 8)
267           break;
268         return false;
269       case RISCV::SH:
270         // The first argument is the value to store.
271         if (OpIdx == 0 && Bits >= 16)
272           break;
273         return false;
274       case RISCV::SW:
275         // The first argument is the value to store.
276         if (OpIdx == 0 && Bits >= 32)
277           break;
278         return false;
279 
280       // For these, lower word of output in these operations, depends only on
281       // the lower word of input. So, we check all uses only read lower word.
282       case RISCV::COPY:
283       case RISCV::PHI:
284 
285       case RISCV::ADD:
286       case RISCV::ADDI:
287       case RISCV::AND:
288       case RISCV::MUL:
289       case RISCV::OR:
290       case RISCV::SUB:
291       case RISCV::XOR:
292       case RISCV::XORI:
293 
294       case RISCV::ANDN:
295       case RISCV::BREV8:
296       case RISCV::CLMUL:
297       case RISCV::ORC_B:
298       case RISCV::ORN:
299       case RISCV::SH1ADD:
300       case RISCV::SH2ADD:
301       case RISCV::SH3ADD:
302       case RISCV::XNOR:
303       case RISCV::BSETI:
304       case RISCV::BCLRI:
305       case RISCV::BINVI:
306         Worklist.push_back(std::make_pair(UserMI, Bits));
307         break;
308 
309       case RISCV::PseudoCCMOVGPR:
310         // Either operand 4 or operand 5 is returned by this instruction. If
311         // only the lower word of the result is used, then only the lower word
312         // of operand 4 and 5 is used.
313         if (OpIdx != 4 && OpIdx != 5)
314           return false;
315         Worklist.push_back(std::make_pair(UserMI, Bits));
316         break;
317 
318       case RISCV::CZERO_EQZ:
319       case RISCV::CZERO_NEZ:
320       case RISCV::VT_MASKC:
321       case RISCV::VT_MASKCN:
322         if (OpIdx != 1)
323           return false;
324         Worklist.push_back(std::make_pair(UserMI, Bits));
325         break;
326       }
327     }
328   }
329 
330   return true;
331 }
332 
hasAllWUsers(const MachineInstr & OrigMI,const RISCVSubtarget & ST,const MachineRegisterInfo & MRI)333 static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST,
334                          const MachineRegisterInfo &MRI) {
335   return hasAllNBitUsers(OrigMI, ST, MRI, 32);
336 }
337 
338 // This function returns true if the machine instruction always outputs a value
339 // where bits 63:32 match bit 31.
isSignExtendingOpW(const MachineInstr & MI,const MachineRegisterInfo & MRI)340 static bool isSignExtendingOpW(const MachineInstr &MI,
341                                const MachineRegisterInfo &MRI) {
342   uint64_t TSFlags = MI.getDesc().TSFlags;
343 
344   // Instructions that can be determined from opcode are marked in tablegen.
345   if (TSFlags & RISCVII::IsSignExtendingOpWMask)
346     return true;
347 
348   // Special cases that require checking operands.
349   switch (MI.getOpcode()) {
350   // shifting right sufficiently makes the value 32-bit sign-extended
351   case RISCV::SRAI:
352     return MI.getOperand(2).getImm() >= 32;
353   case RISCV::SRLI:
354     return MI.getOperand(2).getImm() > 32;
355   // The LI pattern ADDI rd, X0, imm is sign extended.
356   case RISCV::ADDI:
357     return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0;
358   // An ANDI with an 11 bit immediate will zero bits 63:11.
359   case RISCV::ANDI:
360     return isUInt<11>(MI.getOperand(2).getImm());
361   // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11.
362   case RISCV::ORI:
363     return !isUInt<11>(MI.getOperand(2).getImm());
364   // A bseti with X0 is sign extended if the immediate is less than 31.
365   case RISCV::BSETI:
366     return MI.getOperand(2).getImm() < 31 &&
367            MI.getOperand(1).getReg() == RISCV::X0;
368   // Copying from X0 produces zero.
369   case RISCV::COPY:
370     return MI.getOperand(1).getReg() == RISCV::X0;
371   case RISCV::PseudoAtomicLoadNand32:
372     return true;
373   case RISCV::PseudoVMV_X_S: {
374     // vmv.x.s has at least 33 sign bits if log2(sew) <= 5.
375     int64_t Log2SEW = MI.getOperand(2).getImm();
376     assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW");
377     return Log2SEW <= 5;
378   }
379   }
380 
381   return false;
382 }
383 
isSignExtendedW(Register SrcReg,const RISCVSubtarget & ST,const MachineRegisterInfo & MRI,SmallPtrSetImpl<MachineInstr * > & FixableDef)384 static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST,
385                             const MachineRegisterInfo &MRI,
386                             SmallPtrSetImpl<MachineInstr *> &FixableDef) {
387 
388   SmallPtrSet<const MachineInstr *, 4> Visited;
389   SmallVector<MachineInstr *, 4> Worklist;
390 
391   auto AddRegDefToWorkList = [&](Register SrcReg) {
392     if (!SrcReg.isVirtual())
393       return false;
394     MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
395     if (!SrcMI)
396       return false;
397     // Code assumes the register is operand 0.
398     // TODO: Maybe the worklist should store register?
399     if (!SrcMI->getOperand(0).isReg() ||
400         SrcMI->getOperand(0).getReg() != SrcReg)
401       return false;
402     // Add SrcMI to the worklist.
403     Worklist.push_back(SrcMI);
404     return true;
405   };
406 
407   if (!AddRegDefToWorkList(SrcReg))
408     return false;
409 
410   while (!Worklist.empty()) {
411     MachineInstr *MI = Worklist.pop_back_val();
412 
413     // If we already visited this instruction, we don't need to check it again.
414     if (!Visited.insert(MI).second)
415       continue;
416 
417     // If this is a sign extending operation we don't need to look any further.
418     if (isSignExtendingOpW(*MI, MRI))
419       continue;
420 
421     // Is this an instruction that propagates sign extend?
422     switch (MI->getOpcode()) {
423     default:
424       // Unknown opcode, give up.
425       return false;
426     case RISCV::COPY: {
427       const MachineFunction *MF = MI->getMF();
428       const RISCVMachineFunctionInfo *RVFI =
429           MF->getInfo<RISCVMachineFunctionInfo>();
430 
431       // If this is the entry block and the register is livein, see if we know
432       // it is sign extended.
433       if (MI->getParent() == &MF->front()) {
434         Register VReg = MI->getOperand(0).getReg();
435         if (MF->getRegInfo().isLiveIn(VReg) && RVFI->isSExt32Register(VReg))
436           continue;
437       }
438 
439       Register CopySrcReg = MI->getOperand(1).getReg();
440       if (CopySrcReg == RISCV::X10) {
441         // For a method return value, we check the ZExt/SExt flags in attribute.
442         // We assume the following code sequence for method call.
443         // PseudoCALL @bar, ...
444         // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2
445         // %0:gpr = COPY $x10
446         //
447         // We use the PseudoCall to look up the IR function being called to find
448         // its return attributes.
449         const MachineBasicBlock *MBB = MI->getParent();
450         auto II = MI->getIterator();
451         if (II == MBB->instr_begin() ||
452             (--II)->getOpcode() != RISCV::ADJCALLSTACKUP)
453           return false;
454 
455         const MachineInstr &CallMI = *(--II);
456         if (!CallMI.isCall() || !CallMI.getOperand(0).isGlobal())
457           return false;
458 
459         auto *CalleeFn =
460             dyn_cast_if_present<Function>(CallMI.getOperand(0).getGlobal());
461         if (!CalleeFn)
462           return false;
463 
464         auto *IntTy = dyn_cast<IntegerType>(CalleeFn->getReturnType());
465         if (!IntTy)
466           return false;
467 
468         const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs();
469         unsigned BitWidth = IntTy->getBitWidth();
470         if ((BitWidth <= 32 && Attrs.hasAttribute(Attribute::SExt)) ||
471             (BitWidth < 32 && Attrs.hasAttribute(Attribute::ZExt)))
472           continue;
473       }
474 
475       if (!AddRegDefToWorkList(CopySrcReg))
476         return false;
477 
478       break;
479     }
480 
481     // For these, we just need to check if the 1st operand is sign extended.
482     case RISCV::BCLRI:
483     case RISCV::BINVI:
484     case RISCV::BSETI:
485       if (MI->getOperand(2).getImm() >= 31)
486         return false;
487       [[fallthrough]];
488     case RISCV::REM:
489     case RISCV::ANDI:
490     case RISCV::ORI:
491     case RISCV::XORI:
492       // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R.
493       // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1
494       // Logical operations use a sign extended 12-bit immediate.
495       if (!AddRegDefToWorkList(MI->getOperand(1).getReg()))
496         return false;
497 
498       break;
499     case RISCV::PseudoCCADDW:
500     case RISCV::PseudoCCADDIW:
501     case RISCV::PseudoCCSUBW:
502     case RISCV::PseudoCCSLLW:
503     case RISCV::PseudoCCSRLW:
504     case RISCV::PseudoCCSRAW:
505     case RISCV::PseudoCCSLLIW:
506     case RISCV::PseudoCCSRLIW:
507     case RISCV::PseudoCCSRAIW:
508       // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only
509       // need to check if operand 4 is sign extended.
510       if (!AddRegDefToWorkList(MI->getOperand(4).getReg()))
511         return false;
512       break;
513     case RISCV::REMU:
514     case RISCV::AND:
515     case RISCV::OR:
516     case RISCV::XOR:
517     case RISCV::ANDN:
518     case RISCV::ORN:
519     case RISCV::XNOR:
520     case RISCV::MAX:
521     case RISCV::MAXU:
522     case RISCV::MIN:
523     case RISCV::MINU:
524     case RISCV::PseudoCCMOVGPR:
525     case RISCV::PseudoCCAND:
526     case RISCV::PseudoCCOR:
527     case RISCV::PseudoCCXOR:
528     case RISCV::PHI: {
529       // If all incoming values are sign-extended, the output of AND, OR, XOR,
530       // MIN, MAX, or PHI is also sign-extended.
531 
532       // The input registers for PHI are operand 1, 3, ...
533       // The input registers for PseudoCCMOVGPR are 4 and 5.
534       // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6.
535       // The input registers for others are operand 1 and 2.
536       unsigned B = 1, E = 3, D = 1;
537       switch (MI->getOpcode()) {
538       case RISCV::PHI:
539         E = MI->getNumOperands();
540         D = 2;
541         break;
542       case RISCV::PseudoCCMOVGPR:
543         B = 4;
544         E = 6;
545         break;
546       case RISCV::PseudoCCAND:
547       case RISCV::PseudoCCOR:
548       case RISCV::PseudoCCXOR:
549         B = 4;
550         E = 7;
551         break;
552        }
553 
554       for (unsigned I = B; I != E; I += D) {
555         if (!MI->getOperand(I).isReg())
556           return false;
557 
558         if (!AddRegDefToWorkList(MI->getOperand(I).getReg()))
559           return false;
560       }
561 
562       break;
563     }
564 
565     case RISCV::CZERO_EQZ:
566     case RISCV::CZERO_NEZ:
567     case RISCV::VT_MASKC:
568     case RISCV::VT_MASKCN:
569       // Instructions return zero or operand 1. Result is sign extended if
570       // operand 1 is sign extended.
571       if (!AddRegDefToWorkList(MI->getOperand(1).getReg()))
572         return false;
573       break;
574 
575     // With these opcode, we can "fix" them with the W-version
576     // if we know all users of the result only rely on bits 31:0
577     case RISCV::SLLI:
578       // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
579       if (MI->getOperand(2).getImm() >= 32)
580         return false;
581       [[fallthrough]];
582     case RISCV::ADDI:
583     case RISCV::ADD:
584     case RISCV::LD:
585     case RISCV::LWU:
586     case RISCV::MUL:
587     case RISCV::SUB:
588       if (hasAllWUsers(*MI, ST, MRI)) {
589         FixableDef.insert(MI);
590         break;
591       }
592       return false;
593     }
594   }
595 
596   // If we get here, then every node we visited produces a sign extended value
597   // or propagated sign extended values. So the result must be sign extended.
598   return true;
599 }
600 
getWOp(unsigned Opcode)601 static unsigned getWOp(unsigned Opcode) {
602   switch (Opcode) {
603   case RISCV::ADDI:
604     return RISCV::ADDIW;
605   case RISCV::ADD:
606     return RISCV::ADDW;
607   case RISCV::LD:
608   case RISCV::LWU:
609     return RISCV::LW;
610   case RISCV::MUL:
611     return RISCV::MULW;
612   case RISCV::SLLI:
613     return RISCV::SLLIW;
614   case RISCV::SUB:
615     return RISCV::SUBW;
616   default:
617     llvm_unreachable("Unexpected opcode for replacement with W variant");
618   }
619 }
620 
removeSExtWInstrs(MachineFunction & MF,const RISCVInstrInfo & TII,const RISCVSubtarget & ST,MachineRegisterInfo & MRI)621 bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
622                                         const RISCVInstrInfo &TII,
623                                         const RISCVSubtarget &ST,
624                                         MachineRegisterInfo &MRI) {
625   if (DisableSExtWRemoval)
626     return false;
627 
628   bool MadeChange = false;
629   for (MachineBasicBlock &MBB : MF) {
630     for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) {
631       // We're looking for the sext.w pattern ADDIW rd, rs1, 0.
632       if (!RISCV::isSEXT_W(MI))
633         continue;
634 
635       Register SrcReg = MI.getOperand(1).getReg();
636 
637       SmallPtrSet<MachineInstr *, 4> FixableDefs;
638 
639       // If all users only use the lower bits, this sext.w is redundant.
640       // Or if all definitions reaching MI sign-extend their output,
641       // then sext.w is redundant.
642       if (!hasAllWUsers(MI, ST, MRI) &&
643           !isSignExtendedW(SrcReg, ST, MRI, FixableDefs))
644         continue;
645 
646       Register DstReg = MI.getOperand(0).getReg();
647       if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg)))
648         continue;
649 
650       // Convert Fixable instructions to their W versions.
651       for (MachineInstr *Fixable : FixableDefs) {
652         LLVM_DEBUG(dbgs() << "Replacing " << *Fixable);
653         Fixable->setDesc(TII.get(getWOp(Fixable->getOpcode())));
654         Fixable->clearFlag(MachineInstr::MIFlag::NoSWrap);
655         Fixable->clearFlag(MachineInstr::MIFlag::NoUWrap);
656         Fixable->clearFlag(MachineInstr::MIFlag::IsExact);
657         LLVM_DEBUG(dbgs() << "     with " << *Fixable);
658         ++NumTransformedToWInstrs;
659       }
660 
661       LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
662       MRI.replaceRegWith(DstReg, SrcReg);
663       MRI.clearKillFlags(SrcReg);
664       MI.eraseFromParent();
665       ++NumRemovedSExtW;
666       MadeChange = true;
667     }
668   }
669 
670   return MadeChange;
671 }
672 
stripWSuffixes(MachineFunction & MF,const RISCVInstrInfo & TII,const RISCVSubtarget & ST,MachineRegisterInfo & MRI)673 bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF,
674                                      const RISCVInstrInfo &TII,
675                                      const RISCVSubtarget &ST,
676                                      MachineRegisterInfo &MRI) {
677   if (DisableStripWSuffix)
678     return false;
679 
680   bool MadeChange = false;
681   for (MachineBasicBlock &MBB : MF) {
682     for (MachineInstr &MI : MBB) {
683       unsigned Opc;
684       switch (MI.getOpcode()) {
685       default:
686         continue;
687       case RISCV::ADDW:  Opc = RISCV::ADD;  break;
688       case RISCV::ADDIW: Opc = RISCV::ADDI; break;
689       case RISCV::MULW:  Opc = RISCV::MUL;  break;
690       case RISCV::SLLIW: Opc = RISCV::SLLI; break;
691       }
692 
693       if (hasAllWUsers(MI, ST, MRI)) {
694         MI.setDesc(TII.get(Opc));
695         MadeChange = true;
696       }
697     }
698   }
699 
700   return MadeChange;
701 }
702 
runOnMachineFunction(MachineFunction & MF)703 bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) {
704   if (skipFunction(MF.getFunction()))
705     return false;
706 
707   MachineRegisterInfo &MRI = MF.getRegInfo();
708   const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
709   const RISCVInstrInfo &TII = *ST.getInstrInfo();
710 
711   if (!ST.is64Bit())
712     return false;
713 
714   bool MadeChange = false;
715   MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI);
716   MadeChange |= stripWSuffixes(MF, TII, ST, MRI);
717 
718   return MadeChange;
719 }
720