1 //===-- RISCVInsertReadWriteCSR.cpp - Insert Read/Write of RISC-V CSR -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file implements the machine function pass to insert read/write of CSR-s
9 // of the RISC-V instructions.
10 //
11 // Currently the pass implements naive insertion of a write to vxrm before an
12 // RVV fixed-point instruction.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "MCTargetDesc/RISCVBaseInfo.h"
17 #include "RISCV.h"
18 #include "RISCVSubtarget.h"
19 #include "llvm/CodeGen/MachineFunctionPass.h"
20 using namespace llvm;
21 
22 #define DEBUG_TYPE "riscv-insert-read-write-csr"
23 #define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass"
24 
25 namespace {
26 
27 class RISCVInsertReadWriteCSR : public MachineFunctionPass {
28   const TargetInstrInfo *TII;
29 
30 public:
31   static char ID;
32 
33   RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) {
34     initializeRISCVInsertReadWriteCSRPass(*PassRegistry::getPassRegistry());
35   }
36 
37   bool runOnMachineFunction(MachineFunction &MF) override;
38 
39   void getAnalysisUsage(AnalysisUsage &AU) const override {
40     AU.setPreservesCFG();
41     MachineFunctionPass::getAnalysisUsage(AU);
42   }
43 
44   StringRef getPassName() const override {
45     return RISCV_INSERT_READ_WRITE_CSR_NAME;
46   }
47 
48 private:
49   bool emitWriteRoundingMode(MachineBasicBlock &MBB);
50 };
51 
52 } // end anonymous namespace
53 
54 char RISCVInsertReadWriteCSR::ID = 0;
55 
56 INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE,
57                 RISCV_INSERT_READ_WRITE_CSR_NAME, false, false)
58 
59 // Returns the index to the rounding mode immediate value if any, otherwise the
60 // function will return None.
61 static std::optional<unsigned> getRoundModeIdx(const MachineInstr &MI) {
62   uint64_t TSFlags = MI.getDesc().TSFlags;
63   if (!RISCVII::hasRoundModeOp(TSFlags))
64     return std::nullopt;
65 
66   // The operand order
67   // -------------------------------------
68   // | n-1 (if any)   | n-2  | n-3 | n-4 |
69   // | policy         | sew  | vl  | rm  |
70   // -------------------------------------
71   return MI.getNumExplicitOperands() - RISCVII::hasVecPolicyOp(TSFlags) - 3;
72 }
73 
74 // This function inserts a write to vxrm when encountering an RVV fixed-point
75 // instruction.
76 bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) {
77   bool Changed = false;
78   for (MachineInstr &MI : MBB) {
79     if (auto RoundModeIdx = getRoundModeIdx(MI)) {
80       if (RISCVII::usesVXRM(MI.getDesc().TSFlags)) {
81         unsigned VXRMImm = MI.getOperand(*RoundModeIdx).getImm();
82 
83         Changed = true;
84 
85         BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteVXRMImm))
86             .addImm(VXRMImm);
87         MI.addOperand(MachineOperand::CreateReg(RISCV::VXRM, /*IsDef*/ false,
88                                                 /*IsImp*/ true));
89       } else { // FRM
90         unsigned FRMImm = MI.getOperand(*RoundModeIdx).getImm();
91 
92         // The value is a hint to this pass to not alter the frm value.
93         if (FRMImm == RISCVFPRndMode::DYN)
94           continue;
95 
96         Changed = true;
97 
98         // Save
99         MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
100         Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
101         BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm),
102                 SavedFRM)
103             .addImm(FRMImm);
104         MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
105                                                 /*IsImp*/ true));
106         // Restore
107         MachineInstrBuilder MIB =
108             BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
109                 .addReg(SavedFRM);
110         MBB.insertAfter(MI, MIB);
111       }
112     }
113   }
114   return Changed;
115 }
116 
117 bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction &MF) {
118   // Skip if the vector extension is not enabled.
119   const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
120   if (!ST.hasVInstructions())
121     return false;
122 
123   TII = ST.getInstrInfo();
124 
125   bool Changed = false;
126 
127   for (MachineBasicBlock &MBB : MF)
128     Changed |= emitWriteRoundingMode(MBB);
129 
130   return Changed;
131 }
132 
133 FunctionPass *llvm::createRISCVInsertReadWriteCSRPass() {
134   return new RISCVInsertReadWriteCSR();
135 }
136