1 //===-- RISCVInsertReadWriteCSR.cpp - Insert Read/Write of RISC-V CSR -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file implements the machine function pass to insert read/write of CSR-s
9 // of the RISC-V instructions.
10 //
11 // Currently the pass implements:
12 // -Writing and saving frm before an RVV floating-point instruction with a
13 //  static rounding mode and restores the value after.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "MCTargetDesc/RISCVBaseInfo.h"
18 #include "RISCV.h"
19 #include "RISCVSubtarget.h"
20 #include "llvm/CodeGen/MachineFunctionPass.h"
21 using namespace llvm;
22 
23 #define DEBUG_TYPE "riscv-insert-read-write-csr"
24 #define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass"
25 
26 namespace {
27 
28 class RISCVInsertReadWriteCSR : public MachineFunctionPass {
29   const TargetInstrInfo *TII;
30 
31 public:
32   static char ID;
33 
RISCVInsertReadWriteCSR()34   RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) {}
35 
36   bool runOnMachineFunction(MachineFunction &MF) override;
37 
getAnalysisUsage(AnalysisUsage & AU) const38   void getAnalysisUsage(AnalysisUsage &AU) const override {
39     AU.setPreservesCFG();
40     MachineFunctionPass::getAnalysisUsage(AU);
41   }
42 
getPassName() const43   StringRef getPassName() const override {
44     return RISCV_INSERT_READ_WRITE_CSR_NAME;
45   }
46 
47 private:
48   bool emitWriteRoundingMode(MachineBasicBlock &MBB);
49 };
50 
51 } // end anonymous namespace
52 
53 char RISCVInsertReadWriteCSR::ID = 0;
54 
INITIALIZE_PASS(RISCVInsertReadWriteCSR,DEBUG_TYPE,RISCV_INSERT_READ_WRITE_CSR_NAME,false,false)55 INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE,
56                 RISCV_INSERT_READ_WRITE_CSR_NAME, false, false)
57 
58 // This function also swaps frm and restores it when encountering an RVV
59 // floating point instruction with a static rounding mode.
60 bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) {
61   bool Changed = false;
62   for (MachineInstr &MI : MBB) {
63     int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc());
64     if (FRMIdx < 0)
65       continue;
66 
67     unsigned FRMImm = MI.getOperand(FRMIdx).getImm();
68 
69     // The value is a hint to this pass to not alter the frm value.
70     if (FRMImm == RISCVFPRndMode::DYN)
71       continue;
72 
73     Changed = true;
74 
75     // Save
76     MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
77     Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
78     BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm),
79             SavedFRM)
80         .addImm(FRMImm);
81     MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
82                                             /*IsImp*/ true));
83     // Restore
84     MachineInstrBuilder MIB =
85         BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
86             .addReg(SavedFRM);
87     MBB.insertAfter(MI, MIB);
88   }
89   return Changed;
90 }
91 
runOnMachineFunction(MachineFunction & MF)92 bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction &MF) {
93   // Skip if the vector extension is not enabled.
94   const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
95   if (!ST.hasVInstructions())
96     return false;
97 
98   TII = ST.getInstrInfo();
99 
100   bool Changed = false;
101 
102   for (MachineBasicBlock &MBB : MF)
103     Changed |= emitWriteRoundingMode(MBB);
104 
105   return Changed;
106 }
107 
createRISCVInsertReadWriteCSRPass()108 FunctionPass *llvm::createRISCVInsertReadWriteCSRPass() {
109   return new RISCVInsertReadWriteCSR();
110 }
111