1 //===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 //
9 // This pass does some optimizations for *W instructions at the MI level.
10 //
11 // First it removes unneeded sext.w instructions. Either because the sign
12 // extended bits aren't consumed or because the input was already sign extended
13 // by an earlier instruction.
14 //
15 // Then it removes the -w suffix from opw instructions whenever all users are
16 // dependent only on the lower word of the result of the instruction.
17 // The cases handled are:
18 // * addw because c.add has a larger register encoding than c.addw.
19 // * addiw because it helps reduce test differences between RV32 and RV64
20 // w/o being a pessimization.
21 // * mulw because c.mulw doesn't exist but c.mul does (w/ zcb)
22 // * slliw because c.slliw doesn't exist and c.slli does
23 //
24 //===---------------------------------------------------------------------===//
25
26 #include "RISCV.h"
27 #include "RISCVMachineFunctionInfo.h"
28 #include "RISCVSubtarget.h"
29 #include "llvm/ADT/SmallSet.h"
30 #include "llvm/ADT/Statistic.h"
31 #include "llvm/CodeGen/MachineFunctionPass.h"
32 #include "llvm/CodeGen/TargetInstrInfo.h"
33
34 using namespace llvm;
35
36 #define DEBUG_TYPE "riscv-opt-w-instrs"
37 #define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions"
38
39 STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions");
40 STATISTIC(NumTransformedToWInstrs,
41 "Number of instructions transformed to W-ops");
42
43 static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal",
44 cl::desc("Disable removal of sext.w"),
45 cl::init(false), cl::Hidden);
46 static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix",
47 cl::desc("Disable strip W suffix"),
48 cl::init(false), cl::Hidden);
49
50 namespace {
51
52 class RISCVOptWInstrs : public MachineFunctionPass {
53 public:
54 static char ID;
55
RISCVOptWInstrs()56 RISCVOptWInstrs() : MachineFunctionPass(ID) {}
57
58 bool runOnMachineFunction(MachineFunction &MF) override;
59 bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII,
60 const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
61 bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
62 const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
63
getAnalysisUsage(AnalysisUsage & AU) const64 void getAnalysisUsage(AnalysisUsage &AU) const override {
65 AU.setPreservesCFG();
66 MachineFunctionPass::getAnalysisUsage(AU);
67 }
68
getPassName() const69 StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; }
70 };
71
72 } // end anonymous namespace
73
74 char RISCVOptWInstrs::ID = 0;
INITIALIZE_PASS(RISCVOptWInstrs,DEBUG_TYPE,RISCV_OPT_W_INSTRS_NAME,false,false)75 INITIALIZE_PASS(RISCVOptWInstrs, DEBUG_TYPE, RISCV_OPT_W_INSTRS_NAME, false,
76 false)
77
78 FunctionPass *llvm::createRISCVOptWInstrsPass() {
79 return new RISCVOptWInstrs();
80 }
81
vectorPseudoHasAllNBitUsers(const MachineOperand & UserOp,unsigned Bits)82 static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp,
83 unsigned Bits) {
84 const MachineInstr &MI = *UserOp.getParent();
85 unsigned MCOpcode = RISCV::getRVVMCOpcode(MI.getOpcode());
86
87 if (!MCOpcode)
88 return false;
89
90 const MCInstrDesc &MCID = MI.getDesc();
91 const uint64_t TSFlags = MCID.TSFlags;
92 if (!RISCVII::hasSEWOp(TSFlags))
93 return false;
94 assert(RISCVII::hasVLOp(TSFlags));
95 const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm();
96
97 if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID))
98 return false;
99
100 auto NumDemandedBits =
101 RISCV::getVectorLowDemandedScalarBits(MCOpcode, Log2SEW);
102 return NumDemandedBits && Bits >= *NumDemandedBits;
103 }
104
105 // Checks if all users only demand the lower \p OrigBits of the original
106 // instruction's result.
107 // TODO: handle multiple interdependent transformations
hasAllNBitUsers(const MachineInstr & OrigMI,const RISCVSubtarget & ST,const MachineRegisterInfo & MRI,unsigned OrigBits)108 static bool hasAllNBitUsers(const MachineInstr &OrigMI,
109 const RISCVSubtarget &ST,
110 const MachineRegisterInfo &MRI, unsigned OrigBits) {
111
112 SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited;
113 SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist;
114
115 Worklist.push_back(std::make_pair(&OrigMI, OrigBits));
116
117 while (!Worklist.empty()) {
118 auto P = Worklist.pop_back_val();
119 const MachineInstr *MI = P.first;
120 unsigned Bits = P.second;
121
122 if (!Visited.insert(P).second)
123 continue;
124
125 // Only handle instructions with one def.
126 if (MI->getNumExplicitDefs() != 1)
127 return false;
128
129 Register DestReg = MI->getOperand(0).getReg();
130 if (!DestReg.isVirtual())
131 return false;
132
133 for (auto &UserOp : MRI.use_nodbg_operands(DestReg)) {
134 const MachineInstr *UserMI = UserOp.getParent();
135 unsigned OpIdx = UserOp.getOperandNo();
136
137 switch (UserMI->getOpcode()) {
138 default:
139 if (vectorPseudoHasAllNBitUsers(UserOp, Bits))
140 break;
141 return false;
142
143 case RISCV::ADDIW:
144 case RISCV::ADDW:
145 case RISCV::DIVUW:
146 case RISCV::DIVW:
147 case RISCV::MULW:
148 case RISCV::REMUW:
149 case RISCV::REMW:
150 case RISCV::SLLIW:
151 case RISCV::SLLW:
152 case RISCV::SRAIW:
153 case RISCV::SRAW:
154 case RISCV::SRLIW:
155 case RISCV::SRLW:
156 case RISCV::SUBW:
157 case RISCV::ROLW:
158 case RISCV::RORW:
159 case RISCV::RORIW:
160 case RISCV::CLZW:
161 case RISCV::CTZW:
162 case RISCV::CPOPW:
163 case RISCV::SLLI_UW:
164 case RISCV::FMV_W_X:
165 case RISCV::FCVT_H_W:
166 case RISCV::FCVT_H_WU:
167 case RISCV::FCVT_S_W:
168 case RISCV::FCVT_S_WU:
169 case RISCV::FCVT_D_W:
170 case RISCV::FCVT_D_WU:
171 if (Bits >= 32)
172 break;
173 return false;
174 case RISCV::SEXT_B:
175 case RISCV::PACKH:
176 if (Bits >= 8)
177 break;
178 return false;
179 case RISCV::SEXT_H:
180 case RISCV::FMV_H_X:
181 case RISCV::ZEXT_H_RV32:
182 case RISCV::ZEXT_H_RV64:
183 case RISCV::PACKW:
184 if (Bits >= 16)
185 break;
186 return false;
187
188 case RISCV::PACK:
189 if (Bits >= (ST.getXLen() / 2))
190 break;
191 return false;
192
193 case RISCV::SRLI: {
194 // If we are shifting right by less than Bits, and users don't demand
195 // any bits that were shifted into [Bits-1:0], then we can consider this
196 // as an N-Bit user.
197 unsigned ShAmt = UserMI->getOperand(2).getImm();
198 if (Bits > ShAmt) {
199 Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt));
200 break;
201 }
202 return false;
203 }
204
205 // these overwrite higher input bits, otherwise the lower word of output
206 // depends only on the lower word of input. So check their uses read W.
207 case RISCV::SLLI:
208 if (Bits >= (ST.getXLen() - UserMI->getOperand(2).getImm()))
209 break;
210 Worklist.push_back(std::make_pair(UserMI, Bits));
211 break;
212 case RISCV::ANDI: {
213 uint64_t Imm = UserMI->getOperand(2).getImm();
214 if (Bits >= (unsigned)llvm::bit_width(Imm))
215 break;
216 Worklist.push_back(std::make_pair(UserMI, Bits));
217 break;
218 }
219 case RISCV::ORI: {
220 uint64_t Imm = UserMI->getOperand(2).getImm();
221 if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm))
222 break;
223 Worklist.push_back(std::make_pair(UserMI, Bits));
224 break;
225 }
226
227 case RISCV::SLL:
228 case RISCV::BSET:
229 case RISCV::BCLR:
230 case RISCV::BINV:
231 // Operand 2 is the shift amount which uses log2(xlen) bits.
232 if (OpIdx == 2) {
233 if (Bits >= Log2_32(ST.getXLen()))
234 break;
235 return false;
236 }
237 Worklist.push_back(std::make_pair(UserMI, Bits));
238 break;
239
240 case RISCV::SRA:
241 case RISCV::SRL:
242 case RISCV::ROL:
243 case RISCV::ROR:
244 // Operand 2 is the shift amount which uses 6 bits.
245 if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen()))
246 break;
247 return false;
248
249 case RISCV::ADD_UW:
250 case RISCV::SH1ADD_UW:
251 case RISCV::SH2ADD_UW:
252 case RISCV::SH3ADD_UW:
253 // Operand 1 is implicitly zero extended.
254 if (OpIdx == 1 && Bits >= 32)
255 break;
256 Worklist.push_back(std::make_pair(UserMI, Bits));
257 break;
258
259 case RISCV::BEXTI:
260 if (UserMI->getOperand(2).getImm() >= Bits)
261 return false;
262 break;
263
264 case RISCV::SB:
265 // The first argument is the value to store.
266 if (OpIdx == 0 && Bits >= 8)
267 break;
268 return false;
269 case RISCV::SH:
270 // The first argument is the value to store.
271 if (OpIdx == 0 && Bits >= 16)
272 break;
273 return false;
274 case RISCV::SW:
275 // The first argument is the value to store.
276 if (OpIdx == 0 && Bits >= 32)
277 break;
278 return false;
279
280 // For these, lower word of output in these operations, depends only on
281 // the lower word of input. So, we check all uses only read lower word.
282 case RISCV::COPY:
283 case RISCV::PHI:
284
285 case RISCV::ADD:
286 case RISCV::ADDI:
287 case RISCV::AND:
288 case RISCV::MUL:
289 case RISCV::OR:
290 case RISCV::SUB:
291 case RISCV::XOR:
292 case RISCV::XORI:
293
294 case RISCV::ANDN:
295 case RISCV::BREV8:
296 case RISCV::CLMUL:
297 case RISCV::ORC_B:
298 case RISCV::ORN:
299 case RISCV::SH1ADD:
300 case RISCV::SH2ADD:
301 case RISCV::SH3ADD:
302 case RISCV::XNOR:
303 case RISCV::BSETI:
304 case RISCV::BCLRI:
305 case RISCV::BINVI:
306 Worklist.push_back(std::make_pair(UserMI, Bits));
307 break;
308
309 case RISCV::PseudoCCMOVGPR:
310 // Either operand 4 or operand 5 is returned by this instruction. If
311 // only the lower word of the result is used, then only the lower word
312 // of operand 4 and 5 is used.
313 if (OpIdx != 4 && OpIdx != 5)
314 return false;
315 Worklist.push_back(std::make_pair(UserMI, Bits));
316 break;
317
318 case RISCV::CZERO_EQZ:
319 case RISCV::CZERO_NEZ:
320 case RISCV::VT_MASKC:
321 case RISCV::VT_MASKCN:
322 if (OpIdx != 1)
323 return false;
324 Worklist.push_back(std::make_pair(UserMI, Bits));
325 break;
326 }
327 }
328 }
329
330 return true;
331 }
332
hasAllWUsers(const MachineInstr & OrigMI,const RISCVSubtarget & ST,const MachineRegisterInfo & MRI)333 static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST,
334 const MachineRegisterInfo &MRI) {
335 return hasAllNBitUsers(OrigMI, ST, MRI, 32);
336 }
337
338 // This function returns true if the machine instruction always outputs a value
339 // where bits 63:32 match bit 31.
isSignExtendingOpW(const MachineInstr & MI,const MachineRegisterInfo & MRI)340 static bool isSignExtendingOpW(const MachineInstr &MI,
341 const MachineRegisterInfo &MRI) {
342 uint64_t TSFlags = MI.getDesc().TSFlags;
343
344 // Instructions that can be determined from opcode are marked in tablegen.
345 if (TSFlags & RISCVII::IsSignExtendingOpWMask)
346 return true;
347
348 // Special cases that require checking operands.
349 switch (MI.getOpcode()) {
350 // shifting right sufficiently makes the value 32-bit sign-extended
351 case RISCV::SRAI:
352 return MI.getOperand(2).getImm() >= 32;
353 case RISCV::SRLI:
354 return MI.getOperand(2).getImm() > 32;
355 // The LI pattern ADDI rd, X0, imm is sign extended.
356 case RISCV::ADDI:
357 return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0;
358 // An ANDI with an 11 bit immediate will zero bits 63:11.
359 case RISCV::ANDI:
360 return isUInt<11>(MI.getOperand(2).getImm());
361 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11.
362 case RISCV::ORI:
363 return !isUInt<11>(MI.getOperand(2).getImm());
364 // A bseti with X0 is sign extended if the immediate is less than 31.
365 case RISCV::BSETI:
366 return MI.getOperand(2).getImm() < 31 &&
367 MI.getOperand(1).getReg() == RISCV::X0;
368 // Copying from X0 produces zero.
369 case RISCV::COPY:
370 return MI.getOperand(1).getReg() == RISCV::X0;
371 case RISCV::PseudoAtomicLoadNand32:
372 return true;
373 case RISCV::PseudoVMV_X_S: {
374 // vmv.x.s has at least 33 sign bits if log2(sew) <= 5.
375 int64_t Log2SEW = MI.getOperand(2).getImm();
376 assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW");
377 return Log2SEW <= 5;
378 }
379 }
380
381 return false;
382 }
383
isSignExtendedW(Register SrcReg,const RISCVSubtarget & ST,const MachineRegisterInfo & MRI,SmallPtrSetImpl<MachineInstr * > & FixableDef)384 static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST,
385 const MachineRegisterInfo &MRI,
386 SmallPtrSetImpl<MachineInstr *> &FixableDef) {
387
388 SmallPtrSet<const MachineInstr *, 4> Visited;
389 SmallVector<MachineInstr *, 4> Worklist;
390
391 auto AddRegDefToWorkList = [&](Register SrcReg) {
392 if (!SrcReg.isVirtual())
393 return false;
394 MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
395 if (!SrcMI)
396 return false;
397 // Code assumes the register is operand 0.
398 // TODO: Maybe the worklist should store register?
399 if (!SrcMI->getOperand(0).isReg() ||
400 SrcMI->getOperand(0).getReg() != SrcReg)
401 return false;
402 // Add SrcMI to the worklist.
403 Worklist.push_back(SrcMI);
404 return true;
405 };
406
407 if (!AddRegDefToWorkList(SrcReg))
408 return false;
409
410 while (!Worklist.empty()) {
411 MachineInstr *MI = Worklist.pop_back_val();
412
413 // If we already visited this instruction, we don't need to check it again.
414 if (!Visited.insert(MI).second)
415 continue;
416
417 // If this is a sign extending operation we don't need to look any further.
418 if (isSignExtendingOpW(*MI, MRI))
419 continue;
420
421 // Is this an instruction that propagates sign extend?
422 switch (MI->getOpcode()) {
423 default:
424 // Unknown opcode, give up.
425 return false;
426 case RISCV::COPY: {
427 const MachineFunction *MF = MI->getMF();
428 const RISCVMachineFunctionInfo *RVFI =
429 MF->getInfo<RISCVMachineFunctionInfo>();
430
431 // If this is the entry block and the register is livein, see if we know
432 // it is sign extended.
433 if (MI->getParent() == &MF->front()) {
434 Register VReg = MI->getOperand(0).getReg();
435 if (MF->getRegInfo().isLiveIn(VReg) && RVFI->isSExt32Register(VReg))
436 continue;
437 }
438
439 Register CopySrcReg = MI->getOperand(1).getReg();
440 if (CopySrcReg == RISCV::X10) {
441 // For a method return value, we check the ZExt/SExt flags in attribute.
442 // We assume the following code sequence for method call.
443 // PseudoCALL @bar, ...
444 // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2
445 // %0:gpr = COPY $x10
446 //
447 // We use the PseudoCall to look up the IR function being called to find
448 // its return attributes.
449 const MachineBasicBlock *MBB = MI->getParent();
450 auto II = MI->getIterator();
451 if (II == MBB->instr_begin() ||
452 (--II)->getOpcode() != RISCV::ADJCALLSTACKUP)
453 return false;
454
455 const MachineInstr &CallMI = *(--II);
456 if (!CallMI.isCall() || !CallMI.getOperand(0).isGlobal())
457 return false;
458
459 auto *CalleeFn =
460 dyn_cast_if_present<Function>(CallMI.getOperand(0).getGlobal());
461 if (!CalleeFn)
462 return false;
463
464 auto *IntTy = dyn_cast<IntegerType>(CalleeFn->getReturnType());
465 if (!IntTy)
466 return false;
467
468 const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs();
469 unsigned BitWidth = IntTy->getBitWidth();
470 if ((BitWidth <= 32 && Attrs.hasAttribute(Attribute::SExt)) ||
471 (BitWidth < 32 && Attrs.hasAttribute(Attribute::ZExt)))
472 continue;
473 }
474
475 if (!AddRegDefToWorkList(CopySrcReg))
476 return false;
477
478 break;
479 }
480
481 // For these, we just need to check if the 1st operand is sign extended.
482 case RISCV::BCLRI:
483 case RISCV::BINVI:
484 case RISCV::BSETI:
485 if (MI->getOperand(2).getImm() >= 31)
486 return false;
487 [[fallthrough]];
488 case RISCV::REM:
489 case RISCV::ANDI:
490 case RISCV::ORI:
491 case RISCV::XORI:
492 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R.
493 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1
494 // Logical operations use a sign extended 12-bit immediate.
495 if (!AddRegDefToWorkList(MI->getOperand(1).getReg()))
496 return false;
497
498 break;
499 case RISCV::PseudoCCADDW:
500 case RISCV::PseudoCCADDIW:
501 case RISCV::PseudoCCSUBW:
502 case RISCV::PseudoCCSLLW:
503 case RISCV::PseudoCCSRLW:
504 case RISCV::PseudoCCSRAW:
505 case RISCV::PseudoCCSLLIW:
506 case RISCV::PseudoCCSRLIW:
507 case RISCV::PseudoCCSRAIW:
508 // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only
509 // need to check if operand 4 is sign extended.
510 if (!AddRegDefToWorkList(MI->getOperand(4).getReg()))
511 return false;
512 break;
513 case RISCV::REMU:
514 case RISCV::AND:
515 case RISCV::OR:
516 case RISCV::XOR:
517 case RISCV::ANDN:
518 case RISCV::ORN:
519 case RISCV::XNOR:
520 case RISCV::MAX:
521 case RISCV::MAXU:
522 case RISCV::MIN:
523 case RISCV::MINU:
524 case RISCV::PseudoCCMOVGPR:
525 case RISCV::PseudoCCAND:
526 case RISCV::PseudoCCOR:
527 case RISCV::PseudoCCXOR:
528 case RISCV::PHI: {
529 // If all incoming values are sign-extended, the output of AND, OR, XOR,
530 // MIN, MAX, or PHI is also sign-extended.
531
532 // The input registers for PHI are operand 1, 3, ...
533 // The input registers for PseudoCCMOVGPR are 4 and 5.
534 // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6.
535 // The input registers for others are operand 1 and 2.
536 unsigned B = 1, E = 3, D = 1;
537 switch (MI->getOpcode()) {
538 case RISCV::PHI:
539 E = MI->getNumOperands();
540 D = 2;
541 break;
542 case RISCV::PseudoCCMOVGPR:
543 B = 4;
544 E = 6;
545 break;
546 case RISCV::PseudoCCAND:
547 case RISCV::PseudoCCOR:
548 case RISCV::PseudoCCXOR:
549 B = 4;
550 E = 7;
551 break;
552 }
553
554 for (unsigned I = B; I != E; I += D) {
555 if (!MI->getOperand(I).isReg())
556 return false;
557
558 if (!AddRegDefToWorkList(MI->getOperand(I).getReg()))
559 return false;
560 }
561
562 break;
563 }
564
565 case RISCV::CZERO_EQZ:
566 case RISCV::CZERO_NEZ:
567 case RISCV::VT_MASKC:
568 case RISCV::VT_MASKCN:
569 // Instructions return zero or operand 1. Result is sign extended if
570 // operand 1 is sign extended.
571 if (!AddRegDefToWorkList(MI->getOperand(1).getReg()))
572 return false;
573 break;
574
575 // With these opcode, we can "fix" them with the W-version
576 // if we know all users of the result only rely on bits 31:0
577 case RISCV::SLLI:
578 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
579 if (MI->getOperand(2).getImm() >= 32)
580 return false;
581 [[fallthrough]];
582 case RISCV::ADDI:
583 case RISCV::ADD:
584 case RISCV::LD:
585 case RISCV::LWU:
586 case RISCV::MUL:
587 case RISCV::SUB:
588 if (hasAllWUsers(*MI, ST, MRI)) {
589 FixableDef.insert(MI);
590 break;
591 }
592 return false;
593 }
594 }
595
596 // If we get here, then every node we visited produces a sign extended value
597 // or propagated sign extended values. So the result must be sign extended.
598 return true;
599 }
600
getWOp(unsigned Opcode)601 static unsigned getWOp(unsigned Opcode) {
602 switch (Opcode) {
603 case RISCV::ADDI:
604 return RISCV::ADDIW;
605 case RISCV::ADD:
606 return RISCV::ADDW;
607 case RISCV::LD:
608 case RISCV::LWU:
609 return RISCV::LW;
610 case RISCV::MUL:
611 return RISCV::MULW;
612 case RISCV::SLLI:
613 return RISCV::SLLIW;
614 case RISCV::SUB:
615 return RISCV::SUBW;
616 default:
617 llvm_unreachable("Unexpected opcode for replacement with W variant");
618 }
619 }
620
removeSExtWInstrs(MachineFunction & MF,const RISCVInstrInfo & TII,const RISCVSubtarget & ST,MachineRegisterInfo & MRI)621 bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
622 const RISCVInstrInfo &TII,
623 const RISCVSubtarget &ST,
624 MachineRegisterInfo &MRI) {
625 if (DisableSExtWRemoval)
626 return false;
627
628 bool MadeChange = false;
629 for (MachineBasicBlock &MBB : MF) {
630 for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) {
631 // We're looking for the sext.w pattern ADDIW rd, rs1, 0.
632 if (!RISCV::isSEXT_W(MI))
633 continue;
634
635 Register SrcReg = MI.getOperand(1).getReg();
636
637 SmallPtrSet<MachineInstr *, 4> FixableDefs;
638
639 // If all users only use the lower bits, this sext.w is redundant.
640 // Or if all definitions reaching MI sign-extend their output,
641 // then sext.w is redundant.
642 if (!hasAllWUsers(MI, ST, MRI) &&
643 !isSignExtendedW(SrcReg, ST, MRI, FixableDefs))
644 continue;
645
646 Register DstReg = MI.getOperand(0).getReg();
647 if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg)))
648 continue;
649
650 // Convert Fixable instructions to their W versions.
651 for (MachineInstr *Fixable : FixableDefs) {
652 LLVM_DEBUG(dbgs() << "Replacing " << *Fixable);
653 Fixable->setDesc(TII.get(getWOp(Fixable->getOpcode())));
654 Fixable->clearFlag(MachineInstr::MIFlag::NoSWrap);
655 Fixable->clearFlag(MachineInstr::MIFlag::NoUWrap);
656 Fixable->clearFlag(MachineInstr::MIFlag::IsExact);
657 LLVM_DEBUG(dbgs() << " with " << *Fixable);
658 ++NumTransformedToWInstrs;
659 }
660
661 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
662 MRI.replaceRegWith(DstReg, SrcReg);
663 MRI.clearKillFlags(SrcReg);
664 MI.eraseFromParent();
665 ++NumRemovedSExtW;
666 MadeChange = true;
667 }
668 }
669
670 return MadeChange;
671 }
672
stripWSuffixes(MachineFunction & MF,const RISCVInstrInfo & TII,const RISCVSubtarget & ST,MachineRegisterInfo & MRI)673 bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF,
674 const RISCVInstrInfo &TII,
675 const RISCVSubtarget &ST,
676 MachineRegisterInfo &MRI) {
677 if (DisableStripWSuffix)
678 return false;
679
680 bool MadeChange = false;
681 for (MachineBasicBlock &MBB : MF) {
682 for (MachineInstr &MI : MBB) {
683 unsigned Opc;
684 switch (MI.getOpcode()) {
685 default:
686 continue;
687 case RISCV::ADDW: Opc = RISCV::ADD; break;
688 case RISCV::ADDIW: Opc = RISCV::ADDI; break;
689 case RISCV::MULW: Opc = RISCV::MUL; break;
690 case RISCV::SLLIW: Opc = RISCV::SLLI; break;
691 }
692
693 if (hasAllWUsers(MI, ST, MRI)) {
694 MI.setDesc(TII.get(Opc));
695 MadeChange = true;
696 }
697 }
698 }
699
700 return MadeChange;
701 }
702
runOnMachineFunction(MachineFunction & MF)703 bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) {
704 if (skipFunction(MF.getFunction()))
705 return false;
706
707 MachineRegisterInfo &MRI = MF.getRegInfo();
708 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
709 const RISCVInstrInfo &TII = *ST.getInstrInfo();
710
711 if (!ST.is64Bit())
712 return false;
713
714 bool MadeChange = false;
715 MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI);
716 MadeChange |= stripWSuffixes(MF, TII, ST, MRI);
717
718 return MadeChange;
719 }
720