1 //===-- X86FastTileConfig.cpp - Fast Tile Register Configure---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to config the shape of AMX physical registers
10 /// AMX register need to be configured before use. Before FastRegAllocation pass
11 /// the ldtilecfg instruction is inserted, however at that time we don't
12 /// know the shape of each physical tile registers, because the register
13 /// allocation is not done yet. This pass runs after register allocation
14 /// pass. It collects the shape information of each physical tile register
15 /// and store the shape in the stack slot that is allocated for load config
16 /// to tile config register.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "X86.h"
21 #include "X86InstrBuilder.h"
22 #include "X86MachineFunctionInfo.h"
23 #include "X86RegisterInfo.h"
24 #include "X86Subtarget.h"
25 #include "llvm/CodeGen/MachineFrameInfo.h"
26 #include "llvm/CodeGen/MachineFunctionPass.h"
27 #include "llvm/CodeGen/MachineInstr.h"
28 #include "llvm/CodeGen/MachineRegisterInfo.h"
29 #include "llvm/CodeGen/Passes.h"
30 #include "llvm/CodeGen/TargetInstrInfo.h"
31 #include "llvm/CodeGen/TargetRegisterInfo.h"
32 #include "llvm/InitializePasses.h"
33 
34 using namespace llvm;
35 
36 #define DEBUG_TYPE "fasttileconfig"
37 
38 namespace {
39 
40 class X86FastTileConfig : public MachineFunctionPass {
41   // context
42   MachineFunction *MF = nullptr;
43   const TargetInstrInfo *TII = nullptr;
44   MachineRegisterInfo *MRI = nullptr;
45   const TargetRegisterInfo *TRI = nullptr;
46   X86MachineFunctionInfo *X86FI = nullptr;
47 
48   bool configBasicBlock(MachineBasicBlock &MBB);
49 
50 public:
51   X86FastTileConfig() : MachineFunctionPass(ID) {}
52 
53   /// Return the pass name.
54   StringRef getPassName() const override {
55     return "Fast Tile Register Configure";
56   }
57 
58   void getAnalysisUsage(AnalysisUsage &AU) const override {
59     AU.setPreservesAll();
60     MachineFunctionPass::getAnalysisUsage(AU);
61   }
62 
63   /// Perform register allocation.
64   bool runOnMachineFunction(MachineFunction &MFunc) override;
65 
66   MachineFunctionProperties getRequiredProperties() const override {
67     return MachineFunctionProperties().set(
68         MachineFunctionProperties::Property::NoPHIs);
69   }
70 
71   static char ID;
72 };
73 
74 } // end anonymous namespace
75 
76 char X86FastTileConfig::ID = 0;
77 
78 INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE,
79                       "Fast Tile Register Configure", false, false)
80 INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE,
81                     "Fast Tile Register Configure", false, false)
82 
83 static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) {
84   // There is no phi instruction after register allocation.
85   assert(MI.isPHI() == false);
86   // The instruction must have 3 operands: tile def, row, col.
87   // It should be AMX pseudo instruction that have shape operand.
88   if (MI.isDebugInstr() || MI.isCopy() || MI.getNumOperands() < 3 ||
89       !MI.isPseudo())
90     return false;
91   MachineOperand &MO = MI.getOperand(0);
92 
93   if (MO.isReg()) {
94     Register Reg = MO.getReg();
95     // FIXME it may be used after Greedy RA and the physical
96     // register is not rewritten yet.
97     if (Reg.isVirtual() &&
98         MRI->getRegClass(Reg)->getID() == X86::TILERegClassID)
99       return true;
100     if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
101       return true;
102   }
103 
104   return false;
105 }
106 
107 // PreTileConfig should configure the tile registers based on basic
108 // block.
109 bool X86FastTileConfig::configBasicBlock(MachineBasicBlock &MBB) {
110   bool Change = false;
111   SmallVector<std::pair<unsigned, ShapeT>, 6> ShapeInfos;
112   for (MachineInstr &MI : reverse(MBB)) {
113     if (!isTileDef(MRI, MI) && MI.getOpcode() != X86::PLDTILECFGV)
114       continue;
115     // AMX instructions that define tile register.
116     if (MI.getOpcode() != X86::PLDTILECFGV) {
117       MachineOperand &Row = MI.getOperand(1);
118       MachineOperand &Col = MI.getOperand(2);
119       unsigned TMMIdx = MI.getOperand(0).getReg() - X86::TMM0;
120       ShapeInfos.push_back({TMMIdx, ShapeT(&Row, &Col)});
121     } else { // PLDTILECFGV
122       // Rewrite the shape information to memory. Stack slot should have
123       // been initialized to zero in pre config.
124       int SS = MI.getOperand(0).getIndex(); // tile config stack slot.
125       for (auto &ShapeInfo : ShapeInfos) {
126         DebugLoc DL;
127         unsigned TMMIdx = ShapeInfo.first;
128         Register RowReg = ShapeInfo.second.getRow()->getReg();
129         Register ColReg = ShapeInfo.second.getCol()->getReg();
130         // Here is the data format for the tile config.
131         // 0      palette
132         // 1      start_row
133         // 2-15   reserved, must be zero
134         // 16-17  tile0.colsb Tile 0 bytes per row.
135         // 18-19  tile1.colsb Tile 1 bytes per row.
136         // 20-21  tile2.colsb Tile 2 bytes per row.
137         // ... (sequence continues)
138         // 30-31  tile7.colsb Tile 7 bytes per row.
139         // 32-47  reserved, must be zero
140         // 48     tile0.rows Tile 0 rows.
141         // 49     tile1.rows Tile 1 rows.
142         // 50     tile2.rows Tile 2 rows.
143         // ... (sequence continues)
144         // 55     tile7.rows Tile 7 rows.
145         // 56-63  reserved, must be zero
146         int RowOffset = 48 + TMMIdx;
147         int ColOffset = 16 + TMMIdx * 2;
148 
149         Register SubRowReg = TRI->getSubReg(RowReg, X86::sub_8bit);
150         BuildMI(MBB, MI, DL, TII->get(X86::IMPLICIT_DEF), SubRowReg);
151         MachineInstrBuilder StoreRow =
152             BuildMI(MBB, MI, DL, TII->get(X86::MOV8mr));
153         addFrameReference(StoreRow, SS, RowOffset).addReg(SubRowReg);
154 
155         MachineInstrBuilder StoreCol =
156             BuildMI(MBB, MI, DL, TII->get(X86::MOV16mr));
157         addFrameReference(StoreCol, SS, ColOffset).addReg(ColReg);
158       }
159       ShapeInfos.clear();
160       Change = true;
161     }
162   }
163 
164   if (Change)
165     X86FI->setHasVirtualTileReg(true);
166 
167   return Change;
168 }
169 
170 bool X86FastTileConfig::runOnMachineFunction(MachineFunction &MFunc) {
171   MF = &MFunc;
172   MRI = &MFunc.getRegInfo();
173   const TargetSubtargetInfo *ST = &MFunc.getSubtarget<X86Subtarget>();
174   TRI = ST->getRegisterInfo();
175   TII = MFunc.getSubtarget().getInstrInfo();
176   X86FI = MFunc.getInfo<X86MachineFunctionInfo>();
177   bool Change = false;
178 
179   // Loop over all of the basic blocks, eliminating virtual register references
180   for (MachineBasicBlock &MBB : MFunc)
181     Change |= configBasicBlock(MBB);
182 
183   return Change;
184 }
185 
186 FunctionPass *llvm::createX86FastTileConfigPass() {
187   return new X86FastTileConfig();
188 }
189