1 //===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 // 9 // This pass does some optimizations for *W instructions at the MI level. 10 // 11 // First it removes unneeded sext.w instructions. Either because the sign 12 // extended bits aren't consumed or because the input was already sign extended 13 // by an earlier instruction. 14 // 15 // Then it removes the -w suffix from opw instructions whenever all users are 16 // dependent only on the lower word of the result of the instruction. 17 // The cases handled are: 18 // * addw because c.add has a larger register encoding than c.addw. 19 // * addiw because it helps reduce test differences between RV32 and RV64 20 // w/o being a pessimization. 21 // * mulw because c.mulw doesn't exist but c.mul does (w/ zcb) 22 // * slliw because c.slliw doesn't exist and c.slli does 23 // 24 //===---------------------------------------------------------------------===// 25 26 #include "RISCV.h" 27 #include "RISCVMachineFunctionInfo.h" 28 #include "RISCVSubtarget.h" 29 #include "llvm/ADT/SmallSet.h" 30 #include "llvm/ADT/Statistic.h" 31 #include "llvm/CodeGen/MachineFunctionPass.h" 32 #include "llvm/CodeGen/TargetInstrInfo.h" 33 34 using namespace llvm; 35 36 #define DEBUG_TYPE "riscv-opt-w-instrs" 37 #define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions" 38 39 STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions"); 40 STATISTIC(NumTransformedToWInstrs, 41 "Number of instructions transformed to W-ops"); 42 43 static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal", 44 cl::desc("Disable removal of sext.w"), 45 cl::init(false), cl::Hidden); 46 static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix", 47 cl::desc("Disable strip W suffix"), 48 cl::init(false), cl::Hidden); 49 50 namespace { 51 52 class RISCVOptWInstrs : public MachineFunctionPass { 53 public: 54 static char ID; 55 56 RISCVOptWInstrs() : MachineFunctionPass(ID) {} 57 58 bool runOnMachineFunction(MachineFunction &MF) override; 59 bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII, 60 const RISCVSubtarget &ST, MachineRegisterInfo &MRI); 61 bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, 62 const RISCVSubtarget &ST, MachineRegisterInfo &MRI); 63 64 void getAnalysisUsage(AnalysisUsage &AU) const override { 65 AU.setPreservesCFG(); 66 MachineFunctionPass::getAnalysisUsage(AU); 67 } 68 69 StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; } 70 }; 71 72 } // end anonymous namespace 73 74 char RISCVOptWInstrs::ID = 0; 75 INITIALIZE_PASS(RISCVOptWInstrs, DEBUG_TYPE, RISCV_OPT_W_INSTRS_NAME, false, 76 false) 77 78 FunctionPass *llvm::createRISCVOptWInstrsPass() { 79 return new RISCVOptWInstrs(); 80 } 81 82 static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp, 83 unsigned Bits) { 84 const MachineInstr &MI = *UserOp.getParent(); 85 unsigned MCOpcode = RISCV::getRVVMCOpcode(MI.getOpcode()); 86 87 if (!MCOpcode) 88 return false; 89 90 const MCInstrDesc &MCID = MI.getDesc(); 91 const uint64_t TSFlags = MCID.TSFlags; 92 if (!RISCVII::hasSEWOp(TSFlags)) 93 return false; 94 assert(RISCVII::hasVLOp(TSFlags)); 95 const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm(); 96 97 if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID)) 98 return false; 99 100 auto NumDemandedBits = 101 RISCV::getVectorLowDemandedScalarBits(MCOpcode, Log2SEW); 102 return NumDemandedBits && Bits >= *NumDemandedBits; 103 } 104 105 // Checks if all users only demand the lower \p OrigBits of the original 106 // instruction's result. 107 // TODO: handle multiple interdependent transformations 108 static bool hasAllNBitUsers(const MachineInstr &OrigMI, 109 const RISCVSubtarget &ST, 110 const MachineRegisterInfo &MRI, unsigned OrigBits) { 111 112 SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited; 113 SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist; 114 115 Worklist.push_back(std::make_pair(&OrigMI, OrigBits)); 116 117 while (!Worklist.empty()) { 118 auto P = Worklist.pop_back_val(); 119 const MachineInstr *MI = P.first; 120 unsigned Bits = P.second; 121 122 if (!Visited.insert(P).second) 123 continue; 124 125 // Only handle instructions with one def. 126 if (MI->getNumExplicitDefs() != 1) 127 return false; 128 129 for (auto &UserOp : MRI.use_nodbg_operands(MI->getOperand(0).getReg())) { 130 const MachineInstr *UserMI = UserOp.getParent(); 131 unsigned OpIdx = UserOp.getOperandNo(); 132 133 switch (UserMI->getOpcode()) { 134 default: 135 if (vectorPseudoHasAllNBitUsers(UserOp, Bits)) 136 break; 137 return false; 138 139 case RISCV::ADDIW: 140 case RISCV::ADDW: 141 case RISCV::DIVUW: 142 case RISCV::DIVW: 143 case RISCV::MULW: 144 case RISCV::REMUW: 145 case RISCV::REMW: 146 case RISCV::SLLIW: 147 case RISCV::SLLW: 148 case RISCV::SRAIW: 149 case RISCV::SRAW: 150 case RISCV::SRLIW: 151 case RISCV::SRLW: 152 case RISCV::SUBW: 153 case RISCV::ROLW: 154 case RISCV::RORW: 155 case RISCV::RORIW: 156 case RISCV::CLZW: 157 case RISCV::CTZW: 158 case RISCV::CPOPW: 159 case RISCV::SLLI_UW: 160 case RISCV::FMV_W_X: 161 case RISCV::FCVT_H_W: 162 case RISCV::FCVT_H_WU: 163 case RISCV::FCVT_S_W: 164 case RISCV::FCVT_S_WU: 165 case RISCV::FCVT_D_W: 166 case RISCV::FCVT_D_WU: 167 if (Bits >= 32) 168 break; 169 return false; 170 case RISCV::SEXT_B: 171 case RISCV::PACKH: 172 if (Bits >= 8) 173 break; 174 return false; 175 case RISCV::SEXT_H: 176 case RISCV::FMV_H_X: 177 case RISCV::ZEXT_H_RV32: 178 case RISCV::ZEXT_H_RV64: 179 case RISCV::PACKW: 180 if (Bits >= 16) 181 break; 182 return false; 183 184 case RISCV::PACK: 185 if (Bits >= (ST.getXLen() / 2)) 186 break; 187 return false; 188 189 case RISCV::SRLI: { 190 // If we are shifting right by less than Bits, and users don't demand 191 // any bits that were shifted into [Bits-1:0], then we can consider this 192 // as an N-Bit user. 193 unsigned ShAmt = UserMI->getOperand(2).getImm(); 194 if (Bits > ShAmt) { 195 Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt)); 196 break; 197 } 198 return false; 199 } 200 201 // these overwrite higher input bits, otherwise the lower word of output 202 // depends only on the lower word of input. So check their uses read W. 203 case RISCV::SLLI: 204 if (Bits >= (ST.getXLen() - UserMI->getOperand(2).getImm())) 205 break; 206 Worklist.push_back(std::make_pair(UserMI, Bits)); 207 break; 208 case RISCV::ANDI: { 209 uint64_t Imm = UserMI->getOperand(2).getImm(); 210 if (Bits >= (unsigned)llvm::bit_width(Imm)) 211 break; 212 Worklist.push_back(std::make_pair(UserMI, Bits)); 213 break; 214 } 215 case RISCV::ORI: { 216 uint64_t Imm = UserMI->getOperand(2).getImm(); 217 if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm)) 218 break; 219 Worklist.push_back(std::make_pair(UserMI, Bits)); 220 break; 221 } 222 223 case RISCV::SLL: 224 case RISCV::BSET: 225 case RISCV::BCLR: 226 case RISCV::BINV: 227 // Operand 2 is the shift amount which uses log2(xlen) bits. 228 if (OpIdx == 2) { 229 if (Bits >= Log2_32(ST.getXLen())) 230 break; 231 return false; 232 } 233 Worklist.push_back(std::make_pair(UserMI, Bits)); 234 break; 235 236 case RISCV::SRA: 237 case RISCV::SRL: 238 case RISCV::ROL: 239 case RISCV::ROR: 240 // Operand 2 is the shift amount which uses 6 bits. 241 if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen())) 242 break; 243 return false; 244 245 case RISCV::ADD_UW: 246 case RISCV::SH1ADD_UW: 247 case RISCV::SH2ADD_UW: 248 case RISCV::SH3ADD_UW: 249 // Operand 1 is implicitly zero extended. 250 if (OpIdx == 1 && Bits >= 32) 251 break; 252 Worklist.push_back(std::make_pair(UserMI, Bits)); 253 break; 254 255 case RISCV::BEXTI: 256 if (UserMI->getOperand(2).getImm() >= Bits) 257 return false; 258 break; 259 260 case RISCV::SB: 261 // The first argument is the value to store. 262 if (OpIdx == 0 && Bits >= 8) 263 break; 264 return false; 265 case RISCV::SH: 266 // The first argument is the value to store. 267 if (OpIdx == 0 && Bits >= 16) 268 break; 269 return false; 270 case RISCV::SW: 271 // The first argument is the value to store. 272 if (OpIdx == 0 && Bits >= 32) 273 break; 274 return false; 275 276 // For these, lower word of output in these operations, depends only on 277 // the lower word of input. So, we check all uses only read lower word. 278 case RISCV::COPY: 279 case RISCV::PHI: 280 281 case RISCV::ADD: 282 case RISCV::ADDI: 283 case RISCV::AND: 284 case RISCV::MUL: 285 case RISCV::OR: 286 case RISCV::SUB: 287 case RISCV::XOR: 288 case RISCV::XORI: 289 290 case RISCV::ANDN: 291 case RISCV::BREV8: 292 case RISCV::CLMUL: 293 case RISCV::ORC_B: 294 case RISCV::ORN: 295 case RISCV::SH1ADD: 296 case RISCV::SH2ADD: 297 case RISCV::SH3ADD: 298 case RISCV::XNOR: 299 case RISCV::BSETI: 300 case RISCV::BCLRI: 301 case RISCV::BINVI: 302 Worklist.push_back(std::make_pair(UserMI, Bits)); 303 break; 304 305 case RISCV::PseudoCCMOVGPR: 306 // Either operand 4 or operand 5 is returned by this instruction. If 307 // only the lower word of the result is used, then only the lower word 308 // of operand 4 and 5 is used. 309 if (OpIdx != 4 && OpIdx != 5) 310 return false; 311 Worklist.push_back(std::make_pair(UserMI, Bits)); 312 break; 313 314 case RISCV::CZERO_EQZ: 315 case RISCV::CZERO_NEZ: 316 case RISCV::VT_MASKC: 317 case RISCV::VT_MASKCN: 318 if (OpIdx != 1) 319 return false; 320 Worklist.push_back(std::make_pair(UserMI, Bits)); 321 break; 322 } 323 } 324 } 325 326 return true; 327 } 328 329 static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST, 330 const MachineRegisterInfo &MRI) { 331 return hasAllNBitUsers(OrigMI, ST, MRI, 32); 332 } 333 334 // This function returns true if the machine instruction always outputs a value 335 // where bits 63:32 match bit 31. 336 static bool isSignExtendingOpW(const MachineInstr &MI, 337 const MachineRegisterInfo &MRI) { 338 uint64_t TSFlags = MI.getDesc().TSFlags; 339 340 // Instructions that can be determined from opcode are marked in tablegen. 341 if (TSFlags & RISCVII::IsSignExtendingOpWMask) 342 return true; 343 344 // Special cases that require checking operands. 345 switch (MI.getOpcode()) { 346 // shifting right sufficiently makes the value 32-bit sign-extended 347 case RISCV::SRAI: 348 return MI.getOperand(2).getImm() >= 32; 349 case RISCV::SRLI: 350 return MI.getOperand(2).getImm() > 32; 351 // The LI pattern ADDI rd, X0, imm is sign extended. 352 case RISCV::ADDI: 353 return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0; 354 // An ANDI with an 11 bit immediate will zero bits 63:11. 355 case RISCV::ANDI: 356 return isUInt<11>(MI.getOperand(2).getImm()); 357 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11. 358 case RISCV::ORI: 359 return !isUInt<11>(MI.getOperand(2).getImm()); 360 // A bseti with X0 is sign extended if the immediate is less than 31. 361 case RISCV::BSETI: 362 return MI.getOperand(2).getImm() < 31 && 363 MI.getOperand(1).getReg() == RISCV::X0; 364 // Copying from X0 produces zero. 365 case RISCV::COPY: 366 return MI.getOperand(1).getReg() == RISCV::X0; 367 case RISCV::PseudoAtomicLoadNand32: 368 return true; 369 case RISCV::PseudoVMV_X_S_MF8: 370 case RISCV::PseudoVMV_X_S_MF4: 371 case RISCV::PseudoVMV_X_S_MF2: 372 case RISCV::PseudoVMV_X_S_M1: 373 case RISCV::PseudoVMV_X_S_M2: 374 case RISCV::PseudoVMV_X_S_M4: 375 case RISCV::PseudoVMV_X_S_M8: { 376 // vmv.x.s has at least 33 sign bits if log2(sew) <= 5. 377 int64_t Log2SEW = MI.getOperand(2).getImm(); 378 assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW"); 379 return Log2SEW <= 5; 380 } 381 } 382 383 return false; 384 } 385 386 static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST, 387 const MachineRegisterInfo &MRI, 388 SmallPtrSetImpl<MachineInstr *> &FixableDef) { 389 390 SmallPtrSet<const MachineInstr *, 4> Visited; 391 SmallVector<MachineInstr *, 4> Worklist; 392 393 auto AddRegDefToWorkList = [&](Register SrcReg) { 394 if (!SrcReg.isVirtual()) 395 return false; 396 MachineInstr *SrcMI = MRI.getVRegDef(SrcReg); 397 if (!SrcMI) 398 return false; 399 // Code assumes the register is operand 0. 400 // TODO: Maybe the worklist should store register? 401 if (!SrcMI->getOperand(0).isReg() || 402 SrcMI->getOperand(0).getReg() != SrcReg) 403 return false; 404 // Add SrcMI to the worklist. 405 Worklist.push_back(SrcMI); 406 return true; 407 }; 408 409 if (!AddRegDefToWorkList(SrcReg)) 410 return false; 411 412 while (!Worklist.empty()) { 413 MachineInstr *MI = Worklist.pop_back_val(); 414 415 // If we already visited this instruction, we don't need to check it again. 416 if (!Visited.insert(MI).second) 417 continue; 418 419 // If this is a sign extending operation we don't need to look any further. 420 if (isSignExtendingOpW(*MI, MRI)) 421 continue; 422 423 // Is this an instruction that propagates sign extend? 424 switch (MI->getOpcode()) { 425 default: 426 // Unknown opcode, give up. 427 return false; 428 case RISCV::COPY: { 429 const MachineFunction *MF = MI->getMF(); 430 const RISCVMachineFunctionInfo *RVFI = 431 MF->getInfo<RISCVMachineFunctionInfo>(); 432 433 // If this is the entry block and the register is livein, see if we know 434 // it is sign extended. 435 if (MI->getParent() == &MF->front()) { 436 Register VReg = MI->getOperand(0).getReg(); 437 if (MF->getRegInfo().isLiveIn(VReg) && RVFI->isSExt32Register(VReg)) 438 continue; 439 } 440 441 Register CopySrcReg = MI->getOperand(1).getReg(); 442 if (CopySrcReg == RISCV::X10) { 443 // For a method return value, we check the ZExt/SExt flags in attribute. 444 // We assume the following code sequence for method call. 445 // PseudoCALL @bar, ... 446 // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2 447 // %0:gpr = COPY $x10 448 // 449 // We use the PseudoCall to look up the IR function being called to find 450 // its return attributes. 451 const MachineBasicBlock *MBB = MI->getParent(); 452 auto II = MI->getIterator(); 453 if (II == MBB->instr_begin() || 454 (--II)->getOpcode() != RISCV::ADJCALLSTACKUP) 455 return false; 456 457 const MachineInstr &CallMI = *(--II); 458 if (!CallMI.isCall() || !CallMI.getOperand(0).isGlobal()) 459 return false; 460 461 auto *CalleeFn = 462 dyn_cast_if_present<Function>(CallMI.getOperand(0).getGlobal()); 463 if (!CalleeFn) 464 return false; 465 466 auto *IntTy = dyn_cast<IntegerType>(CalleeFn->getReturnType()); 467 if (!IntTy) 468 return false; 469 470 const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs(); 471 unsigned BitWidth = IntTy->getBitWidth(); 472 if ((BitWidth <= 32 && Attrs.hasAttribute(Attribute::SExt)) || 473 (BitWidth < 32 && Attrs.hasAttribute(Attribute::ZExt))) 474 continue; 475 } 476 477 if (!AddRegDefToWorkList(CopySrcReg)) 478 return false; 479 480 break; 481 } 482 483 // For these, we just need to check if the 1st operand is sign extended. 484 case RISCV::BCLRI: 485 case RISCV::BINVI: 486 case RISCV::BSETI: 487 if (MI->getOperand(2).getImm() >= 31) 488 return false; 489 [[fallthrough]]; 490 case RISCV::REM: 491 case RISCV::ANDI: 492 case RISCV::ORI: 493 case RISCV::XORI: 494 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R. 495 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1 496 // Logical operations use a sign extended 12-bit immediate. 497 if (!AddRegDefToWorkList(MI->getOperand(1).getReg())) 498 return false; 499 500 break; 501 case RISCV::PseudoCCADDW: 502 case RISCV::PseudoCCADDIW: 503 case RISCV::PseudoCCSUBW: 504 case RISCV::PseudoCCSLLW: 505 case RISCV::PseudoCCSRLW: 506 case RISCV::PseudoCCSRAW: 507 case RISCV::PseudoCCSLLIW: 508 case RISCV::PseudoCCSRLIW: 509 case RISCV::PseudoCCSRAIW: 510 // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only 511 // need to check if operand 4 is sign extended. 512 if (!AddRegDefToWorkList(MI->getOperand(4).getReg())) 513 return false; 514 break; 515 case RISCV::REMU: 516 case RISCV::AND: 517 case RISCV::OR: 518 case RISCV::XOR: 519 case RISCV::ANDN: 520 case RISCV::ORN: 521 case RISCV::XNOR: 522 case RISCV::MAX: 523 case RISCV::MAXU: 524 case RISCV::MIN: 525 case RISCV::MINU: 526 case RISCV::PseudoCCMOVGPR: 527 case RISCV::PseudoCCAND: 528 case RISCV::PseudoCCOR: 529 case RISCV::PseudoCCXOR: 530 case RISCV::PHI: { 531 // If all incoming values are sign-extended, the output of AND, OR, XOR, 532 // MIN, MAX, or PHI is also sign-extended. 533 534 // The input registers for PHI are operand 1, 3, ... 535 // The input registers for PseudoCCMOVGPR are 4 and 5. 536 // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6. 537 // The input registers for others are operand 1 and 2. 538 unsigned B = 1, E = 3, D = 1; 539 switch (MI->getOpcode()) { 540 case RISCV::PHI: 541 E = MI->getNumOperands(); 542 D = 2; 543 break; 544 case RISCV::PseudoCCMOVGPR: 545 B = 4; 546 E = 6; 547 break; 548 case RISCV::PseudoCCAND: 549 case RISCV::PseudoCCOR: 550 case RISCV::PseudoCCXOR: 551 B = 4; 552 E = 7; 553 break; 554 } 555 556 for (unsigned I = B; I != E; I += D) { 557 if (!MI->getOperand(I).isReg()) 558 return false; 559 560 if (!AddRegDefToWorkList(MI->getOperand(I).getReg())) 561 return false; 562 } 563 564 break; 565 } 566 567 case RISCV::CZERO_EQZ: 568 case RISCV::CZERO_NEZ: 569 case RISCV::VT_MASKC: 570 case RISCV::VT_MASKCN: 571 // Instructions return zero or operand 1. Result is sign extended if 572 // operand 1 is sign extended. 573 if (!AddRegDefToWorkList(MI->getOperand(1).getReg())) 574 return false; 575 break; 576 577 // With these opcode, we can "fix" them with the W-version 578 // if we know all users of the result only rely on bits 31:0 579 case RISCV::SLLI: 580 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits 581 if (MI->getOperand(2).getImm() >= 32) 582 return false; 583 [[fallthrough]]; 584 case RISCV::ADDI: 585 case RISCV::ADD: 586 case RISCV::LD: 587 case RISCV::LWU: 588 case RISCV::MUL: 589 case RISCV::SUB: 590 if (hasAllWUsers(*MI, ST, MRI)) { 591 FixableDef.insert(MI); 592 break; 593 } 594 return false; 595 } 596 } 597 598 // If we get here, then every node we visited produces a sign extended value 599 // or propagated sign extended values. So the result must be sign extended. 600 return true; 601 } 602 603 static unsigned getWOp(unsigned Opcode) { 604 switch (Opcode) { 605 case RISCV::ADDI: 606 return RISCV::ADDIW; 607 case RISCV::ADD: 608 return RISCV::ADDW; 609 case RISCV::LD: 610 case RISCV::LWU: 611 return RISCV::LW; 612 case RISCV::MUL: 613 return RISCV::MULW; 614 case RISCV::SLLI: 615 return RISCV::SLLIW; 616 case RISCV::SUB: 617 return RISCV::SUBW; 618 default: 619 llvm_unreachable("Unexpected opcode for replacement with W variant"); 620 } 621 } 622 623 bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF, 624 const RISCVInstrInfo &TII, 625 const RISCVSubtarget &ST, 626 MachineRegisterInfo &MRI) { 627 if (DisableSExtWRemoval) 628 return false; 629 630 bool MadeChange = false; 631 for (MachineBasicBlock &MBB : MF) { 632 for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) { 633 // We're looking for the sext.w pattern ADDIW rd, rs1, 0. 634 if (!RISCV::isSEXT_W(MI)) 635 continue; 636 637 Register SrcReg = MI.getOperand(1).getReg(); 638 639 SmallPtrSet<MachineInstr *, 4> FixableDefs; 640 641 // If all users only use the lower bits, this sext.w is redundant. 642 // Or if all definitions reaching MI sign-extend their output, 643 // then sext.w is redundant. 644 if (!hasAllWUsers(MI, ST, MRI) && 645 !isSignExtendedW(SrcReg, ST, MRI, FixableDefs)) 646 continue; 647 648 Register DstReg = MI.getOperand(0).getReg(); 649 if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg))) 650 continue; 651 652 // Convert Fixable instructions to their W versions. 653 for (MachineInstr *Fixable : FixableDefs) { 654 LLVM_DEBUG(dbgs() << "Replacing " << *Fixable); 655 Fixable->setDesc(TII.get(getWOp(Fixable->getOpcode()))); 656 Fixable->clearFlag(MachineInstr::MIFlag::NoSWrap); 657 Fixable->clearFlag(MachineInstr::MIFlag::NoUWrap); 658 Fixable->clearFlag(MachineInstr::MIFlag::IsExact); 659 LLVM_DEBUG(dbgs() << " with " << *Fixable); 660 ++NumTransformedToWInstrs; 661 } 662 663 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n"); 664 MRI.replaceRegWith(DstReg, SrcReg); 665 MRI.clearKillFlags(SrcReg); 666 MI.eraseFromParent(); 667 ++NumRemovedSExtW; 668 MadeChange = true; 669 } 670 } 671 672 return MadeChange; 673 } 674 675 bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF, 676 const RISCVInstrInfo &TII, 677 const RISCVSubtarget &ST, 678 MachineRegisterInfo &MRI) { 679 if (DisableStripWSuffix) 680 return false; 681 682 bool MadeChange = false; 683 for (MachineBasicBlock &MBB : MF) { 684 for (MachineInstr &MI : MBB) { 685 unsigned Opc; 686 switch (MI.getOpcode()) { 687 default: 688 continue; 689 case RISCV::ADDW: Opc = RISCV::ADD; break; 690 case RISCV::ADDIW: Opc = RISCV::ADDI; break; 691 case RISCV::MULW: Opc = RISCV::MUL; break; 692 case RISCV::SLLIW: Opc = RISCV::SLLI; break; 693 } 694 695 if (hasAllWUsers(MI, ST, MRI)) { 696 MI.setDesc(TII.get(Opc)); 697 MadeChange = true; 698 } 699 } 700 } 701 702 return MadeChange; 703 } 704 705 bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) { 706 if (skipFunction(MF.getFunction())) 707 return false; 708 709 MachineRegisterInfo &MRI = MF.getRegInfo(); 710 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); 711 const RISCVInstrInfo &TII = *ST.getInstrInfo(); 712 713 if (!ST.is64Bit()) 714 return false; 715 716 bool MadeChange = false; 717 MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI); 718 MadeChange |= stripWSuffixes(MF, TII, ST, MRI); 719 720 return MadeChange; 721 } 722