1 //===- RISCVFoldMasks.cpp - MI Vector Pseudo Mask Peepholes ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 // 9 // This pass performs various peephole optimisations that fold masks into vector 10 // pseudo instructions after instruction selection. 11 // 12 // Currently it converts 13 // PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew 14 // -> 15 // PseudoVMV_V_V %false, %true, %vl, %sew 16 // 17 //===---------------------------------------------------------------------===// 18 19 #include "RISCV.h" 20 #include "RISCVISelDAGToDAG.h" 21 #include "RISCVSubtarget.h" 22 #include "llvm/CodeGen/MachineFunctionPass.h" 23 #include "llvm/CodeGen/MachineRegisterInfo.h" 24 #include "llvm/CodeGen/TargetInstrInfo.h" 25 #include "llvm/CodeGen/TargetRegisterInfo.h" 26 27 using namespace llvm; 28 29 #define DEBUG_TYPE "riscv-fold-masks" 30 31 namespace { 32 33 class RISCVFoldMasks : public MachineFunctionPass { 34 public: 35 static char ID; 36 const TargetInstrInfo *TII; 37 MachineRegisterInfo *MRI; 38 const TargetRegisterInfo *TRI; 39 RISCVFoldMasks() : MachineFunctionPass(ID) {} 40 41 bool runOnMachineFunction(MachineFunction &MF) override; 42 MachineFunctionProperties getRequiredProperties() const override { 43 return MachineFunctionProperties().set( 44 MachineFunctionProperties::Property::IsSSA); 45 } 46 47 StringRef getPassName() const override { return "RISC-V Fold Masks"; } 48 49 private: 50 bool convertToUnmasked(MachineInstr &MI, MachineInstr *MaskDef); 51 bool convertVMergeToVMv(MachineInstr &MI, MachineInstr *MaskDef); 52 53 bool isAllOnesMask(MachineInstr *MaskDef); 54 }; 55 56 } // namespace 57 58 char RISCVFoldMasks::ID = 0; 59 60 INITIALIZE_PASS(RISCVFoldMasks, DEBUG_TYPE, "RISC-V Fold Masks", false, false) 61 62 bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) { 63 if (!MaskDef) 64 return false; 65 assert(MaskDef->isCopy() && MaskDef->getOperand(0).getReg() == RISCV::V0); 66 Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI); 67 if (!SrcReg.isVirtual()) 68 return false; 69 MaskDef = MRI->getVRegDef(SrcReg); 70 if (!MaskDef) 71 return false; 72 73 // TODO: Check that the VMSET is the expected bitwidth? The pseudo has 74 // undefined behaviour if it's the wrong bitwidth, so we could choose to 75 // assume that it's all-ones? Same applies to its VL. 76 switch (MaskDef->getOpcode()) { 77 case RISCV::PseudoVMSET_M_B1: 78 case RISCV::PseudoVMSET_M_B2: 79 case RISCV::PseudoVMSET_M_B4: 80 case RISCV::PseudoVMSET_M_B8: 81 case RISCV::PseudoVMSET_M_B16: 82 case RISCV::PseudoVMSET_M_B32: 83 case RISCV::PseudoVMSET_M_B64: 84 return true; 85 default: 86 return false; 87 } 88 } 89 90 // Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to 91 // (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET. 92 bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI, MachineInstr *V0Def) { 93 #define CASE_VMERGE_TO_VMV(lmul) \ 94 case RISCV::PseudoVMERGE_VVM_##lmul: \ 95 NewOpc = RISCV::PseudoVMV_V_V_##lmul; \ 96 break; 97 unsigned NewOpc; 98 switch (MI.getOpcode()) { 99 default: 100 return false; 101 CASE_VMERGE_TO_VMV(MF8) 102 CASE_VMERGE_TO_VMV(MF4) 103 CASE_VMERGE_TO_VMV(MF2) 104 CASE_VMERGE_TO_VMV(M1) 105 CASE_VMERGE_TO_VMV(M2) 106 CASE_VMERGE_TO_VMV(M4) 107 CASE_VMERGE_TO_VMV(M8) 108 } 109 110 Register MergeReg = MI.getOperand(1).getReg(); 111 Register FalseReg = MI.getOperand(2).getReg(); 112 // Check merge == false (or merge == undef) 113 if (MergeReg != RISCV::NoRegister && TRI->lookThruCopyLike(MergeReg, MRI) != 114 TRI->lookThruCopyLike(FalseReg, MRI)) 115 return false; 116 117 assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0); 118 if (!isAllOnesMask(V0Def)) 119 return false; 120 121 MI.setDesc(TII->get(NewOpc)); 122 MI.removeOperand(1); // Merge operand 123 MI.tieOperands(0, 1); // Tie false to dest 124 MI.removeOperand(3); // Mask operand 125 MI.addOperand( 126 MachineOperand::CreateImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED)); 127 128 // vmv.v.v doesn't have a mask operand, so we may be able to inflate the 129 // register class for the destination and merge operands e.g. VRNoV0 -> VR 130 MRI->recomputeRegClass(MI.getOperand(0).getReg()); 131 MRI->recomputeRegClass(MI.getOperand(1).getReg()); 132 return true; 133 } 134 135 bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI, 136 MachineInstr *MaskDef) { 137 const RISCV::RISCVMaskedPseudoInfo *I = 138 RISCV::getMaskedPseudoInfo(MI.getOpcode()); 139 if (!I) 140 return false; 141 142 if (!isAllOnesMask(MaskDef)) 143 return false; 144 145 // There are two classes of pseudos in the table - compares and 146 // everything else. See the comment on RISCVMaskedPseudo for details. 147 const unsigned Opc = I->UnmaskedPseudo; 148 const MCInstrDesc &MCID = TII->get(Opc); 149 const bool HasPolicyOp = RISCVII::hasVecPolicyOp(MCID.TSFlags); 150 const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MCID); 151 #ifndef NDEBUG 152 const MCInstrDesc &MaskedMCID = TII->get(MI.getOpcode()); 153 assert(RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) == 154 RISCVII::hasVecPolicyOp(MCID.TSFlags) && 155 "Masked and unmasked pseudos are inconsistent"); 156 assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure"); 157 #endif 158 (void)HasPolicyOp; 159 160 MI.setDesc(MCID); 161 162 // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs? 163 unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs(); 164 MI.removeOperand(MaskOpIdx); 165 166 // The unmasked pseudo will no longer be constrained to the vrnov0 reg class, 167 // so try and relax it to vr. 168 MRI->recomputeRegClass(MI.getOperand(0).getReg()); 169 unsigned PassthruOpIdx = MI.getNumExplicitDefs(); 170 if (HasPassthru) { 171 if (MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister) 172 MRI->recomputeRegClass(MI.getOperand(PassthruOpIdx).getReg()); 173 } else 174 MI.removeOperand(PassthruOpIdx); 175 176 return true; 177 } 178 179 bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) { 180 if (skipFunction(MF.getFunction())) 181 return false; 182 183 // Skip if the vector extension is not enabled. 184 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); 185 if (!ST.hasVInstructions()) 186 return false; 187 188 TII = ST.getInstrInfo(); 189 MRI = &MF.getRegInfo(); 190 TRI = MRI->getTargetRegisterInfo(); 191 192 bool Changed = false; 193 194 // Masked pseudos coming out of isel will have their mask operand in the form: 195 // 196 // $v0:vr = COPY %mask:vr 197 // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr 198 // 199 // Because $v0 isn't in SSA, keep track of it so we can check the mask operand 200 // on each pseudo. 201 MachineInstr *CurrentV0Def; 202 for (MachineBasicBlock &MBB : MF) { 203 CurrentV0Def = nullptr; 204 for (MachineInstr &MI : MBB) { 205 Changed |= convertToUnmasked(MI, CurrentV0Def); 206 Changed |= convertVMergeToVMv(MI, CurrentV0Def); 207 208 if (MI.definesRegister(RISCV::V0, TRI)) 209 CurrentV0Def = &MI; 210 } 211 } 212 213 return Changed; 214 } 215 216 FunctionPass *llvm::createRISCVFoldMasksPass() { return new RISCVFoldMasks(); } 217