1 //===- RISCVFoldMasks.cpp - MI Vector Pseudo Mask Peepholes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 //
9 // This pass performs various peephole optimisations that fold masks into vector
10 // pseudo instructions after instruction selection.
11 //
12 // Currently it converts
13 // PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
14 // ->
15 // PseudoVMV_V_V %false, %true, %vl, %sew
16 //
17 //===---------------------------------------------------------------------===//
18 
19 #include "RISCV.h"
20 #include "RISCVISelDAGToDAG.h"
21 #include "RISCVSubtarget.h"
22 #include "llvm/CodeGen/MachineFunctionPass.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/TargetInstrInfo.h"
25 #include "llvm/CodeGen/TargetRegisterInfo.h"
26 
27 using namespace llvm;
28 
29 #define DEBUG_TYPE "riscv-fold-masks"
30 
31 namespace {
32 
33 class RISCVFoldMasks : public MachineFunctionPass {
34 public:
35   static char ID;
36   const TargetInstrInfo *TII;
37   MachineRegisterInfo *MRI;
38   const TargetRegisterInfo *TRI;
RISCVFoldMasks()39   RISCVFoldMasks() : MachineFunctionPass(ID) {}
40 
41   bool runOnMachineFunction(MachineFunction &MF) override;
getRequiredProperties() const42   MachineFunctionProperties getRequiredProperties() const override {
43     return MachineFunctionProperties().set(
44         MachineFunctionProperties::Property::IsSSA);
45   }
46 
getPassName() const47   StringRef getPassName() const override { return "RISC-V Fold Masks"; }
48 
49 private:
50   bool convertToUnmasked(MachineInstr &MI, MachineInstr *MaskDef);
51   bool convertVMergeToVMv(MachineInstr &MI, MachineInstr *MaskDef);
52 
53   bool isAllOnesMask(MachineInstr *MaskDef);
54 };
55 
56 } // namespace
57 
58 char RISCVFoldMasks::ID = 0;
59 
60 INITIALIZE_PASS(RISCVFoldMasks, DEBUG_TYPE, "RISC-V Fold Masks", false, false)
61 
isAllOnesMask(MachineInstr * MaskDef)62 bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) {
63   if (!MaskDef)
64     return false;
65   assert(MaskDef->isCopy() && MaskDef->getOperand(0).getReg() == RISCV::V0);
66   Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI);
67   if (!SrcReg.isVirtual())
68     return false;
69   MaskDef = MRI->getVRegDef(SrcReg);
70   if (!MaskDef)
71     return false;
72 
73   // TODO: Check that the VMSET is the expected bitwidth? The pseudo has
74   // undefined behaviour if it's the wrong bitwidth, so we could choose to
75   // assume that it's all-ones? Same applies to its VL.
76   switch (MaskDef->getOpcode()) {
77   case RISCV::PseudoVMSET_M_B1:
78   case RISCV::PseudoVMSET_M_B2:
79   case RISCV::PseudoVMSET_M_B4:
80   case RISCV::PseudoVMSET_M_B8:
81   case RISCV::PseudoVMSET_M_B16:
82   case RISCV::PseudoVMSET_M_B32:
83   case RISCV::PseudoVMSET_M_B64:
84     return true;
85   default:
86     return false;
87   }
88 }
89 
90 // Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
91 // (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
convertVMergeToVMv(MachineInstr & MI,MachineInstr * V0Def)92 bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI, MachineInstr *V0Def) {
93 #define CASE_VMERGE_TO_VMV(lmul)                                               \
94   case RISCV::PseudoVMERGE_VVM_##lmul:                                         \
95     NewOpc = RISCV::PseudoVMV_V_V_##lmul;                                      \
96     break;
97   unsigned NewOpc;
98   switch (MI.getOpcode()) {
99   default:
100     return false;
101     CASE_VMERGE_TO_VMV(MF8)
102     CASE_VMERGE_TO_VMV(MF4)
103     CASE_VMERGE_TO_VMV(MF2)
104     CASE_VMERGE_TO_VMV(M1)
105     CASE_VMERGE_TO_VMV(M2)
106     CASE_VMERGE_TO_VMV(M4)
107     CASE_VMERGE_TO_VMV(M8)
108   }
109 
110   Register MergeReg = MI.getOperand(1).getReg();
111   Register FalseReg = MI.getOperand(2).getReg();
112   // Check merge == false (or merge == undef)
113   if (MergeReg != RISCV::NoRegister && TRI->lookThruCopyLike(MergeReg, MRI) !=
114                                            TRI->lookThruCopyLike(FalseReg, MRI))
115     return false;
116 
117   assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0);
118   if (!isAllOnesMask(V0Def))
119     return false;
120 
121   MI.setDesc(TII->get(NewOpc));
122   MI.removeOperand(1);  // Merge operand
123   MI.tieOperands(0, 1); // Tie false to dest
124   MI.removeOperand(3);  // Mask operand
125   MI.addOperand(
126       MachineOperand::CreateImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED));
127 
128   // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
129   // register class for the destination and merge operands e.g. VRNoV0 -> VR
130   MRI->recomputeRegClass(MI.getOperand(0).getReg());
131   MRI->recomputeRegClass(MI.getOperand(1).getReg());
132   return true;
133 }
134 
convertToUnmasked(MachineInstr & MI,MachineInstr * MaskDef)135 bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI,
136                                        MachineInstr *MaskDef) {
137   const RISCV::RISCVMaskedPseudoInfo *I =
138       RISCV::getMaskedPseudoInfo(MI.getOpcode());
139   if (!I)
140     return false;
141 
142   if (!isAllOnesMask(MaskDef))
143     return false;
144 
145   // There are two classes of pseudos in the table - compares and
146   // everything else.  See the comment on RISCVMaskedPseudo for details.
147   const unsigned Opc = I->UnmaskedPseudo;
148   const MCInstrDesc &MCID = TII->get(Opc);
149   const bool HasPolicyOp = RISCVII::hasVecPolicyOp(MCID.TSFlags);
150   const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MCID);
151 #ifndef NDEBUG
152   const MCInstrDesc &MaskedMCID = TII->get(MI.getOpcode());
153   assert(RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) ==
154              RISCVII::hasVecPolicyOp(MCID.TSFlags) &&
155          "Masked and unmasked pseudos are inconsistent");
156   assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure");
157 #endif
158   (void)HasPolicyOp;
159 
160   MI.setDesc(MCID);
161 
162   // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs?
163   unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs();
164   MI.removeOperand(MaskOpIdx);
165 
166   // The unmasked pseudo will no longer be constrained to the vrnov0 reg class,
167   // so try and relax it to vr.
168   MRI->recomputeRegClass(MI.getOperand(0).getReg());
169   unsigned PassthruOpIdx = MI.getNumExplicitDefs();
170   if (HasPassthru) {
171     if (MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister)
172       MRI->recomputeRegClass(MI.getOperand(PassthruOpIdx).getReg());
173   } else
174     MI.removeOperand(PassthruOpIdx);
175 
176   return true;
177 }
178 
runOnMachineFunction(MachineFunction & MF)179 bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
180   if (skipFunction(MF.getFunction()))
181     return false;
182 
183   // Skip if the vector extension is not enabled.
184   const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
185   if (!ST.hasVInstructions())
186     return false;
187 
188   TII = ST.getInstrInfo();
189   MRI = &MF.getRegInfo();
190   TRI = MRI->getTargetRegisterInfo();
191 
192   bool Changed = false;
193 
194   // Masked pseudos coming out of isel will have their mask operand in the form:
195   //
196   // $v0:vr = COPY %mask:vr
197   // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
198   //
199   // Because $v0 isn't in SSA, keep track of it so we can check the mask operand
200   // on each pseudo.
201   MachineInstr *CurrentV0Def;
202   for (MachineBasicBlock &MBB : MF) {
203     CurrentV0Def = nullptr;
204     for (MachineInstr &MI : MBB) {
205       Changed |= convertToUnmasked(MI, CurrentV0Def);
206       Changed |= convertVMergeToVMv(MI, CurrentV0Def);
207 
208       if (MI.definesRegister(RISCV::V0, TRI))
209         CurrentV0Def = &MI;
210     }
211   }
212 
213   return Changed;
214 }
215 
createRISCVFoldMasksPass()216 FunctionPass *llvm::createRISCVFoldMasksPass() { return new RISCVFoldMasks(); }
217